gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

client.go (12681B)


      1 // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package websocket
      6 
      7 import (
      8 	"bytes"
      9 	"context"
     10 	"crypto/tls"
     11 	"errors"
     12 	"io"
     13 	"io/ioutil"
     14 	"net"
     15 	"net/http"
     16 	"net/http/httptrace"
     17 	"net/url"
     18 	"strings"
     19 	"time"
     20 )
     21 
     22 // ErrBadHandshake is returned when the server response to opening handshake is
     23 // invalid.
     24 var ErrBadHandshake = errors.New("websocket: bad handshake")
     25 
     26 var errInvalidCompression = errors.New("websocket: invalid compression negotiation")
     27 
     28 // NewClient creates a new client connection using the given net connection.
     29 // The URL u specifies the host and request URI. Use requestHeader to specify
     30 // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
     31 // (Cookie). Use the response.Header to get the selected subprotocol
     32 // (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
     33 //
     34 // If the WebSocket handshake fails, ErrBadHandshake is returned along with a
     35 // non-nil *http.Response so that callers can handle redirects, authentication,
     36 // etc.
     37 //
     38 // Deprecated: Use Dialer instead.
     39 func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
     40 	d := Dialer{
     41 		ReadBufferSize:  readBufSize,
     42 		WriteBufferSize: writeBufSize,
     43 		NetDial: func(net, addr string) (net.Conn, error) {
     44 			return netConn, nil
     45 		},
     46 	}
     47 	return d.Dial(u.String(), requestHeader)
     48 }
     49 
     50 // A Dialer contains options for connecting to WebSocket server.
     51 //
     52 // It is safe to call Dialer's methods concurrently.
     53 type Dialer struct {
     54 	// NetDial specifies the dial function for creating TCP connections. If
     55 	// NetDial is nil, net.Dial is used.
     56 	NetDial func(network, addr string) (net.Conn, error)
     57 
     58 	// NetDialContext specifies the dial function for creating TCP connections. If
     59 	// NetDialContext is nil, NetDial is used.
     60 	NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
     61 
     62 	// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
     63 	// NetDialTLSContext is nil, NetDialContext is used.
     64 	// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
     65 	// TLSClientConfig is ignored.
     66 	NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
     67 
     68 	// Proxy specifies a function to return a proxy for a given
     69 	// Request. If the function returns a non-nil error, the
     70 	// request is aborted with the provided error.
     71 	// If Proxy is nil or returns a nil *URL, no proxy is used.
     72 	Proxy func(*http.Request) (*url.URL, error)
     73 
     74 	// TLSClientConfig specifies the TLS configuration to use with tls.Client.
     75 	// If nil, the default configuration is used.
     76 	// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
     77 	// is done there and TLSClientConfig is ignored.
     78 	TLSClientConfig *tls.Config
     79 
     80 	// HandshakeTimeout specifies the duration for the handshake to complete.
     81 	HandshakeTimeout time.Duration
     82 
     83 	// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
     84 	// size is zero, then a useful default size is used. The I/O buffer sizes
     85 	// do not limit the size of the messages that can be sent or received.
     86 	ReadBufferSize, WriteBufferSize int
     87 
     88 	// WriteBufferPool is a pool of buffers for write operations. If the value
     89 	// is not set, then write buffers are allocated to the connection for the
     90 	// lifetime of the connection.
     91 	//
     92 	// A pool is most useful when the application has a modest volume of writes
     93 	// across a large number of connections.
     94 	//
     95 	// Applications should use a single pool for each unique value of
     96 	// WriteBufferSize.
     97 	WriteBufferPool BufferPool
     98 
     99 	// Subprotocols specifies the client's requested subprotocols.
    100 	Subprotocols []string
    101 
    102 	// EnableCompression specifies if the client should attempt to negotiate
    103 	// per message compression (RFC 7692). Setting this value to true does not
    104 	// guarantee that compression will be supported. Currently only "no context
    105 	// takeover" modes are supported.
    106 	EnableCompression bool
    107 
    108 	// Jar specifies the cookie jar.
    109 	// If Jar is nil, cookies are not sent in requests and ignored
    110 	// in responses.
    111 	Jar http.CookieJar
    112 }
    113 
    114 // Dial creates a new client connection by calling DialContext with a background context.
    115 func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
    116 	return d.DialContext(context.Background(), urlStr, requestHeader)
    117 }
    118 
    119 var errMalformedURL = errors.New("malformed ws or wss URL")
    120 
    121 func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
    122 	hostPort = u.Host
    123 	hostNoPort = u.Host
    124 	if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
    125 		hostNoPort = hostNoPort[:i]
    126 	} else {
    127 		switch u.Scheme {
    128 		case "wss":
    129 			hostPort += ":443"
    130 		case "https":
    131 			hostPort += ":443"
    132 		default:
    133 			hostPort += ":80"
    134 		}
    135 	}
    136 	return hostPort, hostNoPort
    137 }
    138 
    139 // DefaultDialer is a dialer with all fields set to the default values.
    140 var DefaultDialer = &Dialer{
    141 	Proxy:            http.ProxyFromEnvironment,
    142 	HandshakeTimeout: 45 * time.Second,
    143 }
    144 
    145 // nilDialer is dialer to use when receiver is nil.
    146 var nilDialer = *DefaultDialer
    147 
    148 // DialContext creates a new client connection. Use requestHeader to specify the
    149 // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
    150 // Use the response.Header to get the selected subprotocol
    151 // (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
    152 //
    153 // The context will be used in the request and in the Dialer.
    154 //
    155 // If the WebSocket handshake fails, ErrBadHandshake is returned along with a
    156 // non-nil *http.Response so that callers can handle redirects, authentication,
    157 // etcetera. The response body may not contain the entire response and does not
    158 // need to be closed by the application.
    159 func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
    160 	if d == nil {
    161 		d = &nilDialer
    162 	}
    163 
    164 	challengeKey, err := generateChallengeKey()
    165 	if err != nil {
    166 		return nil, nil, err
    167 	}
    168 
    169 	u, err := url.Parse(urlStr)
    170 	if err != nil {
    171 		return nil, nil, err
    172 	}
    173 
    174 	switch u.Scheme {
    175 	case "ws":
    176 		u.Scheme = "http"
    177 	case "wss":
    178 		u.Scheme = "https"
    179 	default:
    180 		return nil, nil, errMalformedURL
    181 	}
    182 
    183 	if u.User != nil {
    184 		// User name and password are not allowed in websocket URIs.
    185 		return nil, nil, errMalformedURL
    186 	}
    187 
    188 	req := &http.Request{
    189 		Method:     http.MethodGet,
    190 		URL:        u,
    191 		Proto:      "HTTP/1.1",
    192 		ProtoMajor: 1,
    193 		ProtoMinor: 1,
    194 		Header:     make(http.Header),
    195 		Host:       u.Host,
    196 	}
    197 	req = req.WithContext(ctx)
    198 
    199 	// Set the cookies present in the cookie jar of the dialer
    200 	if d.Jar != nil {
    201 		for _, cookie := range d.Jar.Cookies(u) {
    202 			req.AddCookie(cookie)
    203 		}
    204 	}
    205 
    206 	// Set the request headers using the capitalization for names and values in
    207 	// RFC examples. Although the capitalization shouldn't matter, there are
    208 	// servers that depend on it. The Header.Set method is not used because the
    209 	// method canonicalizes the header names.
    210 	req.Header["Upgrade"] = []string{"websocket"}
    211 	req.Header["Connection"] = []string{"Upgrade"}
    212 	req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
    213 	req.Header["Sec-WebSocket-Version"] = []string{"13"}
    214 	if len(d.Subprotocols) > 0 {
    215 		req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
    216 	}
    217 	for k, vs := range requestHeader {
    218 		switch {
    219 		case k == "Host":
    220 			if len(vs) > 0 {
    221 				req.Host = vs[0]
    222 			}
    223 		case k == "Upgrade" ||
    224 			k == "Connection" ||
    225 			k == "Sec-Websocket-Key" ||
    226 			k == "Sec-Websocket-Version" ||
    227 			k == "Sec-Websocket-Extensions" ||
    228 			(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
    229 			return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
    230 		case k == "Sec-Websocket-Protocol":
    231 			req.Header["Sec-WebSocket-Protocol"] = vs
    232 		default:
    233 			req.Header[k] = vs
    234 		}
    235 	}
    236 
    237 	if d.EnableCompression {
    238 		req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
    239 	}
    240 
    241 	if d.HandshakeTimeout != 0 {
    242 		var cancel func()
    243 		ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
    244 		defer cancel()
    245 	}
    246 
    247 	// Get network dial function.
    248 	var netDial func(network, add string) (net.Conn, error)
    249 
    250 	switch u.Scheme {
    251 	case "http":
    252 		if d.NetDialContext != nil {
    253 			netDial = func(network, addr string) (net.Conn, error) {
    254 				return d.NetDialContext(ctx, network, addr)
    255 			}
    256 		} else if d.NetDial != nil {
    257 			netDial = d.NetDial
    258 		}
    259 	case "https":
    260 		if d.NetDialTLSContext != nil {
    261 			netDial = func(network, addr string) (net.Conn, error) {
    262 				return d.NetDialTLSContext(ctx, network, addr)
    263 			}
    264 		} else if d.NetDialContext != nil {
    265 			netDial = func(network, addr string) (net.Conn, error) {
    266 				return d.NetDialContext(ctx, network, addr)
    267 			}
    268 		} else if d.NetDial != nil {
    269 			netDial = d.NetDial
    270 		}
    271 	default:
    272 		return nil, nil, errMalformedURL
    273 	}
    274 
    275 	if netDial == nil {
    276 		netDialer := &net.Dialer{}
    277 		netDial = func(network, addr string) (net.Conn, error) {
    278 			return netDialer.DialContext(ctx, network, addr)
    279 		}
    280 	}
    281 
    282 	// If needed, wrap the dial function to set the connection deadline.
    283 	if deadline, ok := ctx.Deadline(); ok {
    284 		forwardDial := netDial
    285 		netDial = func(network, addr string) (net.Conn, error) {
    286 			c, err := forwardDial(network, addr)
    287 			if err != nil {
    288 				return nil, err
    289 			}
    290 			err = c.SetDeadline(deadline)
    291 			if err != nil {
    292 				c.Close()
    293 				return nil, err
    294 			}
    295 			return c, nil
    296 		}
    297 	}
    298 
    299 	// If needed, wrap the dial function to connect through a proxy.
    300 	if d.Proxy != nil {
    301 		proxyURL, err := d.Proxy(req)
    302 		if err != nil {
    303 			return nil, nil, err
    304 		}
    305 		if proxyURL != nil {
    306 			dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
    307 			if err != nil {
    308 				return nil, nil, err
    309 			}
    310 			netDial = dialer.Dial
    311 		}
    312 	}
    313 
    314 	hostPort, hostNoPort := hostPortNoPort(u)
    315 	trace := httptrace.ContextClientTrace(ctx)
    316 	if trace != nil && trace.GetConn != nil {
    317 		trace.GetConn(hostPort)
    318 	}
    319 
    320 	netConn, err := netDial("tcp", hostPort)
    321 	if trace != nil && trace.GotConn != nil {
    322 		trace.GotConn(httptrace.GotConnInfo{
    323 			Conn: netConn,
    324 		})
    325 	}
    326 	if err != nil {
    327 		return nil, nil, err
    328 	}
    329 
    330 	defer func() {
    331 		if netConn != nil {
    332 			netConn.Close()
    333 		}
    334 	}()
    335 
    336 	if u.Scheme == "https" && d.NetDialTLSContext == nil {
    337 		// If NetDialTLSContext is set, assume that the TLS handshake has already been done
    338 
    339 		cfg := cloneTLSConfig(d.TLSClientConfig)
    340 		if cfg.ServerName == "" {
    341 			cfg.ServerName = hostNoPort
    342 		}
    343 		tlsConn := tls.Client(netConn, cfg)
    344 		netConn = tlsConn
    345 
    346 		if trace != nil && trace.TLSHandshakeStart != nil {
    347 			trace.TLSHandshakeStart()
    348 		}
    349 		err := doHandshake(ctx, tlsConn, cfg)
    350 		if trace != nil && trace.TLSHandshakeDone != nil {
    351 			trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
    352 		}
    353 
    354 		if err != nil {
    355 			return nil, nil, err
    356 		}
    357 	}
    358 
    359 	conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
    360 
    361 	if err := req.Write(netConn); err != nil {
    362 		return nil, nil, err
    363 	}
    364 
    365 	if trace != nil && trace.GotFirstResponseByte != nil {
    366 		if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 {
    367 			trace.GotFirstResponseByte()
    368 		}
    369 	}
    370 
    371 	resp, err := http.ReadResponse(conn.br, req)
    372 	if err != nil {
    373 		return nil, nil, err
    374 	}
    375 
    376 	if d.Jar != nil {
    377 		if rc := resp.Cookies(); len(rc) > 0 {
    378 			d.Jar.SetCookies(u, rc)
    379 		}
    380 	}
    381 
    382 	if resp.StatusCode != 101 ||
    383 		!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
    384 		!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
    385 		resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
    386 		// Before closing the network connection on return from this
    387 		// function, slurp up some of the response to aid application
    388 		// debugging.
    389 		buf := make([]byte, 1024)
    390 		n, _ := io.ReadFull(resp.Body, buf)
    391 		resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
    392 		return nil, resp, ErrBadHandshake
    393 	}
    394 
    395 	for _, ext := range parseExtensions(resp.Header) {
    396 		if ext[""] != "permessage-deflate" {
    397 			continue
    398 		}
    399 		_, snct := ext["server_no_context_takeover"]
    400 		_, cnct := ext["client_no_context_takeover"]
    401 		if !snct || !cnct {
    402 			return nil, resp, errInvalidCompression
    403 		}
    404 		conn.newCompressionWriter = compressNoContextTakeover
    405 		conn.newDecompressionReader = decompressNoContextTakeover
    406 		break
    407 	}
    408 
    409 	resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
    410 	conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
    411 
    412 	netConn.SetDeadline(time.Time{})
    413 	netConn = nil // to avoid close in defer.
    414 	return conn, resp, nil
    415 }
    416 
    417 func cloneTLSConfig(cfg *tls.Config) *tls.Config {
    418 	if cfg == nil {
    419 		return &tls.Config{}
    420 	}
    421 	return cfg.Clone()
    422 }