copy_from.go (6515B)
1 package pgx 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io" 8 9 "github.com/jackc/pgx/v5/internal/pgio" 10 "github.com/jackc/pgx/v5/pgconn" 11 ) 12 13 // CopyFromRows returns a CopyFromSource interface over the provided rows slice 14 // making it usable by *Conn.CopyFrom. 15 func CopyFromRows(rows [][]any) CopyFromSource { 16 return ©FromRows{rows: rows, idx: -1} 17 } 18 19 type copyFromRows struct { 20 rows [][]any 21 idx int 22 } 23 24 func (ctr *copyFromRows) Next() bool { 25 ctr.idx++ 26 return ctr.idx < len(ctr.rows) 27 } 28 29 func (ctr *copyFromRows) Values() ([]any, error) { 30 return ctr.rows[ctr.idx], nil 31 } 32 33 func (ctr *copyFromRows) Err() error { 34 return nil 35 } 36 37 // CopyFromSlice returns a CopyFromSource interface over a dynamic func 38 // making it usable by *Conn.CopyFrom. 39 func CopyFromSlice(length int, next func(int) ([]any, error)) CopyFromSource { 40 return ©FromSlice{next: next, idx: -1, len: length} 41 } 42 43 type copyFromSlice struct { 44 next func(int) ([]any, error) 45 idx int 46 len int 47 err error 48 } 49 50 func (cts *copyFromSlice) Next() bool { 51 cts.idx++ 52 return cts.idx < cts.len 53 } 54 55 func (cts *copyFromSlice) Values() ([]any, error) { 56 values, err := cts.next(cts.idx) 57 if err != nil { 58 cts.err = err 59 } 60 return values, err 61 } 62 63 func (cts *copyFromSlice) Err() error { 64 return cts.err 65 } 66 67 // CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data. 68 type CopyFromSource interface { 69 // Next returns true if there is another row and makes the next row data 70 // available to Values(). When there are no more rows available or an error 71 // has occurred it returns false. 72 Next() bool 73 74 // Values returns the values for the current row. 75 Values() ([]any, error) 76 77 // Err returns any error that has been encountered by the CopyFromSource. If 78 // this is not nil *Conn.CopyFrom will abort the copy. 79 Err() error 80 } 81 82 type copyFrom struct { 83 conn *Conn 84 tableName Identifier 85 columnNames []string 86 rowSrc CopyFromSource 87 readerErrChan chan error 88 mode QueryExecMode 89 } 90 91 func (ct *copyFrom) run(ctx context.Context) (int64, error) { 92 if ct.conn.copyFromTracer != nil { 93 ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{ 94 TableName: ct.tableName, 95 ColumnNames: ct.columnNames, 96 }) 97 } 98 99 quotedTableName := ct.tableName.Sanitize() 100 cbuf := &bytes.Buffer{} 101 for i, cn := range ct.columnNames { 102 if i != 0 { 103 cbuf.WriteString(", ") 104 } 105 cbuf.WriteString(quoteIdentifier(cn)) 106 } 107 quotedColumnNames := cbuf.String() 108 109 var sd *pgconn.StatementDescription 110 switch ct.mode { 111 case QueryExecModeExec, QueryExecModeSimpleProtocol: 112 // These modes don't support the binary format. Before the inclusion of the 113 // QueryExecModes, Conn.Prepare was called on every COPY operation to get 114 // the OIDs. These prepared statements were not cached. 115 // 116 // Since that's the same behavior provided by QueryExecModeDescribeExec, 117 // we'll default to that mode. 118 ct.mode = QueryExecModeDescribeExec 119 fallthrough 120 case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec: 121 var err error 122 sd, err = ct.conn.getStatementDescription( 123 ctx, 124 ct.mode, 125 fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName), 126 ) 127 if err != nil { 128 return 0, fmt.Errorf("statement description failed: %w", err) 129 } 130 default: 131 return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode) 132 } 133 134 r, w := io.Pipe() 135 doneChan := make(chan struct{}) 136 137 go func() { 138 defer close(doneChan) 139 140 // Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283. 141 buf := ct.conn.wbuf 142 143 buf = append(buf, "PGCOPY\n\377\r\n\000"...) 144 buf = pgio.AppendInt32(buf, 0) 145 buf = pgio.AppendInt32(buf, 0) 146 147 moreRows := true 148 for moreRows { 149 var err error 150 moreRows, buf, err = ct.buildCopyBuf(buf, sd) 151 if err != nil { 152 w.CloseWithError(err) 153 return 154 } 155 156 if ct.rowSrc.Err() != nil { 157 w.CloseWithError(ct.rowSrc.Err()) 158 return 159 } 160 161 if len(buf) > 0 { 162 _, err = w.Write(buf) 163 if err != nil { 164 w.Close() 165 return 166 } 167 } 168 169 buf = buf[:0] 170 } 171 172 w.Close() 173 }() 174 175 commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) 176 177 r.Close() 178 <-doneChan 179 180 if ct.conn.copyFromTracer != nil { 181 ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{ 182 CommandTag: commandTag, 183 Err: err, 184 }) 185 } 186 187 return commandTag.RowsAffected(), err 188 } 189 190 func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { 191 const sendBufSize = 65536 - 5 // The packet has a 5-byte header 192 lastBufLen := 0 193 largestRowLen := 0 194 195 for ct.rowSrc.Next() { 196 lastBufLen = len(buf) 197 198 values, err := ct.rowSrc.Values() 199 if err != nil { 200 return false, nil, err 201 } 202 if len(values) != len(ct.columnNames) { 203 return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) 204 } 205 206 buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) 207 for i, val := range values { 208 buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val) 209 if err != nil { 210 return false, nil, err 211 } 212 } 213 214 rowLen := len(buf) - lastBufLen 215 if rowLen > largestRowLen { 216 largestRowLen = rowLen 217 } 218 219 // Try not to overflow size of the buffer PgConn.CopyFrom will be reading into. If that happens then the nature of 220 // io.Pipe means that the next Read will be short. This can lead to pathological send sizes such as 65531, 13, 65531 221 // 13, 65531, 13, 65531, 13. 222 if len(buf) > sendBufSize-largestRowLen { 223 return true, buf, nil 224 } 225 } 226 227 return false, buf, nil 228 } 229 230 // CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. It returns the number of rows copied and 231 // an error. 232 // 233 // CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered 234 // for the type of each column. Almost all types implemented by pgx support the binary format. 235 // 236 // Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with 237 // Conn.LoadType and pgtype.Map.RegisterType. 238 func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { 239 ct := ©From{ 240 conn: c, 241 tableName: tableName, 242 columnNames: columnNames, 243 rowSrc: rowSrc, 244 readerErrChan: make(chan error), 245 mode: c.config.DefaultQueryExecMode, 246 } 247 248 return ct.run(ctx) 249 }