copy.go (2825B)
1 package fastcopy 2 3 import ( 4 "io" 5 "sync" 6 _ "unsafe" // link to io.errInvalidWrite. 7 ) 8 9 var ( 10 // global pool instance. 11 pool = CopyPool{size: 4096} 12 13 //go:linkname errInvalidWrite io.errInvalidWrite 14 errInvalidWrite error 15 ) 16 17 // CopyPool provides a memory pool of byte 18 // buffers for io copies from readers to writers. 19 type CopyPool struct { 20 size int 21 pool sync.Pool 22 } 23 24 // See CopyPool.Buffer(). 25 func Buffer(sz int) int { 26 return pool.Buffer(sz) 27 } 28 29 // See CopyPool.CopyN(). 30 func CopyN(dst io.Writer, src io.Reader, n int64) (int64, error) { 31 return pool.CopyN(dst, src, n) 32 } 33 34 // See CopyPool.Copy(). 35 func Copy(dst io.Writer, src io.Reader) (int64, error) { 36 return pool.Copy(dst, src) 37 } 38 39 // Buffer sets the pool buffer size to allocate. Returns current size. 40 // Note this is NOT atomically safe, please call BEFORE other calls to CopyPool. 41 func (cp *CopyPool) Buffer(sz int) int { 42 if sz > 0 { 43 // update size 44 cp.size = sz 45 } else if cp.size < 1 { 46 // default size 47 return 4096 48 } 49 return cp.size 50 } 51 52 // CopyN performs the same logic as io.CopyN(), with the difference 53 // being that the byte buffer is acquired from a memory pool. 54 func (cp *CopyPool) CopyN(dst io.Writer, src io.Reader, n int64) (int64, error) { 55 written, err := cp.Copy(dst, io.LimitReader(src, n)) 56 if written == n { 57 return n, nil 58 } 59 if written < n && err == nil { 60 // src stopped early; must have been EOF. 61 err = io.EOF 62 } 63 return written, err 64 } 65 66 // Copy performs the same logic as io.Copy(), with the difference 67 // being that the byte buffer is acquired from a memory pool. 68 func (cp *CopyPool) Copy(dst io.Writer, src io.Reader) (int64, error) { 69 // Prefer using io.WriterTo to do the copy (avoids alloc + copy) 70 if wt, ok := src.(io.WriterTo); ok { 71 return wt.WriteTo(dst) 72 } 73 74 // Prefer using io.ReaderFrom to do the copy. 75 if rt, ok := dst.(io.ReaderFrom); ok { 76 return rt.ReadFrom(src) 77 } 78 79 var buf []byte 80 81 if b, ok := cp.pool.Get().(*[]byte); ok { 82 // Acquired buf from pool 83 buf = *b 84 } else { 85 // Allocate new buffer of size 86 buf = make([]byte, cp.Buffer(0)) 87 } 88 89 // Defer release to pool 90 defer cp.pool.Put(&buf) 91 92 var n int64 93 for { 94 // Perform next read into buf 95 nr, err := src.Read(buf) 96 if nr > 0 { 97 // We error check AFTER checking 98 // no. read bytes so incomplete 99 // read still gets written up to nr. 100 101 // Perform next write from buf 102 nw, ew := dst.Write(buf[0:nr]) 103 104 // Check for valid write 105 if nw < 0 || nr < nw { 106 if ew == nil { 107 ew = errInvalidWrite 108 } 109 return n, ew 110 } 111 112 // Incr total count 113 n += int64(nw) 114 115 // Check write error 116 if ew != nil { 117 return n, ew 118 } 119 120 // Check unequal read/writes 121 if nr != nw { 122 return n, io.ErrShortWrite 123 } 124 } 125 126 // Return on err 127 if err != nil { 128 if err == io.EOF { 129 err = nil // expected 130 } 131 return n, err 132 } 133 } 134 }