gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

client.go (14604B)


      1 package dns
      2 
      3 // A client implementation.
      4 
      5 import (
      6 	"context"
      7 	"crypto/tls"
      8 	"encoding/binary"
      9 	"io"
     10 	"net"
     11 	"strings"
     12 	"time"
     13 )
     14 
     15 const (
     16 	dnsTimeout     time.Duration = 2 * time.Second
     17 	tcpIdleTimeout time.Duration = 8 * time.Second
     18 )
     19 
     20 func isPacketConn(c net.Conn) bool {
     21 	if _, ok := c.(net.PacketConn); !ok {
     22 		return false
     23 	}
     24 
     25 	if ua, ok := c.LocalAddr().(*net.UnixAddr); ok {
     26 		return ua.Net == "unixgram" || ua.Net == "unixpacket"
     27 	}
     28 
     29 	return true
     30 }
     31 
     32 // A Conn represents a connection to a DNS server.
     33 type Conn struct {
     34 	net.Conn                         // a net.Conn holding the connection
     35 	UDPSize        uint16            // minimum receive buffer for UDP messages
     36 	TsigSecret     map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
     37 	TsigProvider   TsigProvider      // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
     38 	tsigRequestMAC string
     39 }
     40 
     41 func (co *Conn) tsigProvider() TsigProvider {
     42 	if co.TsigProvider != nil {
     43 		return co.TsigProvider
     44 	}
     45 	// tsigSecretProvider will return ErrSecret if co.TsigSecret is nil.
     46 	return tsigSecretProvider(co.TsigSecret)
     47 }
     48 
     49 // A Client defines parameters for a DNS client.
     50 type Client struct {
     51 	Net       string      // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
     52 	UDPSize   uint16      // minimum receive buffer for UDP messages
     53 	TLSConfig *tls.Config // TLS connection configuration
     54 	Dialer    *net.Dialer // a net.Dialer used to set local address, timeouts and more
     55 	// Timeout is a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout,
     56 	// WriteTimeout when non-zero. Can be overridden with net.Dialer.Timeout (see Client.ExchangeWithDialer and
     57 	// Client.Dialer) or context.Context.Deadline (see ExchangeContext)
     58 	Timeout      time.Duration
     59 	DialTimeout  time.Duration     // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero
     60 	ReadTimeout  time.Duration     // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
     61 	WriteTimeout time.Duration     // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
     62 	TsigSecret   map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
     63 	TsigProvider TsigProvider      // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
     64 
     65 	// SingleInflight previously serialised multiple concurrent queries for the
     66 	// same Qname, Qtype and Qclass to ensure only one would be in flight at a
     67 	// time.
     68 	//
     69 	// Deprecated: This is a no-op. Callers should implement their own in flight
     70 	// query caching if needed. See github.com/miekg/dns/issues/1449.
     71 	SingleInflight bool
     72 }
     73 
     74 // Exchange performs a synchronous UDP query. It sends the message m to the address
     75 // contained in a and waits for a reply. Exchange does not retry a failed query, nor
     76 // will it fall back to TCP in case of truncation.
     77 // See client.Exchange for more information on setting larger buffer sizes.
     78 func Exchange(m *Msg, a string) (r *Msg, err error) {
     79 	client := Client{Net: "udp"}
     80 	r, _, err = client.Exchange(m, a)
     81 	return r, err
     82 }
     83 
     84 func (c *Client) dialTimeout() time.Duration {
     85 	if c.Timeout != 0 {
     86 		return c.Timeout
     87 	}
     88 	if c.DialTimeout != 0 {
     89 		return c.DialTimeout
     90 	}
     91 	return dnsTimeout
     92 }
     93 
     94 func (c *Client) readTimeout() time.Duration {
     95 	if c.ReadTimeout != 0 {
     96 		return c.ReadTimeout
     97 	}
     98 	return dnsTimeout
     99 }
    100 
    101 func (c *Client) writeTimeout() time.Duration {
    102 	if c.WriteTimeout != 0 {
    103 		return c.WriteTimeout
    104 	}
    105 	return dnsTimeout
    106 }
    107 
    108 // Dial connects to the address on the named network.
    109 func (c *Client) Dial(address string) (conn *Conn, err error) {
    110 	return c.DialContext(context.Background(), address)
    111 }
    112 
    113 // DialContext connects to the address on the named network, with a context.Context.
    114 func (c *Client) DialContext(ctx context.Context, address string) (conn *Conn, err error) {
    115 	// create a new dialer with the appropriate timeout
    116 	var d net.Dialer
    117 	if c.Dialer == nil {
    118 		d = net.Dialer{Timeout: c.getTimeoutForRequest(c.dialTimeout())}
    119 	} else {
    120 		d = *c.Dialer
    121 	}
    122 
    123 	network := c.Net
    124 	if network == "" {
    125 		network = "udp"
    126 	}
    127 
    128 	useTLS := strings.HasPrefix(network, "tcp") && strings.HasSuffix(network, "-tls")
    129 
    130 	conn = new(Conn)
    131 	if useTLS {
    132 		network = strings.TrimSuffix(network, "-tls")
    133 
    134 		tlsDialer := tls.Dialer{
    135 			NetDialer: &d,
    136 			Config:    c.TLSConfig,
    137 		}
    138 		conn.Conn, err = tlsDialer.DialContext(ctx, network, address)
    139 	} else {
    140 		conn.Conn, err = d.DialContext(ctx, network, address)
    141 	}
    142 	if err != nil {
    143 		return nil, err
    144 	}
    145 	conn.UDPSize = c.UDPSize
    146 	return conn, nil
    147 }
    148 
    149 // Exchange performs a synchronous query. It sends the message m to the address
    150 // contained in a and waits for a reply. Basic use pattern with a *dns.Client:
    151 //
    152 //	c := new(dns.Client)
    153 //	in, rtt, err := c.Exchange(message, "127.0.0.1:53")
    154 //
    155 // Exchange does not retry a failed query, nor will it fall back to TCP in
    156 // case of truncation.
    157 // It is up to the caller to create a message that allows for larger responses to be
    158 // returned. Specifically this means adding an EDNS0 OPT RR that will advertise a larger
    159 // buffer, see SetEdns0. Messages without an OPT RR will fallback to the historic limit
    160 // of 512 bytes
    161 // To specify a local address or a timeout, the caller has to set the `Client.Dialer`
    162 // attribute appropriately
    163 func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) {
    164 	co, err := c.Dial(address)
    165 
    166 	if err != nil {
    167 		return nil, 0, err
    168 	}
    169 	defer co.Close()
    170 	return c.ExchangeWithConn(m, co)
    171 }
    172 
    173 // ExchangeWithConn has the same behavior as Exchange, just with a predetermined connection
    174 // that will be used instead of creating a new one.
    175 // Usage pattern with a *dns.Client:
    176 //
    177 //	c := new(dns.Client)
    178 //	// connection management logic goes here
    179 //
    180 //	conn := c.Dial(address)
    181 //	in, rtt, err := c.ExchangeWithConn(message, conn)
    182 //
    183 // This allows users of the library to implement their own connection management,
    184 // as opposed to Exchange, which will always use new connections and incur the added overhead
    185 // that entails when using "tcp" and especially "tcp-tls" clients.
    186 //
    187 // When the singleflight is set for this client the context is _not_ forwarded to the (shared) exchange, to
    188 // prevent one cancellation from canceling all outstanding requests.
    189 func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) {
    190 	return c.exchangeWithConnContext(context.Background(), m, conn)
    191 }
    192 
    193 func (c *Client) exchangeWithConnContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) {
    194 	opt := m.IsEdns0()
    195 	// If EDNS0 is used use that for size.
    196 	if opt != nil && opt.UDPSize() >= MinMsgSize {
    197 		co.UDPSize = opt.UDPSize()
    198 	}
    199 	// Otherwise use the client's configured UDP size.
    200 	if opt == nil && c.UDPSize >= MinMsgSize {
    201 		co.UDPSize = c.UDPSize
    202 	}
    203 
    204 	// write with the appropriate write timeout
    205 	t := time.Now()
    206 	writeDeadline := t.Add(c.getTimeoutForRequest(c.writeTimeout()))
    207 	readDeadline := t.Add(c.getTimeoutForRequest(c.readTimeout()))
    208 	if deadline, ok := ctx.Deadline(); ok {
    209 		if deadline.Before(writeDeadline) {
    210 			writeDeadline = deadline
    211 		}
    212 		if deadline.Before(readDeadline) {
    213 			readDeadline = deadline
    214 		}
    215 	}
    216 	co.SetWriteDeadline(writeDeadline)
    217 	co.SetReadDeadline(readDeadline)
    218 
    219 	co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider
    220 
    221 	if err = co.WriteMsg(m); err != nil {
    222 		return nil, 0, err
    223 	}
    224 
    225 	if isPacketConn(co.Conn) {
    226 		for {
    227 			r, err = co.ReadMsg()
    228 			// Ignore replies with mismatched IDs because they might be
    229 			// responses to earlier queries that timed out.
    230 			if err != nil || r.Id == m.Id {
    231 				break
    232 			}
    233 		}
    234 	} else {
    235 		r, err = co.ReadMsg()
    236 		if err == nil && r.Id != m.Id {
    237 			err = ErrId
    238 		}
    239 	}
    240 	rtt = time.Since(t)
    241 	return r, rtt, err
    242 }
    243 
    244 // ReadMsg reads a message from the connection co.
    245 // If the received message contains a TSIG record the transaction signature
    246 // is verified. This method always tries to return the message, however if an
    247 // error is returned there are no guarantees that the returned message is a
    248 // valid representation of the packet read.
    249 func (co *Conn) ReadMsg() (*Msg, error) {
    250 	p, err := co.ReadMsgHeader(nil)
    251 	if err != nil {
    252 		return nil, err
    253 	}
    254 
    255 	m := new(Msg)
    256 	if err := m.Unpack(p); err != nil {
    257 		// If an error was returned, we still want to allow the user to use
    258 		// the message, but naively they can just check err if they don't want
    259 		// to use an erroneous message
    260 		return m, err
    261 	}
    262 	if t := m.IsTsig(); t != nil {
    263 		// Need to work on the original message p, as that was used to calculate the tsig.
    264 		err = TsigVerifyWithProvider(p, co.tsigProvider(), co.tsigRequestMAC, false)
    265 	}
    266 	return m, err
    267 }
    268 
    269 // ReadMsgHeader reads a DNS message, parses and populates hdr (when hdr is not nil).
    270 // Returns message as a byte slice to be parsed with Msg.Unpack later on.
    271 // Note that error handling on the message body is not possible as only the header is parsed.
    272 func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
    273 	var (
    274 		p   []byte
    275 		n   int
    276 		err error
    277 	)
    278 
    279 	if isPacketConn(co.Conn) {
    280 		if co.UDPSize > MinMsgSize {
    281 			p = make([]byte, co.UDPSize)
    282 		} else {
    283 			p = make([]byte, MinMsgSize)
    284 		}
    285 		n, err = co.Read(p)
    286 	} else {
    287 		var length uint16
    288 		if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil {
    289 			return nil, err
    290 		}
    291 
    292 		p = make([]byte, length)
    293 		n, err = io.ReadFull(co.Conn, p)
    294 	}
    295 
    296 	if err != nil {
    297 		return nil, err
    298 	} else if n < headerSize {
    299 		return nil, ErrShortRead
    300 	}
    301 
    302 	p = p[:n]
    303 	if hdr != nil {
    304 		dh, _, err := unpackMsgHdr(p, 0)
    305 		if err != nil {
    306 			return nil, err
    307 		}
    308 		*hdr = dh
    309 	}
    310 	return p, err
    311 }
    312 
    313 // Read implements the net.Conn read method.
    314 func (co *Conn) Read(p []byte) (n int, err error) {
    315 	if co.Conn == nil {
    316 		return 0, ErrConnEmpty
    317 	}
    318 
    319 	if isPacketConn(co.Conn) {
    320 		// UDP connection
    321 		return co.Conn.Read(p)
    322 	}
    323 
    324 	var length uint16
    325 	if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil {
    326 		return 0, err
    327 	}
    328 	if int(length) > len(p) {
    329 		return 0, io.ErrShortBuffer
    330 	}
    331 
    332 	return io.ReadFull(co.Conn, p[:length])
    333 }
    334 
    335 // WriteMsg sends a message through the connection co.
    336 // If the message m contains a TSIG record the transaction
    337 // signature is calculated.
    338 func (co *Conn) WriteMsg(m *Msg) (err error) {
    339 	var out []byte
    340 	if t := m.IsTsig(); t != nil {
    341 		// Set tsigRequestMAC for the next read, although only used in zone transfers.
    342 		out, co.tsigRequestMAC, err = TsigGenerateWithProvider(m, co.tsigProvider(), co.tsigRequestMAC, false)
    343 	} else {
    344 		out, err = m.Pack()
    345 	}
    346 	if err != nil {
    347 		return err
    348 	}
    349 	_, err = co.Write(out)
    350 	return err
    351 }
    352 
    353 // Write implements the net.Conn Write method.
    354 func (co *Conn) Write(p []byte) (int, error) {
    355 	if len(p) > MaxMsgSize {
    356 		return 0, &Error{err: "message too large"}
    357 	}
    358 
    359 	if isPacketConn(co.Conn) {
    360 		return co.Conn.Write(p)
    361 	}
    362 
    363 	msg := make([]byte, 2+len(p))
    364 	binary.BigEndian.PutUint16(msg, uint16(len(p)))
    365 	copy(msg[2:], p)
    366 	return co.Conn.Write(msg)
    367 }
    368 
    369 // Return the appropriate timeout for a specific request
    370 func (c *Client) getTimeoutForRequest(timeout time.Duration) time.Duration {
    371 	var requestTimeout time.Duration
    372 	if c.Timeout != 0 {
    373 		requestTimeout = c.Timeout
    374 	} else {
    375 		requestTimeout = timeout
    376 	}
    377 	// net.Dialer.Timeout has priority if smaller than the timeouts computed so
    378 	// far
    379 	if c.Dialer != nil && c.Dialer.Timeout != 0 {
    380 		if c.Dialer.Timeout < requestTimeout {
    381 			requestTimeout = c.Dialer.Timeout
    382 		}
    383 	}
    384 	return requestTimeout
    385 }
    386 
    387 // Dial connects to the address on the named network.
    388 func Dial(network, address string) (conn *Conn, err error) {
    389 	conn = new(Conn)
    390 	conn.Conn, err = net.Dial(network, address)
    391 	if err != nil {
    392 		return nil, err
    393 	}
    394 	return conn, nil
    395 }
    396 
    397 // ExchangeContext performs a synchronous UDP query, like Exchange. It
    398 // additionally obeys deadlines from the passed Context.
    399 func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) {
    400 	client := Client{Net: "udp"}
    401 	r, _, err = client.ExchangeContext(ctx, m, a)
    402 	// ignoring rtt to leave the original ExchangeContext API unchanged, but
    403 	// this function will go away
    404 	return r, err
    405 }
    406 
    407 // ExchangeConn performs a synchronous query. It sends the message m via the connection
    408 // c and waits for a reply. The connection c is not closed by ExchangeConn.
    409 // Deprecated: This function is going away, but can easily be mimicked:
    410 //
    411 //	co := &dns.Conn{Conn: c} // c is your net.Conn
    412 //	co.WriteMsg(m)
    413 //	in, _  := co.ReadMsg()
    414 //	co.Close()
    415 func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
    416 	println("dns: ExchangeConn: this function is deprecated")
    417 	co := new(Conn)
    418 	co.Conn = c
    419 	if err = co.WriteMsg(m); err != nil {
    420 		return nil, err
    421 	}
    422 	r, err = co.ReadMsg()
    423 	if err == nil && r.Id != m.Id {
    424 		err = ErrId
    425 	}
    426 	return r, err
    427 }
    428 
    429 // DialTimeout acts like Dial but takes a timeout.
    430 func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) {
    431 	client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}}
    432 	return client.Dial(address)
    433 }
    434 
    435 // DialWithTLS connects to the address on the named network with TLS.
    436 func DialWithTLS(network, address string, tlsConfig *tls.Config) (conn *Conn, err error) {
    437 	if !strings.HasSuffix(network, "-tls") {
    438 		network += "-tls"
    439 	}
    440 	client := Client{Net: network, TLSConfig: tlsConfig}
    441 	return client.Dial(address)
    442 }
    443 
    444 // DialTimeoutWithTLS acts like DialWithTLS but takes a timeout.
    445 func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout time.Duration) (conn *Conn, err error) {
    446 	if !strings.HasSuffix(network, "-tls") {
    447 		network += "-tls"
    448 	}
    449 	client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}, TLSConfig: tlsConfig}
    450 	return client.Dial(address)
    451 }
    452 
    453 // ExchangeContext acts like Exchange, but honors the deadline on the provided
    454 // context, if present. If there is both a context deadline and a configured
    455 // timeout on the client, the earliest of the two takes effect.
    456 func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
    457 	conn, err := c.DialContext(ctx, a)
    458 	if err != nil {
    459 		return nil, 0, err
    460 	}
    461 	defer conn.Close()
    462 
    463 	return c.exchangeWithConnContext(ctx, m, conn)
    464 }