msg.go (32638B)
1 // DNS packet assembly, see RFC 1035. Converting from - Unpack() - 2 // and to - Pack() - wire format. 3 // All the packers and unpackers take a (msg []byte, off int) 4 // and return (off1 int, ok bool). If they return ok==false, they 5 // also return off1==len(msg), so that the next unpacker will 6 // also fail. This lets us avoid checks of ok until the end of a 7 // packing sequence. 8 9 package dns 10 11 //go:generate go run msg_generate.go 12 13 import ( 14 "crypto/rand" 15 "encoding/binary" 16 "fmt" 17 "math/big" 18 "strconv" 19 "strings" 20 ) 21 22 const ( 23 maxCompressionOffset = 2 << 13 // We have 14 bits for the compression pointer 24 maxDomainNameWireOctets = 255 // See RFC 1035 section 2.3.4 25 26 // This is the maximum number of compression pointers that should occur in a 27 // semantically valid message. Each label in a domain name must be at least one 28 // octet and is separated by a period. The root label won't be represented by a 29 // compression pointer to a compression pointer, hence the -2 to exclude the 30 // smallest valid root label. 31 // 32 // It is possible to construct a valid message that has more compression pointers 33 // than this, and still doesn't loop, by pointing to a previous pointer. This is 34 // not something a well written implementation should ever do, so we leave them 35 // to trip the maximum compression pointer check. 36 maxCompressionPointers = (maxDomainNameWireOctets+1)/2 - 2 37 38 // This is the maximum length of a domain name in presentation format. The 39 // maximum wire length of a domain name is 255 octets (see above), with the 40 // maximum label length being 63. The wire format requires one extra byte over 41 // the presentation format, reducing the number of octets by 1. Each label in 42 // the name will be separated by a single period, with each octet in the label 43 // expanding to at most 4 bytes (\DDD). If all other labels are of the maximum 44 // length, then the final label can only be 61 octets long to not exceed the 45 // maximum allowed wire length. 46 maxDomainNamePresentationLength = 61*4 + 1 + 63*4 + 1 + 63*4 + 1 + 63*4 + 1 47 ) 48 49 // Errors defined in this package. 50 var ( 51 ErrAlg error = &Error{err: "bad algorithm"} // ErrAlg indicates an error with the (DNSSEC) algorithm. 52 ErrAuth error = &Error{err: "bad authentication"} // ErrAuth indicates an error in the TSIG authentication. 53 ErrBuf error = &Error{err: "buffer size too small"} // ErrBuf indicates that the buffer used is too small for the message. 54 ErrConnEmpty error = &Error{err: "conn has no connection"} // ErrConnEmpty indicates a connection is being used before it is initialized. 55 ErrExtendedRcode error = &Error{err: "bad extended rcode"} // ErrExtendedRcode ... 56 ErrFqdn error = &Error{err: "domain must be fully qualified"} // ErrFqdn indicates that a domain name does not have a closing dot. 57 ErrId error = &Error{err: "id mismatch"} // ErrId indicates there is a mismatch with the message's ID. 58 ErrKeyAlg error = &Error{err: "bad key algorithm"} // ErrKeyAlg indicates that the algorithm in the key is not valid. 59 ErrKey error = &Error{err: "bad key"} 60 ErrKeySize error = &Error{err: "bad key size"} 61 ErrLongDomain error = &Error{err: fmt.Sprintf("domain name exceeded %d wire-format octets", maxDomainNameWireOctets)} 62 ErrNoSig error = &Error{err: "no signature found"} 63 ErrPrivKey error = &Error{err: "bad private key"} 64 ErrRcode error = &Error{err: "bad rcode"} 65 ErrRdata error = &Error{err: "bad rdata"} 66 ErrRRset error = &Error{err: "bad rrset"} 67 ErrSecret error = &Error{err: "no secrets defined"} 68 ErrShortRead error = &Error{err: "short read"} 69 ErrSig error = &Error{err: "bad signature"} // ErrSig indicates that a signature can not be cryptographically validated. 70 ErrSoa error = &Error{err: "no SOA"} // ErrSOA indicates that no SOA RR was seen when doing zone transfers. 71 ErrTime error = &Error{err: "bad time"} // ErrTime indicates a timing error in TSIG authentication. 72 ) 73 74 // Id by default returns a 16-bit random number to be used as a message id. The 75 // number is drawn from a cryptographically secure random number generator. 76 // This being a variable the function can be reassigned to a custom function. 77 // For instance, to make it return a static value for testing: 78 // 79 // dns.Id = func() uint16 { return 3 } 80 var Id = id 81 82 // id returns a 16 bits random number to be used as a 83 // message id. The random provided should be good enough. 84 func id() uint16 { 85 var output uint16 86 err := binary.Read(rand.Reader, binary.BigEndian, &output) 87 if err != nil { 88 panic("dns: reading random id failed: " + err.Error()) 89 } 90 return output 91 } 92 93 // MsgHdr is a a manually-unpacked version of (id, bits). 94 type MsgHdr struct { 95 Id uint16 96 Response bool 97 Opcode int 98 Authoritative bool 99 Truncated bool 100 RecursionDesired bool 101 RecursionAvailable bool 102 Zero bool 103 AuthenticatedData bool 104 CheckingDisabled bool 105 Rcode int 106 } 107 108 // Msg contains the layout of a DNS message. 109 type Msg struct { 110 MsgHdr 111 Compress bool `json:"-"` // If true, the message will be compressed when converted to wire format. 112 Question []Question // Holds the RR(s) of the question section. 113 Answer []RR // Holds the RR(s) of the answer section. 114 Ns []RR // Holds the RR(s) of the authority section. 115 Extra []RR // Holds the RR(s) of the additional section. 116 } 117 118 // ClassToString is a maps Classes to strings for each CLASS wire type. 119 var ClassToString = map[uint16]string{ 120 ClassINET: "IN", 121 ClassCSNET: "CS", 122 ClassCHAOS: "CH", 123 ClassHESIOD: "HS", 124 ClassNONE: "NONE", 125 ClassANY: "ANY", 126 } 127 128 // OpcodeToString maps Opcodes to strings. 129 var OpcodeToString = map[int]string{ 130 OpcodeQuery: "QUERY", 131 OpcodeIQuery: "IQUERY", 132 OpcodeStatus: "STATUS", 133 OpcodeNotify: "NOTIFY", 134 OpcodeUpdate: "UPDATE", 135 } 136 137 // RcodeToString maps Rcodes to strings. 138 var RcodeToString = map[int]string{ 139 RcodeSuccess: "NOERROR", 140 RcodeFormatError: "FORMERR", 141 RcodeServerFailure: "SERVFAIL", 142 RcodeNameError: "NXDOMAIN", 143 RcodeNotImplemented: "NOTIMP", 144 RcodeRefused: "REFUSED", 145 RcodeYXDomain: "YXDOMAIN", // See RFC 2136 146 RcodeYXRrset: "YXRRSET", 147 RcodeNXRrset: "NXRRSET", 148 RcodeNotAuth: "NOTAUTH", 149 RcodeNotZone: "NOTZONE", 150 RcodeBadSig: "BADSIG", // Also known as RcodeBadVers, see RFC 6891 151 // RcodeBadVers: "BADVERS", 152 RcodeBadKey: "BADKEY", 153 RcodeBadTime: "BADTIME", 154 RcodeBadMode: "BADMODE", 155 RcodeBadName: "BADNAME", 156 RcodeBadAlg: "BADALG", 157 RcodeBadTrunc: "BADTRUNC", 158 RcodeBadCookie: "BADCOOKIE", 159 } 160 161 // compressionMap is used to allow a more efficient compression map 162 // to be used for internal packDomainName calls without changing the 163 // signature or functionality of public API. 164 // 165 // In particular, map[string]uint16 uses 25% less per-entry memory 166 // than does map[string]int. 167 type compressionMap struct { 168 ext map[string]int // external callers 169 int map[string]uint16 // internal callers 170 } 171 172 func (m compressionMap) valid() bool { 173 return m.int != nil || m.ext != nil 174 } 175 176 func (m compressionMap) insert(s string, pos int) { 177 if m.ext != nil { 178 m.ext[s] = pos 179 } else { 180 m.int[s] = uint16(pos) 181 } 182 } 183 184 func (m compressionMap) find(s string) (int, bool) { 185 if m.ext != nil { 186 pos, ok := m.ext[s] 187 return pos, ok 188 } 189 190 pos, ok := m.int[s] 191 return int(pos), ok 192 } 193 194 // Domain names are a sequence of counted strings 195 // split at the dots. They end with a zero-length string. 196 197 // PackDomainName packs a domain name s into msg[off:]. 198 // If compression is wanted compress must be true and the compression 199 // map needs to hold a mapping between domain names and offsets 200 // pointing into msg. 201 func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { 202 return packDomainName(s, msg, off, compressionMap{ext: compression}, compress) 203 } 204 205 func packDomainName(s string, msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) { 206 // XXX: A logical copy of this function exists in IsDomainName and 207 // should be kept in sync with this function. 208 209 ls := len(s) 210 if ls == 0 { // Ok, for instance when dealing with update RR without any rdata. 211 return off, nil 212 } 213 214 // If not fully qualified, error out. 215 if !IsFqdn(s) { 216 return len(msg), ErrFqdn 217 } 218 219 // Each dot ends a segment of the name. 220 // We trade each dot byte for a length byte. 221 // Except for escaped dots (\.), which are normal dots. 222 // There is also a trailing zero. 223 224 // Compression 225 pointer := -1 226 227 // Emit sequence of counted strings, chopping at dots. 228 var ( 229 begin int 230 compBegin int 231 compOff int 232 bs []byte 233 wasDot bool 234 ) 235 loop: 236 for i := 0; i < ls; i++ { 237 var c byte 238 if bs == nil { 239 c = s[i] 240 } else { 241 c = bs[i] 242 } 243 244 switch c { 245 case '\\': 246 if off+1 > len(msg) { 247 return len(msg), ErrBuf 248 } 249 250 if bs == nil { 251 bs = []byte(s) 252 } 253 254 // check for \DDD 255 if isDDD(bs[i+1:]) { 256 bs[i] = dddToByte(bs[i+1:]) 257 copy(bs[i+1:ls-3], bs[i+4:]) 258 ls -= 3 259 compOff += 3 260 } else { 261 copy(bs[i:ls-1], bs[i+1:]) 262 ls-- 263 compOff++ 264 } 265 266 wasDot = false 267 case '.': 268 if i == 0 && len(s) > 1 { 269 // leading dots are not legal except for the root zone 270 return len(msg), ErrRdata 271 } 272 273 if wasDot { 274 // two dots back to back is not legal 275 return len(msg), ErrRdata 276 } 277 wasDot = true 278 279 labelLen := i - begin 280 if labelLen >= 1<<6 { // top two bits of length must be clear 281 return len(msg), ErrRdata 282 } 283 284 // off can already (we're in a loop) be bigger than len(msg) 285 // this happens when a name isn't fully qualified 286 if off+1+labelLen > len(msg) { 287 return len(msg), ErrBuf 288 } 289 290 // Don't try to compress '.' 291 // We should only compress when compress is true, but we should also still pick 292 // up names that can be used for *future* compression(s). 293 if compression.valid() && !isRootLabel(s, bs, begin, ls) { 294 if p, ok := compression.find(s[compBegin:]); ok { 295 // The first hit is the longest matching dname 296 // keep the pointer offset we get back and store 297 // the offset of the current name, because that's 298 // where we need to insert the pointer later 299 300 // If compress is true, we're allowed to compress this dname 301 if compress { 302 pointer = p // Where to point to 303 break loop 304 } 305 } else if off < maxCompressionOffset { 306 // Only offsets smaller than maxCompressionOffset can be used. 307 compression.insert(s[compBegin:], off) 308 } 309 } 310 311 // The following is covered by the length check above. 312 msg[off] = byte(labelLen) 313 314 if bs == nil { 315 copy(msg[off+1:], s[begin:i]) 316 } else { 317 copy(msg[off+1:], bs[begin:i]) 318 } 319 off += 1 + labelLen 320 321 begin = i + 1 322 compBegin = begin + compOff 323 default: 324 wasDot = false 325 } 326 } 327 328 // Root label is special 329 if isRootLabel(s, bs, 0, ls) { 330 return off, nil 331 } 332 333 // If we did compression and we find something add the pointer here 334 if pointer != -1 { 335 // We have two bytes (14 bits) to put the pointer in 336 binary.BigEndian.PutUint16(msg[off:], uint16(pointer^0xC000)) 337 return off + 2, nil 338 } 339 340 if off < len(msg) { 341 msg[off] = 0 342 } 343 344 return off + 1, nil 345 } 346 347 // isRootLabel returns whether s or bs, from off to end, is the root 348 // label ".". 349 // 350 // If bs is nil, s will be checked, otherwise bs will be checked. 351 func isRootLabel(s string, bs []byte, off, end int) bool { 352 if bs == nil { 353 return s[off:end] == "." 354 } 355 356 return end-off == 1 && bs[off] == '.' 357 } 358 359 // Unpack a domain name. 360 // In addition to the simple sequences of counted strings above, 361 // domain names are allowed to refer to strings elsewhere in the 362 // packet, to avoid repeating common suffixes when returning 363 // many entries in a single domain. The pointers are marked 364 // by a length byte with the top two bits set. Ignoring those 365 // two bits, that byte and the next give a 14 bit offset from msg[0] 366 // where we should pick up the trail. 367 // Note that if we jump elsewhere in the packet, 368 // we return off1 == the offset after the first pointer we found, 369 // which is where the next record will start. 370 // In theory, the pointers are only allowed to jump backward. 371 // We let them jump anywhere and stop jumping after a while. 372 373 // UnpackDomainName unpacks a domain name into a string. It returns 374 // the name, the new offset into msg and any error that occurred. 375 // 376 // When an error is encountered, the unpacked name will be discarded 377 // and len(msg) will be returned as the offset. 378 func UnpackDomainName(msg []byte, off int) (string, int, error) { 379 s := make([]byte, 0, maxDomainNamePresentationLength) 380 off1 := 0 381 lenmsg := len(msg) 382 budget := maxDomainNameWireOctets 383 ptr := 0 // number of pointers followed 384 Loop: 385 for { 386 if off >= lenmsg { 387 return "", lenmsg, ErrBuf 388 } 389 c := int(msg[off]) 390 off++ 391 switch c & 0xC0 { 392 case 0x00: 393 if c == 0x00 { 394 // end of name 395 break Loop 396 } 397 // literal string 398 if off+c > lenmsg { 399 return "", lenmsg, ErrBuf 400 } 401 budget -= c + 1 // +1 for the label separator 402 if budget <= 0 { 403 return "", lenmsg, ErrLongDomain 404 } 405 for _, b := range msg[off : off+c] { 406 if isDomainNameLabelSpecial(b) { 407 s = append(s, '\\', b) 408 } else if b < ' ' || b > '~' { 409 s = append(s, escapeByte(b)...) 410 } else { 411 s = append(s, b) 412 } 413 } 414 s = append(s, '.') 415 off += c 416 case 0xC0: 417 // pointer to somewhere else in msg. 418 // remember location after first ptr, 419 // since that's how many bytes we consumed. 420 // also, don't follow too many pointers -- 421 // maybe there's a loop. 422 if off >= lenmsg { 423 return "", lenmsg, ErrBuf 424 } 425 c1 := msg[off] 426 off++ 427 if ptr == 0 { 428 off1 = off 429 } 430 if ptr++; ptr > maxCompressionPointers { 431 return "", lenmsg, &Error{err: "too many compression pointers"} 432 } 433 // pointer should guarantee that it advances and points forwards at least 434 // but the condition on previous three lines guarantees that it's 435 // at least loop-free 436 off = (c^0xC0)<<8 | int(c1) 437 default: 438 // 0x80 and 0x40 are reserved 439 return "", lenmsg, ErrRdata 440 } 441 } 442 if ptr == 0 { 443 off1 = off 444 } 445 if len(s) == 0 { 446 return ".", off1, nil 447 } 448 return string(s), off1, nil 449 } 450 451 func packTxt(txt []string, msg []byte, offset int) (int, error) { 452 if len(txt) == 0 { 453 if offset >= len(msg) { 454 return offset, ErrBuf 455 } 456 msg[offset] = 0 457 return offset, nil 458 } 459 var err error 460 for _, s := range txt { 461 offset, err = packTxtString(s, msg, offset) 462 if err != nil { 463 return offset, err 464 } 465 } 466 return offset, nil 467 } 468 469 func packTxtString(s string, msg []byte, offset int) (int, error) { 470 lenByteOffset := offset 471 if offset >= len(msg) || len(s) > 256*4+1 /* If all \DDD */ { 472 return offset, ErrBuf 473 } 474 offset++ 475 for i := 0; i < len(s); i++ { 476 if len(msg) <= offset { 477 return offset, ErrBuf 478 } 479 if s[i] == '\\' { 480 i++ 481 if i == len(s) { 482 break 483 } 484 // check for \DDD 485 if isDDD(s[i:]) { 486 msg[offset] = dddToByte(s[i:]) 487 i += 2 488 } else { 489 msg[offset] = s[i] 490 } 491 } else { 492 msg[offset] = s[i] 493 } 494 offset++ 495 } 496 l := offset - lenByteOffset - 1 497 if l > 255 { 498 return offset, &Error{err: "string exceeded 255 bytes in txt"} 499 } 500 msg[lenByteOffset] = byte(l) 501 return offset, nil 502 } 503 504 func packOctetString(s string, msg []byte, offset int, tmp []byte) (int, error) { 505 if offset >= len(msg) || len(s) > len(tmp) { 506 return offset, ErrBuf 507 } 508 bs := tmp[:len(s)] 509 copy(bs, s) 510 for i := 0; i < len(bs); i++ { 511 if len(msg) <= offset { 512 return offset, ErrBuf 513 } 514 if bs[i] == '\\' { 515 i++ 516 if i == len(bs) { 517 break 518 } 519 // check for \DDD 520 if isDDD(bs[i:]) { 521 msg[offset] = dddToByte(bs[i:]) 522 i += 2 523 } else { 524 msg[offset] = bs[i] 525 } 526 } else { 527 msg[offset] = bs[i] 528 } 529 offset++ 530 } 531 return offset, nil 532 } 533 534 func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) { 535 off = off0 536 var s string 537 for off < len(msg) && err == nil { 538 s, off, err = unpackString(msg, off) 539 if err == nil { 540 ss = append(ss, s) 541 } 542 } 543 return 544 } 545 546 // Helpers for dealing with escaped bytes 547 func isDigit(b byte) bool { return b >= '0' && b <= '9' } 548 549 func isDDD[T ~[]byte | ~string](s T) bool { 550 return len(s) >= 3 && isDigit(s[0]) && isDigit(s[1]) && isDigit(s[2]) 551 } 552 553 func dddToByte[T ~[]byte | ~string](s T) byte { 554 _ = s[2] // bounds check hint to compiler; see golang.org/issue/14808 555 return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0')) 556 } 557 558 // Helper function for packing and unpacking 559 func intToBytes(i *big.Int, length int) []byte { 560 buf := i.Bytes() 561 if len(buf) < length { 562 b := make([]byte, length) 563 copy(b[length-len(buf):], buf) 564 return b 565 } 566 return buf 567 } 568 569 // PackRR packs a resource record rr into msg[off:]. 570 // See PackDomainName for documentation about the compression. 571 func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { 572 headerEnd, off1, err := packRR(rr, msg, off, compressionMap{ext: compression}, compress) 573 if err == nil { 574 // packRR no longer sets the Rdlength field on the rr, but 575 // callers might be expecting it so we set it here. 576 rr.Header().Rdlength = uint16(off1 - headerEnd) 577 } 578 return off1, err 579 } 580 581 func packRR(rr RR, msg []byte, off int, compression compressionMap, compress bool) (headerEnd int, off1 int, err error) { 582 if rr == nil { 583 return len(msg), len(msg), &Error{err: "nil rr"} 584 } 585 586 headerEnd, err = rr.Header().packHeader(msg, off, compression, compress) 587 if err != nil { 588 return headerEnd, len(msg), err 589 } 590 591 off1, err = rr.pack(msg, headerEnd, compression, compress) 592 if err != nil { 593 return headerEnd, len(msg), err 594 } 595 596 rdlength := off1 - headerEnd 597 if int(uint16(rdlength)) != rdlength { // overflow 598 return headerEnd, len(msg), ErrRdata 599 } 600 601 // The RDLENGTH field is the last field in the header and we set it here. 602 binary.BigEndian.PutUint16(msg[headerEnd-2:], uint16(rdlength)) 603 return headerEnd, off1, nil 604 } 605 606 // UnpackRR unpacks msg[off:] into an RR. 607 func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) { 608 h, off, msg, err := unpackHeader(msg, off) 609 if err != nil { 610 return nil, len(msg), err 611 } 612 613 return UnpackRRWithHeader(h, msg, off) 614 } 615 616 // UnpackRRWithHeader unpacks the record type specific payload given an existing 617 // RR_Header. 618 func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err error) { 619 if newFn, ok := TypeToRR[h.Rrtype]; ok { 620 rr = newFn() 621 *rr.Header() = h 622 } else { 623 rr = &RFC3597{Hdr: h} 624 } 625 626 if off < 0 || off > len(msg) { 627 return &h, off, &Error{err: "bad off"} 628 } 629 630 end := off + int(h.Rdlength) 631 if end < off || end > len(msg) { 632 return &h, end, &Error{err: "bad rdlength"} 633 } 634 635 if noRdata(h) { 636 return rr, off, nil 637 } 638 639 off, err = rr.unpack(msg, off) 640 if err != nil { 641 return nil, end, err 642 } 643 if off != end { 644 return &h, end, &Error{err: "bad rdlength"} 645 } 646 647 return rr, off, nil 648 } 649 650 // unpackRRslice unpacks msg[off:] into an []RR. 651 // If we cannot unpack the whole array, then it will return nil 652 func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error) { 653 var r RR 654 // Don't pre-allocate, l may be under attacker control 655 var dst []RR 656 for i := 0; i < l; i++ { 657 off1 := off 658 r, off, err = UnpackRR(msg, off) 659 if err != nil { 660 off = len(msg) 661 break 662 } 663 // If offset does not increase anymore, l is a lie 664 if off1 == off { 665 break 666 } 667 dst = append(dst, r) 668 } 669 if err != nil && off == len(msg) { 670 dst = nil 671 } 672 return dst, off, err 673 } 674 675 // Convert a MsgHdr to a string, with dig-like headers: 676 // 677 // ;; opcode: QUERY, status: NOERROR, id: 48404 678 // 679 // ;; flags: qr aa rd ra; 680 func (h *MsgHdr) String() string { 681 if h == nil { 682 return "<nil> MsgHdr" 683 } 684 685 s := ";; opcode: " + OpcodeToString[h.Opcode] 686 s += ", status: " + RcodeToString[h.Rcode] 687 s += ", id: " + strconv.Itoa(int(h.Id)) + "\n" 688 689 s += ";; flags:" 690 if h.Response { 691 s += " qr" 692 } 693 if h.Authoritative { 694 s += " aa" 695 } 696 if h.Truncated { 697 s += " tc" 698 } 699 if h.RecursionDesired { 700 s += " rd" 701 } 702 if h.RecursionAvailable { 703 s += " ra" 704 } 705 if h.Zero { // Hmm 706 s += " z" 707 } 708 if h.AuthenticatedData { 709 s += " ad" 710 } 711 if h.CheckingDisabled { 712 s += " cd" 713 } 714 715 s += ";" 716 return s 717 } 718 719 // Pack packs a Msg: it is converted to to wire format. 720 // If the dns.Compress is true the message will be in compressed wire format. 721 func (dns *Msg) Pack() (msg []byte, err error) { 722 return dns.PackBuffer(nil) 723 } 724 725 // PackBuffer packs a Msg, using the given buffer buf. If buf is too small a new buffer is allocated. 726 func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) { 727 // If this message can't be compressed, avoid filling the 728 // compression map and creating garbage. 729 if dns.Compress && dns.isCompressible() { 730 compression := make(map[string]uint16) // Compression pointer mappings. 731 return dns.packBufferWithCompressionMap(buf, compressionMap{int: compression}, true) 732 } 733 734 return dns.packBufferWithCompressionMap(buf, compressionMap{}, false) 735 } 736 737 // packBufferWithCompressionMap packs a Msg, using the given buffer buf. 738 func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression compressionMap, compress bool) (msg []byte, err error) { 739 if dns.Rcode < 0 || dns.Rcode > 0xFFF { 740 return nil, ErrRcode 741 } 742 743 // Set extended rcode unconditionally if we have an opt, this will allow 744 // resetting the extended rcode bits if they need to. 745 if opt := dns.IsEdns0(); opt != nil { 746 opt.SetExtendedRcode(uint16(dns.Rcode)) 747 } else if dns.Rcode > 0xF { 748 // If Rcode is an extended one and opt is nil, error out. 749 return nil, ErrExtendedRcode 750 } 751 752 // Convert convenient Msg into wire-like Header. 753 var dh Header 754 dh.Id = dns.Id 755 dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode&0xF) 756 if dns.Response { 757 dh.Bits |= _QR 758 } 759 if dns.Authoritative { 760 dh.Bits |= _AA 761 } 762 if dns.Truncated { 763 dh.Bits |= _TC 764 } 765 if dns.RecursionDesired { 766 dh.Bits |= _RD 767 } 768 if dns.RecursionAvailable { 769 dh.Bits |= _RA 770 } 771 if dns.Zero { 772 dh.Bits |= _Z 773 } 774 if dns.AuthenticatedData { 775 dh.Bits |= _AD 776 } 777 if dns.CheckingDisabled { 778 dh.Bits |= _CD 779 } 780 781 dh.Qdcount = uint16(len(dns.Question)) 782 dh.Ancount = uint16(len(dns.Answer)) 783 dh.Nscount = uint16(len(dns.Ns)) 784 dh.Arcount = uint16(len(dns.Extra)) 785 786 // We need the uncompressed length here, because we first pack it and then compress it. 787 msg = buf 788 uncompressedLen := msgLenWithCompressionMap(dns, nil) 789 if packLen := uncompressedLen + 1; len(msg) < packLen { 790 msg = make([]byte, packLen) 791 } 792 793 // Pack it in: header and then the pieces. 794 off := 0 795 off, err = dh.pack(msg, off, compression, compress) 796 if err != nil { 797 return nil, err 798 } 799 for _, r := range dns.Question { 800 off, err = r.pack(msg, off, compression, compress) 801 if err != nil { 802 return nil, err 803 } 804 } 805 for _, r := range dns.Answer { 806 _, off, err = packRR(r, msg, off, compression, compress) 807 if err != nil { 808 return nil, err 809 } 810 } 811 for _, r := range dns.Ns { 812 _, off, err = packRR(r, msg, off, compression, compress) 813 if err != nil { 814 return nil, err 815 } 816 } 817 for _, r := range dns.Extra { 818 _, off, err = packRR(r, msg, off, compression, compress) 819 if err != nil { 820 return nil, err 821 } 822 } 823 return msg[:off], nil 824 } 825 826 func (dns *Msg) unpack(dh Header, msg []byte, off int) (err error) { 827 // If we are at the end of the message we should return *just* the 828 // header. This can still be useful to the caller. 9.9.9.9 sends these 829 // when responding with REFUSED for instance. 830 if off == len(msg) { 831 // reset sections before returning 832 dns.Question, dns.Answer, dns.Ns, dns.Extra = nil, nil, nil, nil 833 return nil 834 } 835 836 // Qdcount, Ancount, Nscount, Arcount can't be trusted, as they are 837 // attacker controlled. This means we can't use them to pre-allocate 838 // slices. 839 dns.Question = nil 840 for i := 0; i < int(dh.Qdcount); i++ { 841 off1 := off 842 var q Question 843 q, off, err = unpackQuestion(msg, off) 844 if err != nil { 845 return err 846 } 847 if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie! 848 dh.Qdcount = uint16(i) 849 break 850 } 851 dns.Question = append(dns.Question, q) 852 } 853 854 dns.Answer, off, err = unpackRRslice(int(dh.Ancount), msg, off) 855 // The header counts might have been wrong so we need to update it 856 dh.Ancount = uint16(len(dns.Answer)) 857 if err == nil { 858 dns.Ns, off, err = unpackRRslice(int(dh.Nscount), msg, off) 859 } 860 // The header counts might have been wrong so we need to update it 861 dh.Nscount = uint16(len(dns.Ns)) 862 if err == nil { 863 dns.Extra, _, err = unpackRRslice(int(dh.Arcount), msg, off) 864 } 865 // The header counts might have been wrong so we need to update it 866 dh.Arcount = uint16(len(dns.Extra)) 867 868 // Set extended Rcode 869 if opt := dns.IsEdns0(); opt != nil { 870 dns.Rcode |= opt.ExtendedRcode() 871 } 872 873 // TODO(miek) make this an error? 874 // use PackOpt to let people tell how detailed the error reporting should be? 875 // if off != len(msg) { 876 // // println("dns: extra bytes in dns packet", off, "<", len(msg)) 877 // } 878 return err 879 880 } 881 882 // Unpack unpacks a binary message to a Msg structure. 883 func (dns *Msg) Unpack(msg []byte) (err error) { 884 dh, off, err := unpackMsgHdr(msg, 0) 885 if err != nil { 886 return err 887 } 888 889 dns.setHdr(dh) 890 return dns.unpack(dh, msg, off) 891 } 892 893 // Convert a complete message to a string with dig-like output. 894 func (dns *Msg) String() string { 895 if dns == nil { 896 return "<nil> MsgHdr" 897 } 898 s := dns.MsgHdr.String() + " " 899 s += "QUERY: " + strconv.Itoa(len(dns.Question)) + ", " 900 s += "ANSWER: " + strconv.Itoa(len(dns.Answer)) + ", " 901 s += "AUTHORITY: " + strconv.Itoa(len(dns.Ns)) + ", " 902 s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n" 903 opt := dns.IsEdns0() 904 if opt != nil { 905 // OPT PSEUDOSECTION 906 s += opt.String() + "\n" 907 } 908 if len(dns.Question) > 0 { 909 s += "\n;; QUESTION SECTION:\n" 910 for _, r := range dns.Question { 911 s += r.String() + "\n" 912 } 913 } 914 if len(dns.Answer) > 0 { 915 s += "\n;; ANSWER SECTION:\n" 916 for _, r := range dns.Answer { 917 if r != nil { 918 s += r.String() + "\n" 919 } 920 } 921 } 922 if len(dns.Ns) > 0 { 923 s += "\n;; AUTHORITY SECTION:\n" 924 for _, r := range dns.Ns { 925 if r != nil { 926 s += r.String() + "\n" 927 } 928 } 929 } 930 if len(dns.Extra) > 0 && (opt == nil || len(dns.Extra) > 1) { 931 s += "\n;; ADDITIONAL SECTION:\n" 932 for _, r := range dns.Extra { 933 if r != nil && r.Header().Rrtype != TypeOPT { 934 s += r.String() + "\n" 935 } 936 } 937 } 938 return s 939 } 940 941 // isCompressible returns whether the msg may be compressible. 942 func (dns *Msg) isCompressible() bool { 943 // If we only have one question, there is nothing we can ever compress. 944 return len(dns.Question) > 1 || len(dns.Answer) > 0 || 945 len(dns.Ns) > 0 || len(dns.Extra) > 0 946 } 947 948 // Len returns the message length when in (un)compressed wire format. 949 // If dns.Compress is true compression it is taken into account. Len() 950 // is provided to be a faster way to get the size of the resulting packet, 951 // than packing it, measuring the size and discarding the buffer. 952 func (dns *Msg) Len() int { 953 // If this message can't be compressed, avoid filling the 954 // compression map and creating garbage. 955 if dns.Compress && dns.isCompressible() { 956 compression := make(map[string]struct{}) 957 return msgLenWithCompressionMap(dns, compression) 958 } 959 960 return msgLenWithCompressionMap(dns, nil) 961 } 962 963 func msgLenWithCompressionMap(dns *Msg, compression map[string]struct{}) int { 964 l := headerSize 965 966 for _, r := range dns.Question { 967 l += r.len(l, compression) 968 } 969 for _, r := range dns.Answer { 970 if r != nil { 971 l += r.len(l, compression) 972 } 973 } 974 for _, r := range dns.Ns { 975 if r != nil { 976 l += r.len(l, compression) 977 } 978 } 979 for _, r := range dns.Extra { 980 if r != nil { 981 l += r.len(l, compression) 982 } 983 } 984 985 return l 986 } 987 988 func domainNameLen(s string, off int, compression map[string]struct{}, compress bool) int { 989 if s == "" || s == "." { 990 return 1 991 } 992 993 escaped := strings.Contains(s, "\\") 994 995 if compression != nil && (compress || off < maxCompressionOffset) { 996 // compressionLenSearch will insert the entry into the compression 997 // map if it doesn't contain it. 998 if l, ok := compressionLenSearch(compression, s, off); ok && compress { 999 if escaped { 1000 return escapedNameLen(s[:l]) + 2 1001 } 1002 1003 return l + 2 1004 } 1005 } 1006 1007 if escaped { 1008 return escapedNameLen(s) + 1 1009 } 1010 1011 return len(s) + 1 1012 } 1013 1014 func escapedNameLen(s string) int { 1015 nameLen := len(s) 1016 for i := 0; i < len(s); i++ { 1017 if s[i] != '\\' { 1018 continue 1019 } 1020 1021 if isDDD(s[i+1:]) { 1022 nameLen -= 3 1023 i += 3 1024 } else { 1025 nameLen-- 1026 i++ 1027 } 1028 } 1029 1030 return nameLen 1031 } 1032 1033 func compressionLenSearch(c map[string]struct{}, s string, msgOff int) (int, bool) { 1034 for off, end := 0, false; !end; off, end = NextLabel(s, off) { 1035 if _, ok := c[s[off:]]; ok { 1036 return off, true 1037 } 1038 1039 if msgOff+off < maxCompressionOffset { 1040 c[s[off:]] = struct{}{} 1041 } 1042 } 1043 1044 return 0, false 1045 } 1046 1047 // Copy returns a new RR which is a deep-copy of r. 1048 func Copy(r RR) RR { return r.copy() } 1049 1050 // Len returns the length (in octets) of the uncompressed RR in wire format. 1051 func Len(r RR) int { return r.len(0, nil) } 1052 1053 // Copy returns a new *Msg which is a deep-copy of dns. 1054 func (dns *Msg) Copy() *Msg { return dns.CopyTo(new(Msg)) } 1055 1056 // CopyTo copies the contents to the provided message using a deep-copy and returns the copy. 1057 func (dns *Msg) CopyTo(r1 *Msg) *Msg { 1058 r1.MsgHdr = dns.MsgHdr 1059 r1.Compress = dns.Compress 1060 1061 if len(dns.Question) > 0 { 1062 // TODO(miek): Question is an immutable value, ok to do a shallow-copy 1063 r1.Question = cloneSlice(dns.Question) 1064 } 1065 1066 rrArr := make([]RR, len(dns.Answer)+len(dns.Ns)+len(dns.Extra)) 1067 r1.Answer, rrArr = rrArr[:0:len(dns.Answer)], rrArr[len(dns.Answer):] 1068 r1.Ns, rrArr = rrArr[:0:len(dns.Ns)], rrArr[len(dns.Ns):] 1069 r1.Extra = rrArr[:0:len(dns.Extra)] 1070 1071 for _, r := range dns.Answer { 1072 r1.Answer = append(r1.Answer, r.copy()) 1073 } 1074 1075 for _, r := range dns.Ns { 1076 r1.Ns = append(r1.Ns, r.copy()) 1077 } 1078 1079 for _, r := range dns.Extra { 1080 r1.Extra = append(r1.Extra, r.copy()) 1081 } 1082 1083 return r1 1084 } 1085 1086 func (q *Question) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) { 1087 off, err := packDomainName(q.Name, msg, off, compression, compress) 1088 if err != nil { 1089 return off, err 1090 } 1091 off, err = packUint16(q.Qtype, msg, off) 1092 if err != nil { 1093 return off, err 1094 } 1095 off, err = packUint16(q.Qclass, msg, off) 1096 if err != nil { 1097 return off, err 1098 } 1099 return off, nil 1100 } 1101 1102 func unpackQuestion(msg []byte, off int) (Question, int, error) { 1103 var ( 1104 q Question 1105 err error 1106 ) 1107 q.Name, off, err = UnpackDomainName(msg, off) 1108 if err != nil { 1109 return q, off, err 1110 } 1111 if off == len(msg) { 1112 return q, off, nil 1113 } 1114 q.Qtype, off, err = unpackUint16(msg, off) 1115 if err != nil { 1116 return q, off, err 1117 } 1118 if off == len(msg) { 1119 return q, off, nil 1120 } 1121 q.Qclass, off, err = unpackUint16(msg, off) 1122 if off == len(msg) { 1123 return q, off, nil 1124 } 1125 return q, off, err 1126 } 1127 1128 func (dh *Header) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) { 1129 off, err := packUint16(dh.Id, msg, off) 1130 if err != nil { 1131 return off, err 1132 } 1133 off, err = packUint16(dh.Bits, msg, off) 1134 if err != nil { 1135 return off, err 1136 } 1137 off, err = packUint16(dh.Qdcount, msg, off) 1138 if err != nil { 1139 return off, err 1140 } 1141 off, err = packUint16(dh.Ancount, msg, off) 1142 if err != nil { 1143 return off, err 1144 } 1145 off, err = packUint16(dh.Nscount, msg, off) 1146 if err != nil { 1147 return off, err 1148 } 1149 off, err = packUint16(dh.Arcount, msg, off) 1150 if err != nil { 1151 return off, err 1152 } 1153 return off, nil 1154 } 1155 1156 func unpackMsgHdr(msg []byte, off int) (Header, int, error) { 1157 var ( 1158 dh Header 1159 err error 1160 ) 1161 dh.Id, off, err = unpackUint16(msg, off) 1162 if err != nil { 1163 return dh, off, err 1164 } 1165 dh.Bits, off, err = unpackUint16(msg, off) 1166 if err != nil { 1167 return dh, off, err 1168 } 1169 dh.Qdcount, off, err = unpackUint16(msg, off) 1170 if err != nil { 1171 return dh, off, err 1172 } 1173 dh.Ancount, off, err = unpackUint16(msg, off) 1174 if err != nil { 1175 return dh, off, err 1176 } 1177 dh.Nscount, off, err = unpackUint16(msg, off) 1178 if err != nil { 1179 return dh, off, err 1180 } 1181 dh.Arcount, off, err = unpackUint16(msg, off) 1182 if err != nil { 1183 return dh, off, err 1184 } 1185 return dh, off, nil 1186 } 1187 1188 // setHdr set the header in the dns using the binary data in dh. 1189 func (dns *Msg) setHdr(dh Header) { 1190 dns.Id = dh.Id 1191 dns.Response = dh.Bits&_QR != 0 1192 dns.Opcode = int(dh.Bits>>11) & 0xF 1193 dns.Authoritative = dh.Bits&_AA != 0 1194 dns.Truncated = dh.Bits&_TC != 0 1195 dns.RecursionDesired = dh.Bits&_RD != 0 1196 dns.RecursionAvailable = dh.Bits&_RA != 0 1197 dns.Zero = dh.Bits&_Z != 0 // _Z covers the zero bit, which should be zero; not sure why we set it to the opposite. 1198 dns.AuthenticatedData = dh.Bits&_AD != 0 1199 dns.CheckingDisabled = dh.Bits&_CD != 0 1200 dns.Rcode = int(dh.Bits & 0xF) 1201 }