client.go (12681B)
1 // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package websocket 6 7 import ( 8 "bytes" 9 "context" 10 "crypto/tls" 11 "errors" 12 "io" 13 "io/ioutil" 14 "net" 15 "net/http" 16 "net/http/httptrace" 17 "net/url" 18 "strings" 19 "time" 20 ) 21 22 // ErrBadHandshake is returned when the server response to opening handshake is 23 // invalid. 24 var ErrBadHandshake = errors.New("websocket: bad handshake") 25 26 var errInvalidCompression = errors.New("websocket: invalid compression negotiation") 27 28 // NewClient creates a new client connection using the given net connection. 29 // The URL u specifies the host and request URI. Use requestHeader to specify 30 // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies 31 // (Cookie). Use the response.Header to get the selected subprotocol 32 // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). 33 // 34 // If the WebSocket handshake fails, ErrBadHandshake is returned along with a 35 // non-nil *http.Response so that callers can handle redirects, authentication, 36 // etc. 37 // 38 // Deprecated: Use Dialer instead. 39 func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { 40 d := Dialer{ 41 ReadBufferSize: readBufSize, 42 WriteBufferSize: writeBufSize, 43 NetDial: func(net, addr string) (net.Conn, error) { 44 return netConn, nil 45 }, 46 } 47 return d.Dial(u.String(), requestHeader) 48 } 49 50 // A Dialer contains options for connecting to WebSocket server. 51 // 52 // It is safe to call Dialer's methods concurrently. 53 type Dialer struct { 54 // NetDial specifies the dial function for creating TCP connections. If 55 // NetDial is nil, net.Dial is used. 56 NetDial func(network, addr string) (net.Conn, error) 57 58 // NetDialContext specifies the dial function for creating TCP connections. If 59 // NetDialContext is nil, NetDial is used. 60 NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) 61 62 // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If 63 // NetDialTLSContext is nil, NetDialContext is used. 64 // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and 65 // TLSClientConfig is ignored. 66 NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) 67 68 // Proxy specifies a function to return a proxy for a given 69 // Request. If the function returns a non-nil error, the 70 // request is aborted with the provided error. 71 // If Proxy is nil or returns a nil *URL, no proxy is used. 72 Proxy func(*http.Request) (*url.URL, error) 73 74 // TLSClientConfig specifies the TLS configuration to use with tls.Client. 75 // If nil, the default configuration is used. 76 // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake 77 // is done there and TLSClientConfig is ignored. 78 TLSClientConfig *tls.Config 79 80 // HandshakeTimeout specifies the duration for the handshake to complete. 81 HandshakeTimeout time.Duration 82 83 // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer 84 // size is zero, then a useful default size is used. The I/O buffer sizes 85 // do not limit the size of the messages that can be sent or received. 86 ReadBufferSize, WriteBufferSize int 87 88 // WriteBufferPool is a pool of buffers for write operations. If the value 89 // is not set, then write buffers are allocated to the connection for the 90 // lifetime of the connection. 91 // 92 // A pool is most useful when the application has a modest volume of writes 93 // across a large number of connections. 94 // 95 // Applications should use a single pool for each unique value of 96 // WriteBufferSize. 97 WriteBufferPool BufferPool 98 99 // Subprotocols specifies the client's requested subprotocols. 100 Subprotocols []string 101 102 // EnableCompression specifies if the client should attempt to negotiate 103 // per message compression (RFC 7692). Setting this value to true does not 104 // guarantee that compression will be supported. Currently only "no context 105 // takeover" modes are supported. 106 EnableCompression bool 107 108 // Jar specifies the cookie jar. 109 // If Jar is nil, cookies are not sent in requests and ignored 110 // in responses. 111 Jar http.CookieJar 112 } 113 114 // Dial creates a new client connection by calling DialContext with a background context. 115 func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { 116 return d.DialContext(context.Background(), urlStr, requestHeader) 117 } 118 119 var errMalformedURL = errors.New("malformed ws or wss URL") 120 121 func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { 122 hostPort = u.Host 123 hostNoPort = u.Host 124 if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { 125 hostNoPort = hostNoPort[:i] 126 } else { 127 switch u.Scheme { 128 case "wss": 129 hostPort += ":443" 130 case "https": 131 hostPort += ":443" 132 default: 133 hostPort += ":80" 134 } 135 } 136 return hostPort, hostNoPort 137 } 138 139 // DefaultDialer is a dialer with all fields set to the default values. 140 var DefaultDialer = &Dialer{ 141 Proxy: http.ProxyFromEnvironment, 142 HandshakeTimeout: 45 * time.Second, 143 } 144 145 // nilDialer is dialer to use when receiver is nil. 146 var nilDialer = *DefaultDialer 147 148 // DialContext creates a new client connection. Use requestHeader to specify the 149 // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). 150 // Use the response.Header to get the selected subprotocol 151 // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). 152 // 153 // The context will be used in the request and in the Dialer. 154 // 155 // If the WebSocket handshake fails, ErrBadHandshake is returned along with a 156 // non-nil *http.Response so that callers can handle redirects, authentication, 157 // etcetera. The response body may not contain the entire response and does not 158 // need to be closed by the application. 159 func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { 160 if d == nil { 161 d = &nilDialer 162 } 163 164 challengeKey, err := generateChallengeKey() 165 if err != nil { 166 return nil, nil, err 167 } 168 169 u, err := url.Parse(urlStr) 170 if err != nil { 171 return nil, nil, err 172 } 173 174 switch u.Scheme { 175 case "ws": 176 u.Scheme = "http" 177 case "wss": 178 u.Scheme = "https" 179 default: 180 return nil, nil, errMalformedURL 181 } 182 183 if u.User != nil { 184 // User name and password are not allowed in websocket URIs. 185 return nil, nil, errMalformedURL 186 } 187 188 req := &http.Request{ 189 Method: http.MethodGet, 190 URL: u, 191 Proto: "HTTP/1.1", 192 ProtoMajor: 1, 193 ProtoMinor: 1, 194 Header: make(http.Header), 195 Host: u.Host, 196 } 197 req = req.WithContext(ctx) 198 199 // Set the cookies present in the cookie jar of the dialer 200 if d.Jar != nil { 201 for _, cookie := range d.Jar.Cookies(u) { 202 req.AddCookie(cookie) 203 } 204 } 205 206 // Set the request headers using the capitalization for names and values in 207 // RFC examples. Although the capitalization shouldn't matter, there are 208 // servers that depend on it. The Header.Set method is not used because the 209 // method canonicalizes the header names. 210 req.Header["Upgrade"] = []string{"websocket"} 211 req.Header["Connection"] = []string{"Upgrade"} 212 req.Header["Sec-WebSocket-Key"] = []string{challengeKey} 213 req.Header["Sec-WebSocket-Version"] = []string{"13"} 214 if len(d.Subprotocols) > 0 { 215 req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} 216 } 217 for k, vs := range requestHeader { 218 switch { 219 case k == "Host": 220 if len(vs) > 0 { 221 req.Host = vs[0] 222 } 223 case k == "Upgrade" || 224 k == "Connection" || 225 k == "Sec-Websocket-Key" || 226 k == "Sec-Websocket-Version" || 227 k == "Sec-Websocket-Extensions" || 228 (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): 229 return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) 230 case k == "Sec-Websocket-Protocol": 231 req.Header["Sec-WebSocket-Protocol"] = vs 232 default: 233 req.Header[k] = vs 234 } 235 } 236 237 if d.EnableCompression { 238 req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} 239 } 240 241 if d.HandshakeTimeout != 0 { 242 var cancel func() 243 ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) 244 defer cancel() 245 } 246 247 // Get network dial function. 248 var netDial func(network, add string) (net.Conn, error) 249 250 switch u.Scheme { 251 case "http": 252 if d.NetDialContext != nil { 253 netDial = func(network, addr string) (net.Conn, error) { 254 return d.NetDialContext(ctx, network, addr) 255 } 256 } else if d.NetDial != nil { 257 netDial = d.NetDial 258 } 259 case "https": 260 if d.NetDialTLSContext != nil { 261 netDial = func(network, addr string) (net.Conn, error) { 262 return d.NetDialTLSContext(ctx, network, addr) 263 } 264 } else if d.NetDialContext != nil { 265 netDial = func(network, addr string) (net.Conn, error) { 266 return d.NetDialContext(ctx, network, addr) 267 } 268 } else if d.NetDial != nil { 269 netDial = d.NetDial 270 } 271 default: 272 return nil, nil, errMalformedURL 273 } 274 275 if netDial == nil { 276 netDialer := &net.Dialer{} 277 netDial = func(network, addr string) (net.Conn, error) { 278 return netDialer.DialContext(ctx, network, addr) 279 } 280 } 281 282 // If needed, wrap the dial function to set the connection deadline. 283 if deadline, ok := ctx.Deadline(); ok { 284 forwardDial := netDial 285 netDial = func(network, addr string) (net.Conn, error) { 286 c, err := forwardDial(network, addr) 287 if err != nil { 288 return nil, err 289 } 290 err = c.SetDeadline(deadline) 291 if err != nil { 292 c.Close() 293 return nil, err 294 } 295 return c, nil 296 } 297 } 298 299 // If needed, wrap the dial function to connect through a proxy. 300 if d.Proxy != nil { 301 proxyURL, err := d.Proxy(req) 302 if err != nil { 303 return nil, nil, err 304 } 305 if proxyURL != nil { 306 dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) 307 if err != nil { 308 return nil, nil, err 309 } 310 netDial = dialer.Dial 311 } 312 } 313 314 hostPort, hostNoPort := hostPortNoPort(u) 315 trace := httptrace.ContextClientTrace(ctx) 316 if trace != nil && trace.GetConn != nil { 317 trace.GetConn(hostPort) 318 } 319 320 netConn, err := netDial("tcp", hostPort) 321 if trace != nil && trace.GotConn != nil { 322 trace.GotConn(httptrace.GotConnInfo{ 323 Conn: netConn, 324 }) 325 } 326 if err != nil { 327 return nil, nil, err 328 } 329 330 defer func() { 331 if netConn != nil { 332 netConn.Close() 333 } 334 }() 335 336 if u.Scheme == "https" && d.NetDialTLSContext == nil { 337 // If NetDialTLSContext is set, assume that the TLS handshake has already been done 338 339 cfg := cloneTLSConfig(d.TLSClientConfig) 340 if cfg.ServerName == "" { 341 cfg.ServerName = hostNoPort 342 } 343 tlsConn := tls.Client(netConn, cfg) 344 netConn = tlsConn 345 346 if trace != nil && trace.TLSHandshakeStart != nil { 347 trace.TLSHandshakeStart() 348 } 349 err := doHandshake(ctx, tlsConn, cfg) 350 if trace != nil && trace.TLSHandshakeDone != nil { 351 trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) 352 } 353 354 if err != nil { 355 return nil, nil, err 356 } 357 } 358 359 conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil) 360 361 if err := req.Write(netConn); err != nil { 362 return nil, nil, err 363 } 364 365 if trace != nil && trace.GotFirstResponseByte != nil { 366 if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 { 367 trace.GotFirstResponseByte() 368 } 369 } 370 371 resp, err := http.ReadResponse(conn.br, req) 372 if err != nil { 373 return nil, nil, err 374 } 375 376 if d.Jar != nil { 377 if rc := resp.Cookies(); len(rc) > 0 { 378 d.Jar.SetCookies(u, rc) 379 } 380 } 381 382 if resp.StatusCode != 101 || 383 !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || 384 !tokenListContainsValue(resp.Header, "Connection", "upgrade") || 385 resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { 386 // Before closing the network connection on return from this 387 // function, slurp up some of the response to aid application 388 // debugging. 389 buf := make([]byte, 1024) 390 n, _ := io.ReadFull(resp.Body, buf) 391 resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) 392 return nil, resp, ErrBadHandshake 393 } 394 395 for _, ext := range parseExtensions(resp.Header) { 396 if ext[""] != "permessage-deflate" { 397 continue 398 } 399 _, snct := ext["server_no_context_takeover"] 400 _, cnct := ext["client_no_context_takeover"] 401 if !snct || !cnct { 402 return nil, resp, errInvalidCompression 403 } 404 conn.newCompressionWriter = compressNoContextTakeover 405 conn.newDecompressionReader = decompressNoContextTakeover 406 break 407 } 408 409 resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) 410 conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") 411 412 netConn.SetDeadline(time.Time{}) 413 netConn = nil // to avoid close in defer. 414 return conn, resp, nil 415 } 416 417 func cloneTLSConfig(cfg *tls.Config) *tls.Config { 418 if cfg == nil { 419 return &tls.Config{} 420 } 421 return cfg.Clone() 422 }