gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

copy.go (2825B)


      1 package fastcopy
      2 
      3 import (
      4 	"io"
      5 	"sync"
      6 	_ "unsafe" // link to io.errInvalidWrite.
      7 )
      8 
      9 var (
     10 	// global pool instance.
     11 	pool = CopyPool{size: 4096}
     12 
     13 	//go:linkname errInvalidWrite io.errInvalidWrite
     14 	errInvalidWrite error
     15 )
     16 
     17 // CopyPool provides a memory pool of byte
     18 // buffers for io copies from readers to writers.
     19 type CopyPool struct {
     20 	size int
     21 	pool sync.Pool
     22 }
     23 
     24 // See CopyPool.Buffer().
     25 func Buffer(sz int) int {
     26 	return pool.Buffer(sz)
     27 }
     28 
     29 // See CopyPool.CopyN().
     30 func CopyN(dst io.Writer, src io.Reader, n int64) (int64, error) {
     31 	return pool.CopyN(dst, src, n)
     32 }
     33 
     34 // See CopyPool.Copy().
     35 func Copy(dst io.Writer, src io.Reader) (int64, error) {
     36 	return pool.Copy(dst, src)
     37 }
     38 
     39 // Buffer sets the pool buffer size to allocate. Returns current size.
     40 // Note this is NOT atomically safe, please call BEFORE other calls to CopyPool.
     41 func (cp *CopyPool) Buffer(sz int) int {
     42 	if sz > 0 {
     43 		// update size
     44 		cp.size = sz
     45 	} else if cp.size < 1 {
     46 		// default size
     47 		return 4096
     48 	}
     49 	return cp.size
     50 }
     51 
     52 // CopyN performs the same logic as io.CopyN(), with the difference
     53 // being that the byte buffer is acquired from a memory pool.
     54 func (cp *CopyPool) CopyN(dst io.Writer, src io.Reader, n int64) (int64, error) {
     55 	written, err := cp.Copy(dst, io.LimitReader(src, n))
     56 	if written == n {
     57 		return n, nil
     58 	}
     59 	if written < n && err == nil {
     60 		// src stopped early; must have been EOF.
     61 		err = io.EOF
     62 	}
     63 	return written, err
     64 }
     65 
     66 // Copy performs the same logic as io.Copy(), with the difference
     67 // being that the byte buffer is acquired from a memory pool.
     68 func (cp *CopyPool) Copy(dst io.Writer, src io.Reader) (int64, error) {
     69 	// Prefer using io.WriterTo to do the copy (avoids alloc + copy)
     70 	if wt, ok := src.(io.WriterTo); ok {
     71 		return wt.WriteTo(dst)
     72 	}
     73 
     74 	// Prefer using io.ReaderFrom to do the copy.
     75 	if rt, ok := dst.(io.ReaderFrom); ok {
     76 		return rt.ReadFrom(src)
     77 	}
     78 
     79 	var buf []byte
     80 
     81 	if b, ok := cp.pool.Get().(*[]byte); ok {
     82 		// Acquired buf from pool
     83 		buf = *b
     84 	} else {
     85 		// Allocate new buffer of size
     86 		buf = make([]byte, cp.Buffer(0))
     87 	}
     88 
     89 	// Defer release to pool
     90 	defer cp.pool.Put(&buf)
     91 
     92 	var n int64
     93 	for {
     94 		// Perform next read into buf
     95 		nr, err := src.Read(buf)
     96 		if nr > 0 {
     97 			// We error check AFTER checking
     98 			// no. read bytes so incomplete
     99 			// read still gets written up to nr.
    100 
    101 			// Perform next write from buf
    102 			nw, ew := dst.Write(buf[0:nr])
    103 
    104 			// Check for valid write
    105 			if nw < 0 || nr < nw {
    106 				if ew == nil {
    107 					ew = errInvalidWrite
    108 				}
    109 				return n, ew
    110 			}
    111 
    112 			// Incr total count
    113 			n += int64(nw)
    114 
    115 			// Check write error
    116 			if ew != nil {
    117 				return n, ew
    118 			}
    119 
    120 			// Check unequal read/writes
    121 			if nr != nw {
    122 				return n, io.ErrShortWrite
    123 			}
    124 		}
    125 
    126 		// Return on err
    127 		if err != nil {
    128 			if err == io.EOF {
    129 				err = nil // expected
    130 			}
    131 			return n, err
    132 		}
    133 	}
    134 }