gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

copy_from.go (6515B)


      1 package pgx
      2 
      3 import (
      4 	"bytes"
      5 	"context"
      6 	"fmt"
      7 	"io"
      8 
      9 	"github.com/jackc/pgx/v5/internal/pgio"
     10 	"github.com/jackc/pgx/v5/pgconn"
     11 )
     12 
     13 // CopyFromRows returns a CopyFromSource interface over the provided rows slice
     14 // making it usable by *Conn.CopyFrom.
     15 func CopyFromRows(rows [][]any) CopyFromSource {
     16 	return &copyFromRows{rows: rows, idx: -1}
     17 }
     18 
     19 type copyFromRows struct {
     20 	rows [][]any
     21 	idx  int
     22 }
     23 
     24 func (ctr *copyFromRows) Next() bool {
     25 	ctr.idx++
     26 	return ctr.idx < len(ctr.rows)
     27 }
     28 
     29 func (ctr *copyFromRows) Values() ([]any, error) {
     30 	return ctr.rows[ctr.idx], nil
     31 }
     32 
     33 func (ctr *copyFromRows) Err() error {
     34 	return nil
     35 }
     36 
     37 // CopyFromSlice returns a CopyFromSource interface over a dynamic func
     38 // making it usable by *Conn.CopyFrom.
     39 func CopyFromSlice(length int, next func(int) ([]any, error)) CopyFromSource {
     40 	return &copyFromSlice{next: next, idx: -1, len: length}
     41 }
     42 
     43 type copyFromSlice struct {
     44 	next func(int) ([]any, error)
     45 	idx  int
     46 	len  int
     47 	err  error
     48 }
     49 
     50 func (cts *copyFromSlice) Next() bool {
     51 	cts.idx++
     52 	return cts.idx < cts.len
     53 }
     54 
     55 func (cts *copyFromSlice) Values() ([]any, error) {
     56 	values, err := cts.next(cts.idx)
     57 	if err != nil {
     58 		cts.err = err
     59 	}
     60 	return values, err
     61 }
     62 
     63 func (cts *copyFromSlice) Err() error {
     64 	return cts.err
     65 }
     66 
     67 // CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
     68 type CopyFromSource interface {
     69 	// Next returns true if there is another row and makes the next row data
     70 	// available to Values(). When there are no more rows available or an error
     71 	// has occurred it returns false.
     72 	Next() bool
     73 
     74 	// Values returns the values for the current row.
     75 	Values() ([]any, error)
     76 
     77 	// Err returns any error that has been encountered by the CopyFromSource. If
     78 	// this is not nil *Conn.CopyFrom will abort the copy.
     79 	Err() error
     80 }
     81 
     82 type copyFrom struct {
     83 	conn          *Conn
     84 	tableName     Identifier
     85 	columnNames   []string
     86 	rowSrc        CopyFromSource
     87 	readerErrChan chan error
     88 	mode          QueryExecMode
     89 }
     90 
     91 func (ct *copyFrom) run(ctx context.Context) (int64, error) {
     92 	if ct.conn.copyFromTracer != nil {
     93 		ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{
     94 			TableName:   ct.tableName,
     95 			ColumnNames: ct.columnNames,
     96 		})
     97 	}
     98 
     99 	quotedTableName := ct.tableName.Sanitize()
    100 	cbuf := &bytes.Buffer{}
    101 	for i, cn := range ct.columnNames {
    102 		if i != 0 {
    103 			cbuf.WriteString(", ")
    104 		}
    105 		cbuf.WriteString(quoteIdentifier(cn))
    106 	}
    107 	quotedColumnNames := cbuf.String()
    108 
    109 	var sd *pgconn.StatementDescription
    110 	switch ct.mode {
    111 	case QueryExecModeExec, QueryExecModeSimpleProtocol:
    112 		// These modes don't support the binary format. Before the inclusion of the
    113 		// QueryExecModes, Conn.Prepare was called on every COPY operation to get
    114 		// the OIDs. These prepared statements were not cached.
    115 		//
    116 		// Since that's the same behavior provided by QueryExecModeDescribeExec,
    117 		// we'll default to that mode.
    118 		ct.mode = QueryExecModeDescribeExec
    119 		fallthrough
    120 	case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec:
    121 		var err error
    122 		sd, err = ct.conn.getStatementDescription(
    123 			ctx,
    124 			ct.mode,
    125 			fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName),
    126 		)
    127 		if err != nil {
    128 			return 0, fmt.Errorf("statement description failed: %w", err)
    129 		}
    130 	default:
    131 		return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
    132 	}
    133 
    134 	r, w := io.Pipe()
    135 	doneChan := make(chan struct{})
    136 
    137 	go func() {
    138 		defer close(doneChan)
    139 
    140 		// Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283.
    141 		buf := ct.conn.wbuf
    142 
    143 		buf = append(buf, "PGCOPY\n\377\r\n\000"...)
    144 		buf = pgio.AppendInt32(buf, 0)
    145 		buf = pgio.AppendInt32(buf, 0)
    146 
    147 		moreRows := true
    148 		for moreRows {
    149 			var err error
    150 			moreRows, buf, err = ct.buildCopyBuf(buf, sd)
    151 			if err != nil {
    152 				w.CloseWithError(err)
    153 				return
    154 			}
    155 
    156 			if ct.rowSrc.Err() != nil {
    157 				w.CloseWithError(ct.rowSrc.Err())
    158 				return
    159 			}
    160 
    161 			if len(buf) > 0 {
    162 				_, err = w.Write(buf)
    163 				if err != nil {
    164 					w.Close()
    165 					return
    166 				}
    167 			}
    168 
    169 			buf = buf[:0]
    170 		}
    171 
    172 		w.Close()
    173 	}()
    174 
    175 	commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
    176 
    177 	r.Close()
    178 	<-doneChan
    179 
    180 	if ct.conn.copyFromTracer != nil {
    181 		ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{
    182 			CommandTag: commandTag,
    183 			Err:        err,
    184 		})
    185 	}
    186 
    187 	return commandTag.RowsAffected(), err
    188 }
    189 
    190 func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
    191 	const sendBufSize = 65536 - 5 // The packet has a 5-byte header
    192 	lastBufLen := 0
    193 	largestRowLen := 0
    194 
    195 	for ct.rowSrc.Next() {
    196 		lastBufLen = len(buf)
    197 
    198 		values, err := ct.rowSrc.Values()
    199 		if err != nil {
    200 			return false, nil, err
    201 		}
    202 		if len(values) != len(ct.columnNames) {
    203 			return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
    204 		}
    205 
    206 		buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
    207 		for i, val := range values {
    208 			buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val)
    209 			if err != nil {
    210 				return false, nil, err
    211 			}
    212 		}
    213 
    214 		rowLen := len(buf) - lastBufLen
    215 		if rowLen > largestRowLen {
    216 			largestRowLen = rowLen
    217 		}
    218 
    219 		// Try not to overflow size of the buffer PgConn.CopyFrom will be reading into. If that happens then the nature of
    220 		// io.Pipe means that the next Read will be short. This can lead to pathological send sizes such as 65531, 13, 65531
    221 		// 13, 65531, 13, 65531, 13.
    222 		if len(buf) > sendBufSize-largestRowLen {
    223 			return true, buf, nil
    224 		}
    225 	}
    226 
    227 	return false, buf, nil
    228 }
    229 
    230 // CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. It returns the number of rows copied and
    231 // an error.
    232 //
    233 // CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered
    234 // for the type of each column. Almost all types implemented by pgx support the binary format.
    235 //
    236 // Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with
    237 // Conn.LoadType and pgtype.Map.RegisterType.
    238 func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
    239 	ct := &copyFrom{
    240 		conn:          c,
    241 		tableName:     tableName,
    242 		columnNames:   columnNames,
    243 		rowSrc:        rowSrc,
    244 		readerErrChan: make(chan error),
    245 		mode:          c.config.DefaultQueryExecMode,
    246 	}
    247 
    248 	return ct.run(ctx)
    249 }