huffman.go (5962B)
1 // Copyright 2014 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package hpack 6 7 import ( 8 "bytes" 9 "errors" 10 "io" 11 "sync" 12 ) 13 14 var bufPool = sync.Pool{ 15 New: func() interface{} { return new(bytes.Buffer) }, 16 } 17 18 // HuffmanDecode decodes the string in v and writes the expanded 19 // result to w, returning the number of bytes written to w and the 20 // Write call's return value. At most one Write call is made. 21 func HuffmanDecode(w io.Writer, v []byte) (int, error) { 22 buf := bufPool.Get().(*bytes.Buffer) 23 buf.Reset() 24 defer bufPool.Put(buf) 25 if err := huffmanDecode(buf, 0, v); err != nil { 26 return 0, err 27 } 28 return w.Write(buf.Bytes()) 29 } 30 31 // HuffmanDecodeToString decodes the string in v. 32 func HuffmanDecodeToString(v []byte) (string, error) { 33 buf := bufPool.Get().(*bytes.Buffer) 34 buf.Reset() 35 defer bufPool.Put(buf) 36 if err := huffmanDecode(buf, 0, v); err != nil { 37 return "", err 38 } 39 return buf.String(), nil 40 } 41 42 // ErrInvalidHuffman is returned for errors found decoding 43 // Huffman-encoded strings. 44 var ErrInvalidHuffman = errors.New("hpack: invalid Huffman-encoded data") 45 46 // huffmanDecode decodes v to buf. 47 // If maxLen is greater than 0, attempts to write more to buf than 48 // maxLen bytes will return ErrStringLength. 49 func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error { 50 rootHuffmanNode := getRootHuffmanNode() 51 n := rootHuffmanNode 52 // cur is the bit buffer that has not been fed into n. 53 // cbits is the number of low order bits in cur that are valid. 54 // sbits is the number of bits of the symbol prefix being decoded. 55 cur, cbits, sbits := uint(0), uint8(0), uint8(0) 56 for _, b := range v { 57 cur = cur<<8 | uint(b) 58 cbits += 8 59 sbits += 8 60 for cbits >= 8 { 61 idx := byte(cur >> (cbits - 8)) 62 n = n.children[idx] 63 if n == nil { 64 return ErrInvalidHuffman 65 } 66 if n.children == nil { 67 if maxLen != 0 && buf.Len() == maxLen { 68 return ErrStringLength 69 } 70 buf.WriteByte(n.sym) 71 cbits -= n.codeLen 72 n = rootHuffmanNode 73 sbits = cbits 74 } else { 75 cbits -= 8 76 } 77 } 78 } 79 for cbits > 0 { 80 n = n.children[byte(cur<<(8-cbits))] 81 if n == nil { 82 return ErrInvalidHuffman 83 } 84 if n.children != nil || n.codeLen > cbits { 85 break 86 } 87 if maxLen != 0 && buf.Len() == maxLen { 88 return ErrStringLength 89 } 90 buf.WriteByte(n.sym) 91 cbits -= n.codeLen 92 n = rootHuffmanNode 93 sbits = cbits 94 } 95 if sbits > 7 { 96 // Either there was an incomplete symbol, or overlong padding. 97 // Both are decoding errors per RFC 7541 section 5.2. 98 return ErrInvalidHuffman 99 } 100 if mask := uint(1<<cbits - 1); cur&mask != mask { 101 // Trailing bits must be a prefix of EOS per RFC 7541 section 5.2. 102 return ErrInvalidHuffman 103 } 104 105 return nil 106 } 107 108 // incomparable is a zero-width, non-comparable type. Adding it to a struct 109 // makes that struct also non-comparable, and generally doesn't add 110 // any size (as long as it's first). 111 type incomparable [0]func() 112 113 type node struct { 114 _ incomparable 115 116 // children is non-nil for internal nodes 117 children *[256]*node 118 119 // The following are only valid if children is nil: 120 codeLen uint8 // number of bits that led to the output of sym 121 sym byte // output symbol 122 } 123 124 func newInternalNode() *node { 125 return &node{children: new([256]*node)} 126 } 127 128 var ( 129 buildRootOnce sync.Once 130 lazyRootHuffmanNode *node 131 ) 132 133 func getRootHuffmanNode() *node { 134 buildRootOnce.Do(buildRootHuffmanNode) 135 return lazyRootHuffmanNode 136 } 137 138 func buildRootHuffmanNode() { 139 if len(huffmanCodes) != 256 { 140 panic("unexpected size") 141 } 142 lazyRootHuffmanNode = newInternalNode() 143 // allocate a leaf node for each of the 256 symbols 144 leaves := new([256]node) 145 146 for sym, code := range huffmanCodes { 147 codeLen := huffmanCodeLen[sym] 148 149 cur := lazyRootHuffmanNode 150 for codeLen > 8 { 151 codeLen -= 8 152 i := uint8(code >> codeLen) 153 if cur.children[i] == nil { 154 cur.children[i] = newInternalNode() 155 } 156 cur = cur.children[i] 157 } 158 shift := 8 - codeLen 159 start, end := int(uint8(code<<shift)), int(1<<shift) 160 161 leaves[sym].sym = byte(sym) 162 leaves[sym].codeLen = codeLen 163 for i := start; i < start+end; i++ { 164 cur.children[i] = &leaves[sym] 165 } 166 } 167 } 168 169 // AppendHuffmanString appends s, as encoded in Huffman codes, to dst 170 // and returns the extended buffer. 171 func AppendHuffmanString(dst []byte, s string) []byte { 172 // This relies on the maximum huffman code length being 30 (See tables.go huffmanCodeLen array) 173 // So if a uint64 buffer has less than 32 valid bits can always accommodate another huffmanCode. 174 var ( 175 x uint64 // buffer 176 n uint // number valid of bits present in x 177 ) 178 for i := 0; i < len(s); i++ { 179 c := s[i] 180 n += uint(huffmanCodeLen[c]) 181 x <<= huffmanCodeLen[c] % 64 182 x |= uint64(huffmanCodes[c]) 183 if n >= 32 { 184 n %= 32 // Normally would be -= 32 but %= 32 informs compiler 0 <= n <= 31 for upcoming shift 185 y := uint32(x >> n) // Compiler doesn't combine memory writes if y isn't uint32 186 dst = append(dst, byte(y>>24), byte(y>>16), byte(y>>8), byte(y)) 187 } 188 } 189 // Add padding bits if necessary 190 if over := n % 8; over > 0 { 191 const ( 192 eosCode = 0x3fffffff 193 eosNBits = 30 194 eosPadByte = eosCode >> (eosNBits - 8) 195 ) 196 pad := 8 - over 197 x = (x << pad) | (eosPadByte >> over) 198 n += pad // 8 now divides into n exactly 199 } 200 // n in (0, 8, 16, 24, 32) 201 switch n / 8 { 202 case 0: 203 return dst 204 case 1: 205 return append(dst, byte(x)) 206 case 2: 207 y := uint16(x) 208 return append(dst, byte(y>>8), byte(y)) 209 case 3: 210 y := uint16(x >> 8) 211 return append(dst, byte(y>>8), byte(y), byte(x)) 212 } 213 // case 4: 214 y := uint32(x) 215 return append(dst, byte(y>>24), byte(y>>16), byte(y>>8), byte(y)) 216 } 217 218 // HuffmanEncodeLength returns the number of bytes required to encode 219 // s in Huffman codes. The result is round up to byte boundary. 220 func HuffmanEncodeLength(s string) uint64 { 221 n := uint64(0) 222 for i := 0; i < len(s); i++ { 223 n += uint64(huffmanCodeLen[s[i]]) 224 } 225 return (n + 7) / 8 226 }