client.go (14604B)
1 package dns 2 3 // A client implementation. 4 5 import ( 6 "context" 7 "crypto/tls" 8 "encoding/binary" 9 "io" 10 "net" 11 "strings" 12 "time" 13 ) 14 15 const ( 16 dnsTimeout time.Duration = 2 * time.Second 17 tcpIdleTimeout time.Duration = 8 * time.Second 18 ) 19 20 func isPacketConn(c net.Conn) bool { 21 if _, ok := c.(net.PacketConn); !ok { 22 return false 23 } 24 25 if ua, ok := c.LocalAddr().(*net.UnixAddr); ok { 26 return ua.Net == "unixgram" || ua.Net == "unixpacket" 27 } 28 29 return true 30 } 31 32 // A Conn represents a connection to a DNS server. 33 type Conn struct { 34 net.Conn // a net.Conn holding the connection 35 UDPSize uint16 // minimum receive buffer for UDP messages 36 TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) 37 TsigProvider TsigProvider // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations. 38 tsigRequestMAC string 39 } 40 41 func (co *Conn) tsigProvider() TsigProvider { 42 if co.TsigProvider != nil { 43 return co.TsigProvider 44 } 45 // tsigSecretProvider will return ErrSecret if co.TsigSecret is nil. 46 return tsigSecretProvider(co.TsigSecret) 47 } 48 49 // A Client defines parameters for a DNS client. 50 type Client struct { 51 Net string // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) 52 UDPSize uint16 // minimum receive buffer for UDP messages 53 TLSConfig *tls.Config // TLS connection configuration 54 Dialer *net.Dialer // a net.Dialer used to set local address, timeouts and more 55 // Timeout is a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout, 56 // WriteTimeout when non-zero. Can be overridden with net.Dialer.Timeout (see Client.ExchangeWithDialer and 57 // Client.Dialer) or context.Context.Deadline (see ExchangeContext) 58 Timeout time.Duration 59 DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero 60 ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero 61 WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero 62 TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) 63 TsigProvider TsigProvider // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations. 64 65 // SingleInflight previously serialised multiple concurrent queries for the 66 // same Qname, Qtype and Qclass to ensure only one would be in flight at a 67 // time. 68 // 69 // Deprecated: This is a no-op. Callers should implement their own in flight 70 // query caching if needed. See github.com/miekg/dns/issues/1449. 71 SingleInflight bool 72 } 73 74 // Exchange performs a synchronous UDP query. It sends the message m to the address 75 // contained in a and waits for a reply. Exchange does not retry a failed query, nor 76 // will it fall back to TCP in case of truncation. 77 // See client.Exchange for more information on setting larger buffer sizes. 78 func Exchange(m *Msg, a string) (r *Msg, err error) { 79 client := Client{Net: "udp"} 80 r, _, err = client.Exchange(m, a) 81 return r, err 82 } 83 84 func (c *Client) dialTimeout() time.Duration { 85 if c.Timeout != 0 { 86 return c.Timeout 87 } 88 if c.DialTimeout != 0 { 89 return c.DialTimeout 90 } 91 return dnsTimeout 92 } 93 94 func (c *Client) readTimeout() time.Duration { 95 if c.ReadTimeout != 0 { 96 return c.ReadTimeout 97 } 98 return dnsTimeout 99 } 100 101 func (c *Client) writeTimeout() time.Duration { 102 if c.WriteTimeout != 0 { 103 return c.WriteTimeout 104 } 105 return dnsTimeout 106 } 107 108 // Dial connects to the address on the named network. 109 func (c *Client) Dial(address string) (conn *Conn, err error) { 110 return c.DialContext(context.Background(), address) 111 } 112 113 // DialContext connects to the address on the named network, with a context.Context. 114 func (c *Client) DialContext(ctx context.Context, address string) (conn *Conn, err error) { 115 // create a new dialer with the appropriate timeout 116 var d net.Dialer 117 if c.Dialer == nil { 118 d = net.Dialer{Timeout: c.getTimeoutForRequest(c.dialTimeout())} 119 } else { 120 d = *c.Dialer 121 } 122 123 network := c.Net 124 if network == "" { 125 network = "udp" 126 } 127 128 useTLS := strings.HasPrefix(network, "tcp") && strings.HasSuffix(network, "-tls") 129 130 conn = new(Conn) 131 if useTLS { 132 network = strings.TrimSuffix(network, "-tls") 133 134 tlsDialer := tls.Dialer{ 135 NetDialer: &d, 136 Config: c.TLSConfig, 137 } 138 conn.Conn, err = tlsDialer.DialContext(ctx, network, address) 139 } else { 140 conn.Conn, err = d.DialContext(ctx, network, address) 141 } 142 if err != nil { 143 return nil, err 144 } 145 conn.UDPSize = c.UDPSize 146 return conn, nil 147 } 148 149 // Exchange performs a synchronous query. It sends the message m to the address 150 // contained in a and waits for a reply. Basic use pattern with a *dns.Client: 151 // 152 // c := new(dns.Client) 153 // in, rtt, err := c.Exchange(message, "127.0.0.1:53") 154 // 155 // Exchange does not retry a failed query, nor will it fall back to TCP in 156 // case of truncation. 157 // It is up to the caller to create a message that allows for larger responses to be 158 // returned. Specifically this means adding an EDNS0 OPT RR that will advertise a larger 159 // buffer, see SetEdns0. Messages without an OPT RR will fallback to the historic limit 160 // of 512 bytes 161 // To specify a local address or a timeout, the caller has to set the `Client.Dialer` 162 // attribute appropriately 163 func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) { 164 co, err := c.Dial(address) 165 166 if err != nil { 167 return nil, 0, err 168 } 169 defer co.Close() 170 return c.ExchangeWithConn(m, co) 171 } 172 173 // ExchangeWithConn has the same behavior as Exchange, just with a predetermined connection 174 // that will be used instead of creating a new one. 175 // Usage pattern with a *dns.Client: 176 // 177 // c := new(dns.Client) 178 // // connection management logic goes here 179 // 180 // conn := c.Dial(address) 181 // in, rtt, err := c.ExchangeWithConn(message, conn) 182 // 183 // This allows users of the library to implement their own connection management, 184 // as opposed to Exchange, which will always use new connections and incur the added overhead 185 // that entails when using "tcp" and especially "tcp-tls" clients. 186 // 187 // When the singleflight is set for this client the context is _not_ forwarded to the (shared) exchange, to 188 // prevent one cancellation from canceling all outstanding requests. 189 func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) { 190 return c.exchangeWithConnContext(context.Background(), m, conn) 191 } 192 193 func (c *Client) exchangeWithConnContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) { 194 opt := m.IsEdns0() 195 // If EDNS0 is used use that for size. 196 if opt != nil && opt.UDPSize() >= MinMsgSize { 197 co.UDPSize = opt.UDPSize() 198 } 199 // Otherwise use the client's configured UDP size. 200 if opt == nil && c.UDPSize >= MinMsgSize { 201 co.UDPSize = c.UDPSize 202 } 203 204 // write with the appropriate write timeout 205 t := time.Now() 206 writeDeadline := t.Add(c.getTimeoutForRequest(c.writeTimeout())) 207 readDeadline := t.Add(c.getTimeoutForRequest(c.readTimeout())) 208 if deadline, ok := ctx.Deadline(); ok { 209 if deadline.Before(writeDeadline) { 210 writeDeadline = deadline 211 } 212 if deadline.Before(readDeadline) { 213 readDeadline = deadline 214 } 215 } 216 co.SetWriteDeadline(writeDeadline) 217 co.SetReadDeadline(readDeadline) 218 219 co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider 220 221 if err = co.WriteMsg(m); err != nil { 222 return nil, 0, err 223 } 224 225 if isPacketConn(co.Conn) { 226 for { 227 r, err = co.ReadMsg() 228 // Ignore replies with mismatched IDs because they might be 229 // responses to earlier queries that timed out. 230 if err != nil || r.Id == m.Id { 231 break 232 } 233 } 234 } else { 235 r, err = co.ReadMsg() 236 if err == nil && r.Id != m.Id { 237 err = ErrId 238 } 239 } 240 rtt = time.Since(t) 241 return r, rtt, err 242 } 243 244 // ReadMsg reads a message from the connection co. 245 // If the received message contains a TSIG record the transaction signature 246 // is verified. This method always tries to return the message, however if an 247 // error is returned there are no guarantees that the returned message is a 248 // valid representation of the packet read. 249 func (co *Conn) ReadMsg() (*Msg, error) { 250 p, err := co.ReadMsgHeader(nil) 251 if err != nil { 252 return nil, err 253 } 254 255 m := new(Msg) 256 if err := m.Unpack(p); err != nil { 257 // If an error was returned, we still want to allow the user to use 258 // the message, but naively they can just check err if they don't want 259 // to use an erroneous message 260 return m, err 261 } 262 if t := m.IsTsig(); t != nil { 263 // Need to work on the original message p, as that was used to calculate the tsig. 264 err = TsigVerifyWithProvider(p, co.tsigProvider(), co.tsigRequestMAC, false) 265 } 266 return m, err 267 } 268 269 // ReadMsgHeader reads a DNS message, parses and populates hdr (when hdr is not nil). 270 // Returns message as a byte slice to be parsed with Msg.Unpack later on. 271 // Note that error handling on the message body is not possible as only the header is parsed. 272 func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) { 273 var ( 274 p []byte 275 n int 276 err error 277 ) 278 279 if isPacketConn(co.Conn) { 280 if co.UDPSize > MinMsgSize { 281 p = make([]byte, co.UDPSize) 282 } else { 283 p = make([]byte, MinMsgSize) 284 } 285 n, err = co.Read(p) 286 } else { 287 var length uint16 288 if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil { 289 return nil, err 290 } 291 292 p = make([]byte, length) 293 n, err = io.ReadFull(co.Conn, p) 294 } 295 296 if err != nil { 297 return nil, err 298 } else if n < headerSize { 299 return nil, ErrShortRead 300 } 301 302 p = p[:n] 303 if hdr != nil { 304 dh, _, err := unpackMsgHdr(p, 0) 305 if err != nil { 306 return nil, err 307 } 308 *hdr = dh 309 } 310 return p, err 311 } 312 313 // Read implements the net.Conn read method. 314 func (co *Conn) Read(p []byte) (n int, err error) { 315 if co.Conn == nil { 316 return 0, ErrConnEmpty 317 } 318 319 if isPacketConn(co.Conn) { 320 // UDP connection 321 return co.Conn.Read(p) 322 } 323 324 var length uint16 325 if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil { 326 return 0, err 327 } 328 if int(length) > len(p) { 329 return 0, io.ErrShortBuffer 330 } 331 332 return io.ReadFull(co.Conn, p[:length]) 333 } 334 335 // WriteMsg sends a message through the connection co. 336 // If the message m contains a TSIG record the transaction 337 // signature is calculated. 338 func (co *Conn) WriteMsg(m *Msg) (err error) { 339 var out []byte 340 if t := m.IsTsig(); t != nil { 341 // Set tsigRequestMAC for the next read, although only used in zone transfers. 342 out, co.tsigRequestMAC, err = TsigGenerateWithProvider(m, co.tsigProvider(), co.tsigRequestMAC, false) 343 } else { 344 out, err = m.Pack() 345 } 346 if err != nil { 347 return err 348 } 349 _, err = co.Write(out) 350 return err 351 } 352 353 // Write implements the net.Conn Write method. 354 func (co *Conn) Write(p []byte) (int, error) { 355 if len(p) > MaxMsgSize { 356 return 0, &Error{err: "message too large"} 357 } 358 359 if isPacketConn(co.Conn) { 360 return co.Conn.Write(p) 361 } 362 363 msg := make([]byte, 2+len(p)) 364 binary.BigEndian.PutUint16(msg, uint16(len(p))) 365 copy(msg[2:], p) 366 return co.Conn.Write(msg) 367 } 368 369 // Return the appropriate timeout for a specific request 370 func (c *Client) getTimeoutForRequest(timeout time.Duration) time.Duration { 371 var requestTimeout time.Duration 372 if c.Timeout != 0 { 373 requestTimeout = c.Timeout 374 } else { 375 requestTimeout = timeout 376 } 377 // net.Dialer.Timeout has priority if smaller than the timeouts computed so 378 // far 379 if c.Dialer != nil && c.Dialer.Timeout != 0 { 380 if c.Dialer.Timeout < requestTimeout { 381 requestTimeout = c.Dialer.Timeout 382 } 383 } 384 return requestTimeout 385 } 386 387 // Dial connects to the address on the named network. 388 func Dial(network, address string) (conn *Conn, err error) { 389 conn = new(Conn) 390 conn.Conn, err = net.Dial(network, address) 391 if err != nil { 392 return nil, err 393 } 394 return conn, nil 395 } 396 397 // ExchangeContext performs a synchronous UDP query, like Exchange. It 398 // additionally obeys deadlines from the passed Context. 399 func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) { 400 client := Client{Net: "udp"} 401 r, _, err = client.ExchangeContext(ctx, m, a) 402 // ignoring rtt to leave the original ExchangeContext API unchanged, but 403 // this function will go away 404 return r, err 405 } 406 407 // ExchangeConn performs a synchronous query. It sends the message m via the connection 408 // c and waits for a reply. The connection c is not closed by ExchangeConn. 409 // Deprecated: This function is going away, but can easily be mimicked: 410 // 411 // co := &dns.Conn{Conn: c} // c is your net.Conn 412 // co.WriteMsg(m) 413 // in, _ := co.ReadMsg() 414 // co.Close() 415 func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) { 416 println("dns: ExchangeConn: this function is deprecated") 417 co := new(Conn) 418 co.Conn = c 419 if err = co.WriteMsg(m); err != nil { 420 return nil, err 421 } 422 r, err = co.ReadMsg() 423 if err == nil && r.Id != m.Id { 424 err = ErrId 425 } 426 return r, err 427 } 428 429 // DialTimeout acts like Dial but takes a timeout. 430 func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) { 431 client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}} 432 return client.Dial(address) 433 } 434 435 // DialWithTLS connects to the address on the named network with TLS. 436 func DialWithTLS(network, address string, tlsConfig *tls.Config) (conn *Conn, err error) { 437 if !strings.HasSuffix(network, "-tls") { 438 network += "-tls" 439 } 440 client := Client{Net: network, TLSConfig: tlsConfig} 441 return client.Dial(address) 442 } 443 444 // DialTimeoutWithTLS acts like DialWithTLS but takes a timeout. 445 func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout time.Duration) (conn *Conn, err error) { 446 if !strings.HasSuffix(network, "-tls") { 447 network += "-tls" 448 } 449 client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}, TLSConfig: tlsConfig} 450 return client.Dial(address) 451 } 452 453 // ExchangeContext acts like Exchange, but honors the deadline on the provided 454 // context, if present. If there is both a context deadline and a configured 455 // timeout on the client, the earliest of the two takes effect. 456 func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) { 457 conn, err := c.DialContext(ctx, a) 458 if err != nil { 459 return nil, 0, err 460 } 461 defer conn.Close() 462 463 return c.exchangeWithConnContext(ctx, m, conn) 464 }