decode.go (12001B)
1 // Copyright 2011 The Snappy-Go Authors. All rights reserved. 2 // Copyright (c) 2019 Klaus Post. All rights reserved. 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 package s2 7 8 import ( 9 "encoding/binary" 10 "errors" 11 "fmt" 12 "strconv" 13 ) 14 15 var ( 16 // ErrCorrupt reports that the input is invalid. 17 ErrCorrupt = errors.New("s2: corrupt input") 18 // ErrCRC reports that the input failed CRC validation (streams only) 19 ErrCRC = errors.New("s2: corrupt input, crc mismatch") 20 // ErrTooLarge reports that the uncompressed length is too large. 21 ErrTooLarge = errors.New("s2: decoded block is too large") 22 // ErrUnsupported reports that the input isn't supported. 23 ErrUnsupported = errors.New("s2: unsupported input") 24 ) 25 26 // DecodedLen returns the length of the decoded block. 27 func DecodedLen(src []byte) (int, error) { 28 v, _, err := decodedLen(src) 29 return v, err 30 } 31 32 // decodedLen returns the length of the decoded block and the number of bytes 33 // that the length header occupied. 34 func decodedLen(src []byte) (blockLen, headerLen int, err error) { 35 v, n := binary.Uvarint(src) 36 if n <= 0 || v > 0xffffffff { 37 return 0, 0, ErrCorrupt 38 } 39 40 const wordSize = 32 << (^uint(0) >> 32 & 1) 41 if wordSize == 32 && v > 0x7fffffff { 42 return 0, 0, ErrTooLarge 43 } 44 return int(v), n, nil 45 } 46 47 const ( 48 decodeErrCodeCorrupt = 1 49 ) 50 51 // Decode returns the decoded form of src. The returned slice may be a sub- 52 // slice of dst if dst was large enough to hold the entire decoded block. 53 // Otherwise, a newly allocated slice will be returned. 54 // 55 // The dst and src must not overlap. It is valid to pass a nil dst. 56 func Decode(dst, src []byte) ([]byte, error) { 57 dLen, s, err := decodedLen(src) 58 if err != nil { 59 return nil, err 60 } 61 if dLen <= cap(dst) { 62 dst = dst[:dLen] 63 } else { 64 dst = make([]byte, dLen) 65 } 66 if s2Decode(dst, src[s:]) != 0 { 67 return nil, ErrCorrupt 68 } 69 return dst, nil 70 } 71 72 // s2DecodeDict writes the decoding of src to dst. It assumes that the varint-encoded 73 // length of the decompressed bytes has already been read, and that len(dst) 74 // equals that length. 75 // 76 // It returns 0 on success or a decodeErrCodeXxx error code on failure. 77 func s2DecodeDict(dst, src []byte, dict *Dict) int { 78 if dict == nil { 79 return s2Decode(dst, src) 80 } 81 const debug = false 82 const debugErrs = debug 83 84 if debug { 85 fmt.Println("Starting decode, dst len:", len(dst)) 86 } 87 var d, s, length int 88 offset := len(dict.dict) - dict.repeat 89 90 // As long as we can read at least 5 bytes... 91 for s < len(src)-5 { 92 // Removing bounds checks is SLOWER, when if doing 93 // in := src[s:s+5] 94 // Checked on Go 1.18 95 switch src[s] & 0x03 { 96 case tagLiteral: 97 x := uint32(src[s] >> 2) 98 switch { 99 case x < 60: 100 s++ 101 case x == 60: 102 s += 2 103 x = uint32(src[s-1]) 104 case x == 61: 105 in := src[s : s+3] 106 x = uint32(in[1]) | uint32(in[2])<<8 107 s += 3 108 case x == 62: 109 in := src[s : s+4] 110 // Load as 32 bit and shift down. 111 x = uint32(in[0]) | uint32(in[1])<<8 | uint32(in[2])<<16 | uint32(in[3])<<24 112 x >>= 8 113 s += 4 114 case x == 63: 115 in := src[s : s+5] 116 x = uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24 117 s += 5 118 } 119 length = int(x) + 1 120 if debug { 121 fmt.Println("literals, length:", length, "d-after:", d+length) 122 } 123 if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) { 124 if debugErrs { 125 fmt.Println("corrupt literal: length:", length, "d-left:", len(dst)-d, "src-left:", len(src)-s) 126 } 127 return decodeErrCodeCorrupt 128 } 129 130 copy(dst[d:], src[s:s+length]) 131 d += length 132 s += length 133 continue 134 135 case tagCopy1: 136 s += 2 137 toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1])) 138 length = int(src[s-2]) >> 2 & 0x7 139 if toffset == 0 { 140 if debug { 141 fmt.Print("(repeat) ") 142 } 143 // keep last offset 144 switch length { 145 case 5: 146 length = int(src[s]) + 4 147 s += 1 148 case 6: 149 in := src[s : s+2] 150 length = int(uint32(in[0])|(uint32(in[1])<<8)) + (1 << 8) 151 s += 2 152 case 7: 153 in := src[s : s+3] 154 length = int((uint32(in[2])<<16)|(uint32(in[1])<<8)|uint32(in[0])) + (1 << 16) 155 s += 3 156 default: // 0-> 4 157 } 158 } else { 159 offset = toffset 160 } 161 length += 4 162 case tagCopy2: 163 in := src[s : s+3] 164 offset = int(uint32(in[1]) | uint32(in[2])<<8) 165 length = 1 + int(in[0])>>2 166 s += 3 167 168 case tagCopy4: 169 in := src[s : s+5] 170 offset = int(uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24) 171 length = 1 + int(in[0])>>2 172 s += 5 173 } 174 175 if offset <= 0 || length > len(dst)-d { 176 if debugErrs { 177 fmt.Println("match error; offset:", offset, "length:", length, "dst-left:", len(dst)-d) 178 } 179 return decodeErrCodeCorrupt 180 } 181 182 // copy from dict 183 if d < offset { 184 if d > MaxDictSrcOffset { 185 if debugErrs { 186 fmt.Println("dict after", MaxDictSrcOffset, "d:", d, "offset:", offset, "length:", length) 187 } 188 return decodeErrCodeCorrupt 189 } 190 startOff := len(dict.dict) - offset + d 191 if startOff < 0 || startOff+length > len(dict.dict) { 192 if debugErrs { 193 fmt.Printf("offset (%d) + length (%d) bigger than dict (%d)\n", offset, length, len(dict.dict)) 194 } 195 return decodeErrCodeCorrupt 196 } 197 if debug { 198 fmt.Println("dict copy, length:", length, "offset:", offset, "d-after:", d+length, "dict start offset:", startOff) 199 } 200 copy(dst[d:d+length], dict.dict[startOff:]) 201 d += length 202 continue 203 } 204 205 if debug { 206 fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length) 207 } 208 209 // Copy from an earlier sub-slice of dst to a later sub-slice. 210 // If no overlap, use the built-in copy: 211 if offset > length { 212 copy(dst[d:d+length], dst[d-offset:]) 213 d += length 214 continue 215 } 216 217 // Unlike the built-in copy function, this byte-by-byte copy always runs 218 // forwards, even if the slices overlap. Conceptually, this is: 219 // 220 // d += forwardCopy(dst[d:d+length], dst[d-offset:]) 221 // 222 // We align the slices into a and b and show the compiler they are the same size. 223 // This allows the loop to run without bounds checks. 224 a := dst[d : d+length] 225 b := dst[d-offset:] 226 b = b[:len(a)] 227 for i := range a { 228 a[i] = b[i] 229 } 230 d += length 231 } 232 233 // Remaining with extra checks... 234 for s < len(src) { 235 switch src[s] & 0x03 { 236 case tagLiteral: 237 x := uint32(src[s] >> 2) 238 switch { 239 case x < 60: 240 s++ 241 case x == 60: 242 s += 2 243 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. 244 if debugErrs { 245 fmt.Println("src went oob") 246 } 247 return decodeErrCodeCorrupt 248 } 249 x = uint32(src[s-1]) 250 case x == 61: 251 s += 3 252 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. 253 if debugErrs { 254 fmt.Println("src went oob") 255 } 256 return decodeErrCodeCorrupt 257 } 258 x = uint32(src[s-2]) | uint32(src[s-1])<<8 259 case x == 62: 260 s += 4 261 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. 262 if debugErrs { 263 fmt.Println("src went oob") 264 } 265 return decodeErrCodeCorrupt 266 } 267 x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16 268 case x == 63: 269 s += 5 270 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. 271 if debugErrs { 272 fmt.Println("src went oob") 273 } 274 return decodeErrCodeCorrupt 275 } 276 x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24 277 } 278 length = int(x) + 1 279 if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) { 280 if debugErrs { 281 fmt.Println("corrupt literal: length:", length, "d-left:", len(dst)-d, "src-left:", len(src)-s) 282 } 283 return decodeErrCodeCorrupt 284 } 285 if debug { 286 fmt.Println("literals, length:", length, "d-after:", d+length) 287 } 288 289 copy(dst[d:], src[s:s+length]) 290 d += length 291 s += length 292 continue 293 294 case tagCopy1: 295 s += 2 296 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. 297 if debugErrs { 298 fmt.Println("src went oob") 299 } 300 return decodeErrCodeCorrupt 301 } 302 length = int(src[s-2]) >> 2 & 0x7 303 toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1])) 304 if toffset == 0 { 305 if debug { 306 fmt.Print("(repeat) ") 307 } 308 // keep last offset 309 switch length { 310 case 5: 311 s += 1 312 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. 313 if debugErrs { 314 fmt.Println("src went oob") 315 } 316 return decodeErrCodeCorrupt 317 } 318 length = int(uint32(src[s-1])) + 4 319 case 6: 320 s += 2 321 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. 322 if debugErrs { 323 fmt.Println("src went oob") 324 } 325 return decodeErrCodeCorrupt 326 } 327 length = int(uint32(src[s-2])|(uint32(src[s-1])<<8)) + (1 << 8) 328 case 7: 329 s += 3 330 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. 331 if debugErrs { 332 fmt.Println("src went oob") 333 } 334 return decodeErrCodeCorrupt 335 } 336 length = int(uint32(src[s-3])|(uint32(src[s-2])<<8)|(uint32(src[s-1])<<16)) + (1 << 16) 337 default: // 0-> 4 338 } 339 } else { 340 offset = toffset 341 } 342 length += 4 343 case tagCopy2: 344 s += 3 345 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. 346 if debugErrs { 347 fmt.Println("src went oob") 348 } 349 return decodeErrCodeCorrupt 350 } 351 length = 1 + int(src[s-3])>>2 352 offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8) 353 354 case tagCopy4: 355 s += 5 356 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. 357 if debugErrs { 358 fmt.Println("src went oob") 359 } 360 return decodeErrCodeCorrupt 361 } 362 length = 1 + int(src[s-5])>>2 363 offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24) 364 } 365 366 if offset <= 0 || length > len(dst)-d { 367 if debugErrs { 368 fmt.Println("match error; offset:", offset, "length:", length, "dst-left:", len(dst)-d) 369 } 370 return decodeErrCodeCorrupt 371 } 372 373 // copy from dict 374 if d < offset { 375 if d > MaxDictSrcOffset { 376 if debugErrs { 377 fmt.Println("dict after", MaxDictSrcOffset, "d:", d, "offset:", offset, "length:", length) 378 } 379 return decodeErrCodeCorrupt 380 } 381 rOff := len(dict.dict) - (offset - d) 382 if debug { 383 fmt.Println("starting dict entry from dict offset", len(dict.dict)-rOff) 384 } 385 if rOff+length > len(dict.dict) { 386 if debugErrs { 387 fmt.Println("err: END offset", rOff+length, "bigger than dict", len(dict.dict), "dict offset:", rOff, "length:", length) 388 } 389 return decodeErrCodeCorrupt 390 } 391 if rOff < 0 { 392 if debugErrs { 393 fmt.Println("err: START offset", rOff, "less than 0", len(dict.dict), "dict offset:", rOff, "length:", length) 394 } 395 return decodeErrCodeCorrupt 396 } 397 copy(dst[d:d+length], dict.dict[rOff:]) 398 d += length 399 continue 400 } 401 402 if debug { 403 fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length) 404 } 405 406 // Copy from an earlier sub-slice of dst to a later sub-slice. 407 // If no overlap, use the built-in copy: 408 if offset > length { 409 copy(dst[d:d+length], dst[d-offset:]) 410 d += length 411 continue 412 } 413 414 // Unlike the built-in copy function, this byte-by-byte copy always runs 415 // forwards, even if the slices overlap. Conceptually, this is: 416 // 417 // d += forwardCopy(dst[d:d+length], dst[d-offset:]) 418 // 419 // We align the slices into a and b and show the compiler they are the same size. 420 // This allows the loop to run without bounds checks. 421 a := dst[d : d+length] 422 b := dst[d-offset:] 423 b = b[:len(a)] 424 for i := range a { 425 a[i] = b[i] 426 } 427 d += length 428 } 429 430 if d != len(dst) { 431 if debugErrs { 432 fmt.Println("wanted length", len(dst), "got", d) 433 } 434 return decodeErrCodeCorrupt 435 } 436 return 0 437 }