compression.go (3192B)
1 // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package websocket 6 7 import ( 8 "compress/flate" 9 "errors" 10 "io" 11 "strings" 12 "sync" 13 ) 14 15 const ( 16 minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 17 maxCompressionLevel = flate.BestCompression 18 defaultCompressionLevel = 1 19 ) 20 21 var ( 22 flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool 23 flateReaderPool = sync.Pool{New: func() interface{} { 24 return flate.NewReader(nil) 25 }} 26 ) 27 28 func decompressNoContextTakeover(r io.Reader) io.ReadCloser { 29 const tail = 30 // Add four bytes as specified in RFC 31 "\x00\x00\xff\xff" + 32 // Add final block to squelch unexpected EOF error from flate reader. 33 "\x01\x00\x00\xff\xff" 34 35 fr, _ := flateReaderPool.Get().(io.ReadCloser) 36 fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) 37 return &flateReadWrapper{fr} 38 } 39 40 func isValidCompressionLevel(level int) bool { 41 return minCompressionLevel <= level && level <= maxCompressionLevel 42 } 43 44 func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { 45 p := &flateWriterPools[level-minCompressionLevel] 46 tw := &truncWriter{w: w} 47 fw, _ := p.Get().(*flate.Writer) 48 if fw == nil { 49 fw, _ = flate.NewWriter(tw, level) 50 } else { 51 fw.Reset(tw) 52 } 53 return &flateWriteWrapper{fw: fw, tw: tw, p: p} 54 } 55 56 // truncWriter is an io.Writer that writes all but the last four bytes of the 57 // stream to another io.Writer. 58 type truncWriter struct { 59 w io.WriteCloser 60 n int 61 p [4]byte 62 } 63 64 func (w *truncWriter) Write(p []byte) (int, error) { 65 n := 0 66 67 // fill buffer first for simplicity. 68 if w.n < len(w.p) { 69 n = copy(w.p[w.n:], p) 70 p = p[n:] 71 w.n += n 72 if len(p) == 0 { 73 return n, nil 74 } 75 } 76 77 m := len(p) 78 if m > len(w.p) { 79 m = len(w.p) 80 } 81 82 if nn, err := w.w.Write(w.p[:m]); err != nil { 83 return n + nn, err 84 } 85 86 copy(w.p[:], w.p[m:]) 87 copy(w.p[len(w.p)-m:], p[len(p)-m:]) 88 nn, err := w.w.Write(p[:len(p)-m]) 89 return n + nn, err 90 } 91 92 type flateWriteWrapper struct { 93 fw *flate.Writer 94 tw *truncWriter 95 p *sync.Pool 96 } 97 98 func (w *flateWriteWrapper) Write(p []byte) (int, error) { 99 if w.fw == nil { 100 return 0, errWriteClosed 101 } 102 return w.fw.Write(p) 103 } 104 105 func (w *flateWriteWrapper) Close() error { 106 if w.fw == nil { 107 return errWriteClosed 108 } 109 err1 := w.fw.Flush() 110 w.p.Put(w.fw) 111 w.fw = nil 112 if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { 113 return errors.New("websocket: internal error, unexpected bytes at end of flate stream") 114 } 115 err2 := w.tw.w.Close() 116 if err1 != nil { 117 return err1 118 } 119 return err2 120 } 121 122 type flateReadWrapper struct { 123 fr io.ReadCloser 124 } 125 126 func (r *flateReadWrapper) Read(p []byte) (int, error) { 127 if r.fr == nil { 128 return 0, io.ErrClosedPipe 129 } 130 n, err := r.fr.Read(p) 131 if err == io.EOF { 132 // Preemptively place the reader back in the pool. This helps with 133 // scenarios where the application does not call NextReader() soon after 134 // this final read. 135 r.Close() 136 } 137 return n, err 138 } 139 140 func (r *flateReadWrapper) Close() error { 141 if r.fr == nil { 142 return io.ErrClosedPipe 143 } 144 err := r.fr.Close() 145 flateReaderPool.Put(r.fr) 146 r.fr = nil 147 return err 148 }