pgconn.go (53872B)
1 package pgconn 2 3 import ( 4 "context" 5 "crypto/md5" 6 "crypto/tls" 7 "encoding/binary" 8 "encoding/hex" 9 "errors" 10 "fmt" 11 "io" 12 "math" 13 "net" 14 "strconv" 15 "strings" 16 "sync" 17 "time" 18 19 "github.com/jackc/pgconn/internal/ctxwatch" 20 "github.com/jackc/pgio" 21 "github.com/jackc/pgproto3/v2" 22 ) 23 24 const ( 25 connStatusUninitialized = iota 26 connStatusConnecting 27 connStatusClosed 28 connStatusIdle 29 connStatusBusy 30 ) 31 32 const wbufLen = 1024 33 34 // Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from 35 // LISTEN/NOTIFY notification. 36 type Notice PgError 37 38 // Notification is a message received from the PostgreSQL LISTEN/NOTIFY system 39 type Notification struct { 40 PID uint32 // backend pid that sent the notification 41 Channel string // channel from which notification was received 42 Payload string 43 } 44 45 // DialFunc is a function that can be used to connect to a PostgreSQL server. 46 type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) 47 48 // LookupFunc is a function that can be used to lookup IPs addrs from host. Optionally an ip:port combination can be 49 // returned in order to override the connection string's port. 50 type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) 51 52 // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. 53 type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend 54 55 // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at 56 // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin 57 // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY 58 // notification. 59 type NoticeHandler func(*PgConn, *Notice) 60 61 // NotificationHandler is a function that can handle notifications received from the PostgreSQL server. Notifications 62 // can be received at any time, usually during handling of a query response. The *PgConn is provided so the handler is 63 // aware of the origin of the notice, but it must not invoke any query method. Be aware that this is distinct from a 64 // notice event. 65 type NotificationHandler func(*PgConn, *Notification) 66 67 // Frontend used to receive messages from backend. 68 type Frontend interface { 69 Receive() (pgproto3.BackendMessage, error) 70 } 71 72 // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. 73 type PgConn struct { 74 conn net.Conn // the underlying TCP or unix domain socket connection 75 pid uint32 // backend pid 76 secretKey uint32 // key to use to send a cancel query message to the server 77 parameterStatuses map[string]string // parameters that have been reported by the server 78 txStatus byte 79 frontend Frontend 80 81 config *Config 82 83 status byte // One of connStatus* constants 84 85 bufferingReceive bool 86 bufferingReceiveMux sync.Mutex 87 bufferingReceiveMsg pgproto3.BackendMessage 88 bufferingReceiveErr error 89 90 peekedMsg pgproto3.BackendMessage 91 92 // Reusable / preallocated resources 93 wbuf []byte // write buffer 94 resultReader ResultReader 95 multiResultReader MultiResultReader 96 contextWatcher *ctxwatch.ContextWatcher 97 98 cleanupDone chan struct{} 99 } 100 101 // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) 102 // to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt. 103 func Connect(ctx context.Context, connString string) (*PgConn, error) { 104 config, err := ParseConfig(connString) 105 if err != nil { 106 return nil, err 107 } 108 109 return ConnectConfig(ctx, config) 110 } 111 112 // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) 113 // and ParseConfigOptions to provide additional configuration. See documentation for ParseConfig for details. ctx can be 114 // used to cancel a connect attempt. 115 func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { 116 config, err := ParseConfigWithOptions(connString, parseConfigOptions) 117 if err != nil { 118 return nil, err 119 } 120 121 return ConnectConfig(ctx, config) 122 } 123 124 // Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with 125 // ParseConfig. ctx can be used to cancel a connect attempt. 126 // 127 // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An 128 // authentication error will terminate the chain of attempts (like libpq: 129 // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, 130 // if all attempts fail the last error is returned. 131 func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err error) { 132 // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from 133 // zero values. 134 if !config.createdByParseConfig { 135 panic("config must be created by ParseConfig") 136 } 137 138 // Simplify usage by treating primary config and fallbacks the same. 139 fallbackConfigs := []*FallbackConfig{ 140 { 141 Host: config.Host, 142 Port: config.Port, 143 TLSConfig: config.TLSConfig, 144 }, 145 } 146 fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) 147 ctx := octx 148 fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) 149 if err != nil { 150 return nil, &connectError{config: config, msg: "hostname resolving error", err: err} 151 } 152 153 if len(fallbackConfigs) == 0 { 154 return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} 155 } 156 157 foundBestServer := false 158 var fallbackConfig *FallbackConfig 159 for _, fc := range fallbackConfigs { 160 // ConnectTimeout restricts the whole connection process. 161 if config.ConnectTimeout != 0 { 162 var cancel context.CancelFunc 163 ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) 164 defer cancel() 165 } else { 166 ctx = octx 167 } 168 pgConn, err = connect(ctx, config, fc, false) 169 if err == nil { 170 foundBestServer = true 171 break 172 } else if pgerr, ok := err.(*PgError); ok { 173 err = &connectError{config: config, msg: "server error", err: pgerr} 174 const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password 175 const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings 176 const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist 177 const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege 178 if pgerr.Code == ERRCODE_INVALID_PASSWORD || 179 pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION || 180 pgerr.Code == ERRCODE_INVALID_CATALOG_NAME || 181 pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { 182 break 183 } 184 } else if cerr, ok := err.(*connectError); ok { 185 if _, ok := cerr.err.(*NotPreferredError); ok { 186 fallbackConfig = fc 187 } 188 } 189 } 190 191 if !foundBestServer && fallbackConfig != nil { 192 pgConn, err = connect(ctx, config, fallbackConfig, true) 193 if pgerr, ok := err.(*PgError); ok { 194 err = &connectError{config: config, msg: "server error", err: pgerr} 195 } 196 } 197 198 if err != nil { 199 return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError 200 } 201 202 if config.AfterConnect != nil { 203 err := config.AfterConnect(ctx, pgConn) 204 if err != nil { 205 pgConn.conn.Close() 206 return nil, &connectError{config: config, msg: "AfterConnect error", err: err} 207 } 208 } 209 210 return pgConn, nil 211 } 212 213 func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { 214 var configs []*FallbackConfig 215 216 for _, fb := range fallbacks { 217 // skip resolve for unix sockets 218 if isAbsolutePath(fb.Host) { 219 configs = append(configs, &FallbackConfig{ 220 Host: fb.Host, 221 Port: fb.Port, 222 TLSConfig: fb.TLSConfig, 223 }) 224 225 continue 226 } 227 228 ips, err := lookupFn(ctx, fb.Host) 229 if err != nil { 230 return nil, err 231 } 232 233 for _, ip := range ips { 234 splitIP, splitPort, err := net.SplitHostPort(ip) 235 if err == nil { 236 port, err := strconv.ParseUint(splitPort, 10, 16) 237 if err != nil { 238 return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err) 239 } 240 configs = append(configs, &FallbackConfig{ 241 Host: splitIP, 242 Port: uint16(port), 243 TLSConfig: fb.TLSConfig, 244 }) 245 } else { 246 configs = append(configs, &FallbackConfig{ 247 Host: ip, 248 Port: fb.Port, 249 TLSConfig: fb.TLSConfig, 250 }) 251 } 252 } 253 } 254 255 return configs, nil 256 } 257 258 func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig, 259 ignoreNotPreferredErr bool) (*PgConn, error) { 260 pgConn := new(PgConn) 261 pgConn.config = config 262 pgConn.wbuf = make([]byte, 0, wbufLen) 263 pgConn.cleanupDone = make(chan struct{}) 264 265 var err error 266 network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) 267 netConn, err := config.DialFunc(ctx, network, address) 268 if err != nil { 269 var netErr net.Error 270 if errors.As(err, &netErr) && netErr.Timeout() { 271 err = &errTimeout{err: err} 272 } 273 return nil, &connectError{config: config, msg: "dial error", err: err} 274 } 275 276 pgConn.conn = netConn 277 pgConn.contextWatcher = newContextWatcher(netConn) 278 pgConn.contextWatcher.Watch(ctx) 279 280 if fallbackConfig.TLSConfig != nil { 281 tlsConn, err := startTLS(netConn, fallbackConfig.TLSConfig) 282 pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. 283 if err != nil { 284 netConn.Close() 285 return nil, &connectError{config: config, msg: "tls error", err: err} 286 } 287 288 pgConn.conn = tlsConn 289 pgConn.contextWatcher = newContextWatcher(tlsConn) 290 pgConn.contextWatcher.Watch(ctx) 291 } 292 293 defer pgConn.contextWatcher.Unwatch() 294 295 pgConn.parameterStatuses = make(map[string]string) 296 pgConn.status = connStatusConnecting 297 pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) 298 299 startupMsg := pgproto3.StartupMessage{ 300 ProtocolVersion: pgproto3.ProtocolVersionNumber, 301 Parameters: make(map[string]string), 302 } 303 304 // Copy default run-time params 305 for k, v := range config.RuntimeParams { 306 startupMsg.Parameters[k] = v 307 } 308 309 startupMsg.Parameters["user"] = config.User 310 if config.Database != "" { 311 startupMsg.Parameters["database"] = config.Database 312 } 313 314 if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { 315 pgConn.conn.Close() 316 return nil, &connectError{config: config, msg: "failed to write startup message", err: err} 317 } 318 319 for { 320 msg, err := pgConn.receiveMessage() 321 if err != nil { 322 pgConn.conn.Close() 323 if err, ok := err.(*PgError); ok { 324 return nil, err 325 } 326 return nil, &connectError{config: config, msg: "failed to receive message", err: preferContextOverNetTimeoutError(ctx, err)} 327 } 328 329 switch msg := msg.(type) { 330 case *pgproto3.BackendKeyData: 331 pgConn.pid = msg.ProcessID 332 pgConn.secretKey = msg.SecretKey 333 334 case *pgproto3.AuthenticationOk: 335 case *pgproto3.AuthenticationCleartextPassword: 336 err = pgConn.txPasswordMessage(pgConn.config.Password) 337 if err != nil { 338 pgConn.conn.Close() 339 return nil, &connectError{config: config, msg: "failed to write password message", err: err} 340 } 341 case *pgproto3.AuthenticationMD5Password: 342 digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) 343 err = pgConn.txPasswordMessage(digestedPassword) 344 if err != nil { 345 pgConn.conn.Close() 346 return nil, &connectError{config: config, msg: "failed to write password message", err: err} 347 } 348 case *pgproto3.AuthenticationSASL: 349 err = pgConn.scramAuth(msg.AuthMechanisms) 350 if err != nil { 351 pgConn.conn.Close() 352 return nil, &connectError{config: config, msg: "failed SASL auth", err: err} 353 } 354 case *pgproto3.AuthenticationGSS: 355 err = pgConn.gssAuth() 356 if err != nil { 357 pgConn.conn.Close() 358 return nil, &connectError{config: config, msg: "failed GSS auth", err: err} 359 } 360 case *pgproto3.ReadyForQuery: 361 pgConn.status = connStatusIdle 362 if config.ValidateConnect != nil { 363 // ValidateConnect may execute commands that cause the context to be watched again. Unwatch first to avoid 364 // the watch already in progress panic. This is that last thing done by this method so there is no need to 365 // restart the watch after ValidateConnect returns. 366 // 367 // See https://github.com/jackc/pgconn/issues/40. 368 pgConn.contextWatcher.Unwatch() 369 370 err := config.ValidateConnect(ctx, pgConn) 371 if err != nil { 372 if _, ok := err.(*NotPreferredError); ignoreNotPreferredErr && ok { 373 return pgConn, nil 374 } 375 pgConn.conn.Close() 376 return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} 377 } 378 } 379 return pgConn, nil 380 case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse: 381 // handled by ReceiveMessage 382 case *pgproto3.ErrorResponse: 383 pgConn.conn.Close() 384 return nil, ErrorResponseToPgError(msg) 385 default: 386 pgConn.conn.Close() 387 return nil, &connectError{config: config, msg: "received unexpected message", err: err} 388 } 389 } 390 } 391 392 func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { 393 return ctxwatch.NewContextWatcher( 394 func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, 395 func() { conn.SetDeadline(time.Time{}) }, 396 ) 397 } 398 399 func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { 400 err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) 401 if err != nil { 402 return nil, err 403 } 404 405 response := make([]byte, 1) 406 if _, err = io.ReadFull(conn, response); err != nil { 407 return nil, err 408 } 409 410 if response[0] != 'S' { 411 return nil, errors.New("server refused TLS connection") 412 } 413 414 return tls.Client(conn, tlsConfig), nil 415 } 416 417 func (pgConn *PgConn) txPasswordMessage(password string) (err error) { 418 msg := &pgproto3.PasswordMessage{Password: password} 419 _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) 420 return err 421 } 422 423 func hexMD5(s string) string { 424 hash := md5.New() 425 io.WriteString(hash, s) 426 return hex.EncodeToString(hash.Sum(nil)) 427 } 428 429 func (pgConn *PgConn) signalMessage() chan struct{} { 430 if pgConn.bufferingReceive { 431 panic("BUG: signalMessage when already in progress") 432 } 433 434 pgConn.bufferingReceive = true 435 pgConn.bufferingReceiveMux.Lock() 436 437 ch := make(chan struct{}) 438 go func() { 439 pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() 440 pgConn.bufferingReceiveMux.Unlock() 441 close(ch) 442 }() 443 444 return ch 445 } 446 447 // SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as 448 // error to call SendBytes while reading the result of a query. 449 // 450 // This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. 451 // See https://www.postgresql.org/docs/current/protocol.html. 452 func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { 453 if err := pgConn.lock(); err != nil { 454 return err 455 } 456 defer pgConn.unlock() 457 458 if ctx != context.Background() { 459 select { 460 case <-ctx.Done(): 461 return newContextAlreadyDoneError(ctx) 462 default: 463 } 464 pgConn.contextWatcher.Watch(ctx) 465 defer pgConn.contextWatcher.Unwatch() 466 } 467 468 n, err := pgConn.conn.Write(buf) 469 if err != nil { 470 pgConn.asyncClose() 471 return &writeError{err: err, safeToRetry: n == 0} 472 } 473 474 return nil 475 } 476 477 // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the 478 // connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages 479 // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger 480 // the OnNotification callback. 481 // 482 // This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. 483 // See https://www.postgresql.org/docs/current/protocol.html. 484 func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { 485 if err := pgConn.lock(); err != nil { 486 return nil, err 487 } 488 defer pgConn.unlock() 489 490 if ctx != context.Background() { 491 select { 492 case <-ctx.Done(): 493 return nil, newContextAlreadyDoneError(ctx) 494 default: 495 } 496 pgConn.contextWatcher.Watch(ctx) 497 defer pgConn.contextWatcher.Unwatch() 498 } 499 500 msg, err := pgConn.receiveMessage() 501 if err != nil { 502 err = &pgconnError{ 503 msg: "receive message failed", 504 err: preferContextOverNetTimeoutError(ctx, err), 505 safeToRetry: true} 506 } 507 return msg, err 508 } 509 510 // peekMessage peeks at the next message without setting up context cancellation. 511 func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { 512 if pgConn.peekedMsg != nil { 513 return pgConn.peekedMsg, nil 514 } 515 516 var msg pgproto3.BackendMessage 517 var err error 518 if pgConn.bufferingReceive { 519 pgConn.bufferingReceiveMux.Lock() 520 msg = pgConn.bufferingReceiveMsg 521 err = pgConn.bufferingReceiveErr 522 pgConn.bufferingReceiveMux.Unlock() 523 pgConn.bufferingReceive = false 524 525 // If a timeout error happened in the background try the read again. 526 var netErr net.Error 527 if errors.As(err, &netErr) && netErr.Timeout() { 528 msg, err = pgConn.frontend.Receive() 529 } 530 } else { 531 msg, err = pgConn.frontend.Receive() 532 } 533 534 if err != nil { 535 // Close on anything other than timeout error - everything else is fatal 536 var netErr net.Error 537 isNetErr := errors.As(err, &netErr) 538 if !(isNetErr && netErr.Timeout()) { 539 pgConn.asyncClose() 540 } 541 542 return nil, err 543 } 544 545 pgConn.peekedMsg = msg 546 return msg, nil 547 } 548 549 // receiveMessage receives a message without setting up context cancellation 550 func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { 551 msg, err := pgConn.peekMessage() 552 if err != nil { 553 // Close on anything other than timeout error - everything else is fatal 554 var netErr net.Error 555 isNetErr := errors.As(err, &netErr) 556 if !(isNetErr && netErr.Timeout()) { 557 pgConn.asyncClose() 558 } 559 560 return nil, err 561 } 562 pgConn.peekedMsg = nil 563 564 switch msg := msg.(type) { 565 case *pgproto3.ReadyForQuery: 566 pgConn.txStatus = msg.TxStatus 567 case *pgproto3.ParameterStatus: 568 pgConn.parameterStatuses[msg.Name] = msg.Value 569 case *pgproto3.ErrorResponse: 570 if msg.Severity == "FATAL" { 571 pgConn.status = connStatusClosed 572 pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. 573 close(pgConn.cleanupDone) 574 return nil, ErrorResponseToPgError(msg) 575 } 576 case *pgproto3.NoticeResponse: 577 if pgConn.config.OnNotice != nil { 578 pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg)) 579 } 580 case *pgproto3.NotificationResponse: 581 if pgConn.config.OnNotification != nil { 582 pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) 583 } 584 } 585 586 return msg, nil 587 } 588 589 // Conn returns the underlying net.Conn. 590 func (pgConn *PgConn) Conn() net.Conn { 591 return pgConn.conn 592 } 593 594 // PID returns the backend PID. 595 func (pgConn *PgConn) PID() uint32 { 596 return pgConn.pid 597 } 598 599 // TxStatus returns the current TxStatus as reported by the server in the ReadyForQuery message. 600 // 601 // Possible return values: 602 // 'I' - idle / not in transaction 603 // 'T' - in a transaction 604 // 'E' - in a failed transaction 605 // 606 // See https://www.postgresql.org/docs/current/protocol-message-formats.html. 607 func (pgConn *PgConn) TxStatus() byte { 608 return pgConn.txStatus 609 } 610 611 // SecretKey returns the backend secret key used to send a cancel query message to the server. 612 func (pgConn *PgConn) SecretKey() uint32 { 613 return pgConn.secretKey 614 } 615 616 // Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by 617 // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The 618 // underlying net.Conn.Close() will always be called regardless of any other errors. 619 func (pgConn *PgConn) Close(ctx context.Context) error { 620 if pgConn.status == connStatusClosed { 621 return nil 622 } 623 pgConn.status = connStatusClosed 624 625 defer close(pgConn.cleanupDone) 626 defer pgConn.conn.Close() 627 628 if ctx != context.Background() { 629 // Close may be called while a cancellable query is in progress. This will most often be triggered by panic when 630 // a defer closes the connection (possibly indirectly via a transaction or a connection pool). Unwatch to end any 631 // previous watch. It is safe to Unwatch regardless of whether a watch is already is progress. 632 // 633 // See https://github.com/jackc/pgconn/issues/29 634 pgConn.contextWatcher.Unwatch() 635 636 pgConn.contextWatcher.Watch(ctx) 637 defer pgConn.contextWatcher.Unwatch() 638 } 639 640 // Ignore any errors sending Terminate message and waiting for server to close connection. 641 // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully 642 // ignores errors. 643 // 644 // See https://github.com/jackc/pgx/issues/637 645 pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) 646 647 return pgConn.conn.Close() 648 } 649 650 // asyncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying 651 // connection. 652 func (pgConn *PgConn) asyncClose() { 653 if pgConn.status == connStatusClosed { 654 return 655 } 656 pgConn.status = connStatusClosed 657 658 go func() { 659 defer close(pgConn.cleanupDone) 660 defer pgConn.conn.Close() 661 662 deadline := time.Now().Add(time.Second * 15) 663 664 ctx, cancel := context.WithDeadline(context.Background(), deadline) 665 defer cancel() 666 667 pgConn.CancelRequest(ctx) 668 669 pgConn.conn.SetDeadline(deadline) 670 671 pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) 672 }() 673 } 674 675 // CleanupDone returns a channel that will be closed after all underlying resources have been cleaned up. A closed 676 // connection is no longer usable, but underlying resources, in particular the net.Conn, may not have finished closing 677 // yet. This is because certain errors such as a context cancellation require that the interrupted function call return 678 // immediately, but the error may also cause the connection to be closed. In these cases the underlying resources are 679 // closed asynchronously. 680 // 681 // This is only likely to be useful to connection pools. It gives them a way avoid establishing a new connection while 682 // an old connection is still being cleaned up and thereby exceeding the maximum pool size. 683 func (pgConn *PgConn) CleanupDone() chan (struct{}) { 684 return pgConn.cleanupDone 685 } 686 687 // IsClosed reports if the connection has been closed. 688 // 689 // CleanupDone() can be used to determine if all cleanup has been completed. 690 func (pgConn *PgConn) IsClosed() bool { 691 return pgConn.status < connStatusIdle 692 } 693 694 // IsBusy reports if the connection is busy. 695 func (pgConn *PgConn) IsBusy() bool { 696 return pgConn.status == connStatusBusy 697 } 698 699 // lock locks the connection. 700 func (pgConn *PgConn) lock() error { 701 switch pgConn.status { 702 case connStatusBusy: 703 return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug. 704 case connStatusClosed: 705 return &connLockError{status: "conn closed"} 706 case connStatusUninitialized: 707 return &connLockError{status: "conn uninitialized"} 708 } 709 pgConn.status = connStatusBusy 710 return nil 711 } 712 713 func (pgConn *PgConn) unlock() { 714 switch pgConn.status { 715 case connStatusBusy: 716 pgConn.status = connStatusIdle 717 case connStatusClosed: 718 default: 719 panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package. 720 } 721 } 722 723 // ParameterStatus returns the value of a parameter reported by the server (e.g. 724 // server_version). Returns an empty string for unknown parameters. 725 func (pgConn *PgConn) ParameterStatus(key string) string { 726 return pgConn.parameterStatuses[key] 727 } 728 729 // CommandTag is the result of an Exec function 730 type CommandTag []byte 731 732 // RowsAffected returns the number of rows affected. If the CommandTag was not 733 // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. 734 func (ct CommandTag) RowsAffected() int64 { 735 // Find last non-digit 736 idx := -1 737 for i := len(ct) - 1; i >= 0; i-- { 738 if ct[i] >= '0' && ct[i] <= '9' { 739 idx = i 740 } else { 741 break 742 } 743 } 744 745 if idx == -1 { 746 return 0 747 } 748 749 var n int64 750 for _, b := range ct[idx:] { 751 n = n*10 + int64(b-'0') 752 } 753 754 return n 755 } 756 757 func (ct CommandTag) String() string { 758 return string(ct) 759 } 760 761 // Insert is true if the command tag starts with "INSERT". 762 func (ct CommandTag) Insert() bool { 763 return len(ct) >= 6 && 764 ct[0] == 'I' && 765 ct[1] == 'N' && 766 ct[2] == 'S' && 767 ct[3] == 'E' && 768 ct[4] == 'R' && 769 ct[5] == 'T' 770 } 771 772 // Update is true if the command tag starts with "UPDATE". 773 func (ct CommandTag) Update() bool { 774 return len(ct) >= 6 && 775 ct[0] == 'U' && 776 ct[1] == 'P' && 777 ct[2] == 'D' && 778 ct[3] == 'A' && 779 ct[4] == 'T' && 780 ct[5] == 'E' 781 } 782 783 // Delete is true if the command tag starts with "DELETE". 784 func (ct CommandTag) Delete() bool { 785 return len(ct) >= 6 && 786 ct[0] == 'D' && 787 ct[1] == 'E' && 788 ct[2] == 'L' && 789 ct[3] == 'E' && 790 ct[4] == 'T' && 791 ct[5] == 'E' 792 } 793 794 // Select is true if the command tag starts with "SELECT". 795 func (ct CommandTag) Select() bool { 796 return len(ct) >= 6 && 797 ct[0] == 'S' && 798 ct[1] == 'E' && 799 ct[2] == 'L' && 800 ct[3] == 'E' && 801 ct[4] == 'C' && 802 ct[5] == 'T' 803 } 804 805 type StatementDescription struct { 806 Name string 807 SQL string 808 ParamOIDs []uint32 809 Fields []pgproto3.FieldDescription 810 } 811 812 // Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This 813 // allows Prepare to also to describe statements without creating a server-side prepared statement. 814 func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { 815 if err := pgConn.lock(); err != nil { 816 return nil, err 817 } 818 defer pgConn.unlock() 819 820 if ctx != context.Background() { 821 select { 822 case <-ctx.Done(): 823 return nil, newContextAlreadyDoneError(ctx) 824 default: 825 } 826 pgConn.contextWatcher.Watch(ctx) 827 defer pgConn.contextWatcher.Unwatch() 828 } 829 830 buf := pgConn.wbuf 831 buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) 832 buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) 833 buf = (&pgproto3.Sync{}).Encode(buf) 834 835 n, err := pgConn.conn.Write(buf) 836 if err != nil { 837 pgConn.asyncClose() 838 return nil, &writeError{err: err, safeToRetry: n == 0} 839 } 840 841 psd := &StatementDescription{Name: name, SQL: sql} 842 843 var parseErr error 844 845 readloop: 846 for { 847 msg, err := pgConn.receiveMessage() 848 if err != nil { 849 pgConn.asyncClose() 850 return nil, preferContextOverNetTimeoutError(ctx, err) 851 } 852 853 switch msg := msg.(type) { 854 case *pgproto3.ParameterDescription: 855 psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) 856 copy(psd.ParamOIDs, msg.ParameterOIDs) 857 case *pgproto3.RowDescription: 858 psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) 859 copy(psd.Fields, msg.Fields) 860 case *pgproto3.ErrorResponse: 861 parseErr = ErrorResponseToPgError(msg) 862 case *pgproto3.ReadyForQuery: 863 break readloop 864 } 865 } 866 867 if parseErr != nil { 868 return nil, parseErr 869 } 870 return psd, nil 871 } 872 873 // ErrorResponseToPgError converts a wire protocol error message to a *PgError. 874 func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { 875 return &PgError{ 876 Severity: msg.Severity, 877 Code: string(msg.Code), 878 Message: string(msg.Message), 879 Detail: string(msg.Detail), 880 Hint: msg.Hint, 881 Position: msg.Position, 882 InternalPosition: msg.InternalPosition, 883 InternalQuery: string(msg.InternalQuery), 884 Where: string(msg.Where), 885 SchemaName: string(msg.SchemaName), 886 TableName: string(msg.TableName), 887 ColumnName: string(msg.ColumnName), 888 DataTypeName: string(msg.DataTypeName), 889 ConstraintName: msg.ConstraintName, 890 File: string(msg.File), 891 Line: msg.Line, 892 Routine: string(msg.Routine), 893 } 894 } 895 896 func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { 897 pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg)) 898 return (*Notice)(pgerr) 899 } 900 901 // CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel 902 // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there 903 // is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 904 func (pgConn *PgConn) CancelRequest(ctx context.Context) error { 905 // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing 906 // the connection config. This is important in high availability configurations where fallback connections may be 907 // specified or DNS may be used to load balance. 908 serverAddr := pgConn.conn.RemoteAddr() 909 cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) 910 if err != nil { 911 return err 912 } 913 defer cancelConn.Close() 914 915 if ctx != context.Background() { 916 contextWatcher := ctxwatch.NewContextWatcher( 917 func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, 918 func() { cancelConn.SetDeadline(time.Time{}) }, 919 ) 920 contextWatcher.Watch(ctx) 921 defer contextWatcher.Unwatch() 922 } 923 924 buf := make([]byte, 16) 925 binary.BigEndian.PutUint32(buf[0:4], 16) 926 binary.BigEndian.PutUint32(buf[4:8], 80877102) 927 binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) 928 binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) 929 _, err = cancelConn.Write(buf) 930 if err != nil { 931 return err 932 } 933 934 _, err = cancelConn.Read(buf) 935 if err != io.EOF { 936 return err 937 } 938 939 return nil 940 } 941 942 // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not 943 // received. 944 func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { 945 if err := pgConn.lock(); err != nil { 946 return err 947 } 948 defer pgConn.unlock() 949 950 if ctx != context.Background() { 951 select { 952 case <-ctx.Done(): 953 return newContextAlreadyDoneError(ctx) 954 default: 955 } 956 957 pgConn.contextWatcher.Watch(ctx) 958 defer pgConn.contextWatcher.Unwatch() 959 } 960 961 for { 962 msg, err := pgConn.receiveMessage() 963 if err != nil { 964 return preferContextOverNetTimeoutError(ctx, err) 965 } 966 967 switch msg.(type) { 968 case *pgproto3.NotificationResponse: 969 return nil 970 } 971 } 972 } 973 974 // Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is 975 // implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control 976 // statements. 977 // 978 // Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. 979 func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { 980 if err := pgConn.lock(); err != nil { 981 return &MultiResultReader{ 982 closed: true, 983 err: err, 984 } 985 } 986 987 pgConn.multiResultReader = MultiResultReader{ 988 pgConn: pgConn, 989 ctx: ctx, 990 } 991 multiResult := &pgConn.multiResultReader 992 if ctx != context.Background() { 993 select { 994 case <-ctx.Done(): 995 multiResult.closed = true 996 multiResult.err = newContextAlreadyDoneError(ctx) 997 pgConn.unlock() 998 return multiResult 999 default: 1000 } 1001 pgConn.contextWatcher.Watch(ctx) 1002 } 1003 1004 buf := pgConn.wbuf 1005 buf = (&pgproto3.Query{String: sql}).Encode(buf) 1006 1007 n, err := pgConn.conn.Write(buf) 1008 if err != nil { 1009 pgConn.asyncClose() 1010 pgConn.contextWatcher.Unwatch() 1011 multiResult.closed = true 1012 multiResult.err = &writeError{err: err, safeToRetry: n == 0} 1013 pgConn.unlock() 1014 return multiResult 1015 } 1016 1017 return multiResult 1018 } 1019 1020 // ReceiveResults reads the result that might be returned by Postgres after a SendBytes 1021 // (e.a. after sending a CopyDone in a copy-both situation). 1022 // 1023 // This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. 1024 // See https://www.postgresql.org/docs/current/protocol.html. 1025 func (pgConn *PgConn) ReceiveResults(ctx context.Context) *MultiResultReader { 1026 if err := pgConn.lock(); err != nil { 1027 return &MultiResultReader{ 1028 closed: true, 1029 err: err, 1030 } 1031 } 1032 1033 pgConn.multiResultReader = MultiResultReader{ 1034 pgConn: pgConn, 1035 ctx: ctx, 1036 } 1037 multiResult := &pgConn.multiResultReader 1038 if ctx != context.Background() { 1039 select { 1040 case <-ctx.Done(): 1041 multiResult.closed = true 1042 multiResult.err = newContextAlreadyDoneError(ctx) 1043 pgConn.unlock() 1044 return multiResult 1045 default: 1046 } 1047 pgConn.contextWatcher.Watch(ctx) 1048 } 1049 1050 return multiResult 1051 } 1052 1053 // ExecParams executes a command via the PostgreSQL extended query protocol. 1054 // 1055 // sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3, 1056 // etc. 1057 // 1058 // paramValues are the parameter values. It must be encoded in the format given by paramFormats. 1059 // 1060 // paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for 1061 // all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. 1062 // ExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). 1063 // 1064 // paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or 1065 // binary format. If paramFormats is nil all params are text format. ExecParams will panic if 1066 // len(paramFormats) is not 0, 1, or len(paramValues). 1067 // 1068 // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or 1069 // binary format. If resultFormats is nil all results will be in text format. 1070 // 1071 // ResultReader must be closed before PgConn can be used again. 1072 func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { 1073 result := pgConn.execExtendedPrefix(ctx, paramValues) 1074 if result.closed { 1075 return result 1076 } 1077 1078 buf := pgConn.wbuf 1079 buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) 1080 buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) 1081 1082 pgConn.execExtendedSuffix(buf, result) 1083 1084 return result 1085 } 1086 1087 // ExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. 1088 // 1089 // paramValues are the parameter values. It must be encoded in the format given by paramFormats. 1090 // 1091 // paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or 1092 // binary format. If paramFormats is nil all params are text format. ExecPrepared will panic if 1093 // len(paramFormats) is not 0, 1, or len(paramValues). 1094 // 1095 // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or 1096 // binary format. If resultFormats is nil all results will be in text format. 1097 // 1098 // ResultReader must be closed before PgConn can be used again. 1099 func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { 1100 result := pgConn.execExtendedPrefix(ctx, paramValues) 1101 if result.closed { 1102 return result 1103 } 1104 1105 buf := pgConn.wbuf 1106 buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) 1107 1108 pgConn.execExtendedSuffix(buf, result) 1109 1110 return result 1111 } 1112 1113 func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { 1114 pgConn.resultReader = ResultReader{ 1115 pgConn: pgConn, 1116 ctx: ctx, 1117 } 1118 result := &pgConn.resultReader 1119 1120 if err := pgConn.lock(); err != nil { 1121 result.concludeCommand(nil, err) 1122 result.closed = true 1123 return result 1124 } 1125 1126 if len(paramValues) > math.MaxUint16 { 1127 result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) 1128 result.closed = true 1129 pgConn.unlock() 1130 return result 1131 } 1132 1133 if ctx != context.Background() { 1134 select { 1135 case <-ctx.Done(): 1136 result.concludeCommand(nil, newContextAlreadyDoneError(ctx)) 1137 result.closed = true 1138 pgConn.unlock() 1139 return result 1140 default: 1141 } 1142 pgConn.contextWatcher.Watch(ctx) 1143 } 1144 1145 return result 1146 } 1147 1148 func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { 1149 buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) 1150 buf = (&pgproto3.Execute{}).Encode(buf) 1151 buf = (&pgproto3.Sync{}).Encode(buf) 1152 1153 n, err := pgConn.conn.Write(buf) 1154 if err != nil { 1155 pgConn.asyncClose() 1156 result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) 1157 pgConn.contextWatcher.Unwatch() 1158 result.closed = true 1159 pgConn.unlock() 1160 return 1161 } 1162 1163 result.readUntilRowDescription() 1164 } 1165 1166 // CopyTo executes the copy command sql and copies the results to w. 1167 func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { 1168 if err := pgConn.lock(); err != nil { 1169 return nil, err 1170 } 1171 1172 if ctx != context.Background() { 1173 select { 1174 case <-ctx.Done(): 1175 pgConn.unlock() 1176 return nil, newContextAlreadyDoneError(ctx) 1177 default: 1178 } 1179 pgConn.contextWatcher.Watch(ctx) 1180 defer pgConn.contextWatcher.Unwatch() 1181 } 1182 1183 // Send copy to command 1184 buf := pgConn.wbuf 1185 buf = (&pgproto3.Query{String: sql}).Encode(buf) 1186 1187 n, err := pgConn.conn.Write(buf) 1188 if err != nil { 1189 pgConn.asyncClose() 1190 pgConn.unlock() 1191 return nil, &writeError{err: err, safeToRetry: n == 0} 1192 } 1193 1194 // Read results 1195 var commandTag CommandTag 1196 var pgErr error 1197 for { 1198 msg, err := pgConn.receiveMessage() 1199 if err != nil { 1200 pgConn.asyncClose() 1201 return nil, preferContextOverNetTimeoutError(ctx, err) 1202 } 1203 1204 switch msg := msg.(type) { 1205 case *pgproto3.CopyDone: 1206 case *pgproto3.CopyData: 1207 _, err := w.Write(msg.Data) 1208 if err != nil { 1209 pgConn.asyncClose() 1210 return nil, err 1211 } 1212 case *pgproto3.ReadyForQuery: 1213 pgConn.unlock() 1214 return commandTag, pgErr 1215 case *pgproto3.CommandComplete: 1216 commandTag = CommandTag(msg.CommandTag) 1217 case *pgproto3.ErrorResponse: 1218 pgErr = ErrorResponseToPgError(msg) 1219 } 1220 } 1221 } 1222 1223 // CopyFrom executes the copy command sql and copies all of r to the PostgreSQL server. 1224 // 1225 // Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r 1226 // could still block. 1227 func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { 1228 if err := pgConn.lock(); err != nil { 1229 return nil, err 1230 } 1231 defer pgConn.unlock() 1232 1233 if ctx != context.Background() { 1234 select { 1235 case <-ctx.Done(): 1236 return nil, newContextAlreadyDoneError(ctx) 1237 default: 1238 } 1239 pgConn.contextWatcher.Watch(ctx) 1240 defer pgConn.contextWatcher.Unwatch() 1241 } 1242 1243 // Send copy to command 1244 buf := pgConn.wbuf 1245 buf = (&pgproto3.Query{String: sql}).Encode(buf) 1246 1247 n, err := pgConn.conn.Write(buf) 1248 if err != nil { 1249 pgConn.asyncClose() 1250 return nil, &writeError{err: err, safeToRetry: n == 0} 1251 } 1252 1253 // Send copy data 1254 abortCopyChan := make(chan struct{}) 1255 copyErrChan := make(chan error, 1) 1256 signalMessageChan := pgConn.signalMessage() 1257 var wg sync.WaitGroup 1258 wg.Add(1) 1259 1260 go func() { 1261 defer wg.Done() 1262 buf := make([]byte, 0, 65536) 1263 buf = append(buf, 'd') 1264 sp := len(buf) 1265 1266 for { 1267 n, readErr := r.Read(buf[5:cap(buf)]) 1268 if n > 0 { 1269 buf = buf[0 : n+5] 1270 pgio.SetInt32(buf[sp:], int32(n+4)) 1271 1272 _, writeErr := pgConn.conn.Write(buf) 1273 if writeErr != nil { 1274 // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. 1275 pgConn.conn.Close() 1276 1277 copyErrChan <- writeErr 1278 return 1279 } 1280 } 1281 if readErr != nil { 1282 copyErrChan <- readErr 1283 return 1284 } 1285 1286 select { 1287 case <-abortCopyChan: 1288 return 1289 default: 1290 } 1291 } 1292 }() 1293 1294 var pgErr error 1295 var copyErr error 1296 for copyErr == nil && pgErr == nil { 1297 select { 1298 case copyErr = <-copyErrChan: 1299 case <-signalMessageChan: 1300 msg, err := pgConn.receiveMessage() 1301 if err != nil { 1302 pgConn.asyncClose() 1303 return nil, preferContextOverNetTimeoutError(ctx, err) 1304 } 1305 1306 switch msg := msg.(type) { 1307 case *pgproto3.ErrorResponse: 1308 pgErr = ErrorResponseToPgError(msg) 1309 default: 1310 signalMessageChan = pgConn.signalMessage() 1311 } 1312 } 1313 } 1314 close(abortCopyChan) 1315 // Make sure io goroutine finishes before writing. 1316 wg.Wait() 1317 1318 buf = buf[:0] 1319 if copyErr == io.EOF || pgErr != nil { 1320 copyDone := &pgproto3.CopyDone{} 1321 buf = copyDone.Encode(buf) 1322 } else { 1323 copyFail := &pgproto3.CopyFail{Message: copyErr.Error()} 1324 buf = copyFail.Encode(buf) 1325 } 1326 _, err = pgConn.conn.Write(buf) 1327 if err != nil { 1328 pgConn.asyncClose() 1329 return nil, err 1330 } 1331 1332 // Read results 1333 var commandTag CommandTag 1334 for { 1335 msg, err := pgConn.receiveMessage() 1336 if err != nil { 1337 pgConn.asyncClose() 1338 return nil, preferContextOverNetTimeoutError(ctx, err) 1339 } 1340 1341 switch msg := msg.(type) { 1342 case *pgproto3.ReadyForQuery: 1343 return commandTag, pgErr 1344 case *pgproto3.CommandComplete: 1345 commandTag = CommandTag(msg.CommandTag) 1346 case *pgproto3.ErrorResponse: 1347 pgErr = ErrorResponseToPgError(msg) 1348 } 1349 } 1350 } 1351 1352 // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. 1353 type MultiResultReader struct { 1354 pgConn *PgConn 1355 ctx context.Context 1356 1357 rr *ResultReader 1358 1359 closed bool 1360 err error 1361 } 1362 1363 // ReadAll reads all available results. Calling ReadAll is mutually exclusive with all other MultiResultReader methods. 1364 func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { 1365 var results []*Result 1366 1367 for mrr.NextResult() { 1368 results = append(results, mrr.ResultReader().Read()) 1369 } 1370 err := mrr.Close() 1371 1372 return results, err 1373 } 1374 1375 func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { 1376 msg, err := mrr.pgConn.receiveMessage() 1377 1378 if err != nil { 1379 mrr.pgConn.contextWatcher.Unwatch() 1380 mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) 1381 mrr.closed = true 1382 mrr.pgConn.asyncClose() 1383 return nil, mrr.err 1384 } 1385 1386 switch msg := msg.(type) { 1387 case *pgproto3.ReadyForQuery: 1388 mrr.pgConn.contextWatcher.Unwatch() 1389 mrr.closed = true 1390 mrr.pgConn.unlock() 1391 case *pgproto3.ErrorResponse: 1392 mrr.err = ErrorResponseToPgError(msg) 1393 } 1394 1395 return msg, nil 1396 } 1397 1398 // NextResult returns advances the MultiResultReader to the next result and returns true if a result is available. 1399 func (mrr *MultiResultReader) NextResult() bool { 1400 for !mrr.closed && mrr.err == nil { 1401 msg, err := mrr.receiveMessage() 1402 if err != nil { 1403 return false 1404 } 1405 1406 switch msg := msg.(type) { 1407 case *pgproto3.RowDescription: 1408 mrr.pgConn.resultReader = ResultReader{ 1409 pgConn: mrr.pgConn, 1410 multiResultReader: mrr, 1411 ctx: mrr.ctx, 1412 fieldDescriptions: msg.Fields, 1413 } 1414 mrr.rr = &mrr.pgConn.resultReader 1415 return true 1416 case *pgproto3.CommandComplete: 1417 mrr.pgConn.resultReader = ResultReader{ 1418 commandTag: CommandTag(msg.CommandTag), 1419 commandConcluded: true, 1420 closed: true, 1421 } 1422 mrr.rr = &mrr.pgConn.resultReader 1423 return true 1424 case *pgproto3.EmptyQueryResponse: 1425 return false 1426 } 1427 } 1428 1429 return false 1430 } 1431 1432 // ResultReader returns the current ResultReader. 1433 func (mrr *MultiResultReader) ResultReader() *ResultReader { 1434 return mrr.rr 1435 } 1436 1437 // Close closes the MultiResultReader and returns the first error that occurred during the MultiResultReader's use. 1438 func (mrr *MultiResultReader) Close() error { 1439 for !mrr.closed { 1440 _, err := mrr.receiveMessage() 1441 if err != nil { 1442 return mrr.err 1443 } 1444 } 1445 1446 return mrr.err 1447 } 1448 1449 // ResultReader is a reader for the result of a single query. 1450 type ResultReader struct { 1451 pgConn *PgConn 1452 multiResultReader *MultiResultReader 1453 ctx context.Context 1454 1455 fieldDescriptions []pgproto3.FieldDescription 1456 rowValues [][]byte 1457 commandTag CommandTag 1458 commandConcluded bool 1459 closed bool 1460 err error 1461 } 1462 1463 // Result is the saved query response that is returned by calling Read on a ResultReader. 1464 type Result struct { 1465 FieldDescriptions []pgproto3.FieldDescription 1466 Rows [][][]byte 1467 CommandTag CommandTag 1468 Err error 1469 } 1470 1471 // Read saves the query response to a Result. 1472 func (rr *ResultReader) Read() *Result { 1473 br := &Result{} 1474 1475 for rr.NextRow() { 1476 if br.FieldDescriptions == nil { 1477 br.FieldDescriptions = make([]pgproto3.FieldDescription, len(rr.FieldDescriptions())) 1478 copy(br.FieldDescriptions, rr.FieldDescriptions()) 1479 } 1480 1481 row := make([][]byte, len(rr.Values())) 1482 copy(row, rr.Values()) 1483 br.Rows = append(br.Rows, row) 1484 } 1485 1486 br.CommandTag, br.Err = rr.Close() 1487 1488 return br 1489 } 1490 1491 // NextRow advances the ResultReader to the next row and returns true if a row is available. 1492 func (rr *ResultReader) NextRow() bool { 1493 for !rr.commandConcluded { 1494 msg, err := rr.receiveMessage() 1495 if err != nil { 1496 return false 1497 } 1498 1499 switch msg := msg.(type) { 1500 case *pgproto3.DataRow: 1501 rr.rowValues = msg.Values 1502 return true 1503 } 1504 } 1505 1506 return false 1507 } 1508 1509 // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until 1510 // the ResultReader is closed. 1511 func (rr *ResultReader) FieldDescriptions() []pgproto3.FieldDescription { 1512 return rr.fieldDescriptions 1513 } 1514 1515 // Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only 1516 // valid until the next NextRow call or the ResultReader is closed. However, the underlying byte data is safe to 1517 // retain a reference to and mutate. 1518 func (rr *ResultReader) Values() [][]byte { 1519 return rr.rowValues 1520 } 1521 1522 // Close consumes any remaining result data and returns the command tag or 1523 // error. 1524 func (rr *ResultReader) Close() (CommandTag, error) { 1525 if rr.closed { 1526 return rr.commandTag, rr.err 1527 } 1528 rr.closed = true 1529 1530 for !rr.commandConcluded { 1531 _, err := rr.receiveMessage() 1532 if err != nil { 1533 return nil, rr.err 1534 } 1535 } 1536 1537 if rr.multiResultReader == nil { 1538 for { 1539 msg, err := rr.receiveMessage() 1540 if err != nil { 1541 return nil, rr.err 1542 } 1543 1544 switch msg := msg.(type) { 1545 // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. 1546 case *pgproto3.ErrorResponse: 1547 rr.err = ErrorResponseToPgError(msg) 1548 case *pgproto3.ReadyForQuery: 1549 rr.pgConn.contextWatcher.Unwatch() 1550 rr.pgConn.unlock() 1551 return rr.commandTag, rr.err 1552 } 1553 } 1554 } 1555 1556 return rr.commandTag, rr.err 1557 } 1558 1559 // readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any 1560 // error will be stored in the ResultReader. 1561 func (rr *ResultReader) readUntilRowDescription() { 1562 for !rr.commandConcluded { 1563 // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. 1564 // This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are 1565 // manually used to construct a query that does not issue a describe statement. 1566 msg, _ := rr.pgConn.peekMessage() 1567 if _, ok := msg.(*pgproto3.DataRow); ok { 1568 return 1569 } 1570 1571 // Consume the message 1572 msg, _ = rr.receiveMessage() 1573 if _, ok := msg.(*pgproto3.RowDescription); ok { 1574 return 1575 } 1576 } 1577 } 1578 1579 func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { 1580 if rr.multiResultReader == nil { 1581 msg, err = rr.pgConn.receiveMessage() 1582 } else { 1583 msg, err = rr.multiResultReader.receiveMessage() 1584 } 1585 1586 if err != nil { 1587 err = preferContextOverNetTimeoutError(rr.ctx, err) 1588 rr.concludeCommand(nil, err) 1589 rr.pgConn.contextWatcher.Unwatch() 1590 rr.closed = true 1591 if rr.multiResultReader == nil { 1592 rr.pgConn.asyncClose() 1593 } 1594 1595 return nil, rr.err 1596 } 1597 1598 switch msg := msg.(type) { 1599 case *pgproto3.RowDescription: 1600 rr.fieldDescriptions = msg.Fields 1601 case *pgproto3.CommandComplete: 1602 rr.concludeCommand(CommandTag(msg.CommandTag), nil) 1603 case *pgproto3.EmptyQueryResponse: 1604 rr.concludeCommand(nil, nil) 1605 case *pgproto3.ErrorResponse: 1606 rr.concludeCommand(nil, ErrorResponseToPgError(msg)) 1607 } 1608 1609 return msg, nil 1610 } 1611 1612 func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { 1613 // Keep the first error that is recorded. Store the error before checking if the command is already concluded to 1614 // allow for receiving an error after CommandComplete but before ReadyForQuery. 1615 if err != nil && rr.err == nil { 1616 rr.err = err 1617 } 1618 1619 if rr.commandConcluded { 1620 return 1621 } 1622 1623 rr.commandTag = commandTag 1624 rr.rowValues = nil 1625 rr.commandConcluded = true 1626 } 1627 1628 // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. 1629 type Batch struct { 1630 buf []byte 1631 } 1632 1633 // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. 1634 func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { 1635 batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) 1636 batch.ExecPrepared("", paramValues, paramFormats, resultFormats) 1637 } 1638 1639 // ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. 1640 func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { 1641 batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) 1642 batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) 1643 batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) 1644 } 1645 1646 // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a 1647 // transaction is already in progress or SQL contains transaction control statements. 1648 func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { 1649 if err := pgConn.lock(); err != nil { 1650 return &MultiResultReader{ 1651 closed: true, 1652 err: err, 1653 } 1654 } 1655 1656 pgConn.multiResultReader = MultiResultReader{ 1657 pgConn: pgConn, 1658 ctx: ctx, 1659 } 1660 multiResult := &pgConn.multiResultReader 1661 1662 if ctx != context.Background() { 1663 select { 1664 case <-ctx.Done(): 1665 multiResult.closed = true 1666 multiResult.err = newContextAlreadyDoneError(ctx) 1667 pgConn.unlock() 1668 return multiResult 1669 default: 1670 } 1671 pgConn.contextWatcher.Watch(ctx) 1672 } 1673 1674 batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) 1675 1676 // A large batch can deadlock without concurrent reading and writing. If the Write fails the underlying net.Conn is 1677 // closed. This is all that can be done without introducing a race condition or adding a concurrent safe communication 1678 // channel to relay the error back. The practical effect of this is that the underlying Write error is not reported. 1679 // The error the code reading the batch results receives will be a closed connection error. 1680 // 1681 // See https://github.com/jackc/pgx/issues/374. 1682 go func() { 1683 _, err := pgConn.conn.Write(batch.buf) 1684 if err != nil { 1685 pgConn.conn.Close() 1686 } 1687 }() 1688 1689 return multiResult 1690 } 1691 1692 // EscapeString escapes a string such that it can safely be interpolated into a SQL command string. It does not include 1693 // the surrounding single quotes. 1694 // 1695 // The current implementation requires that standard_conforming_strings=on and client_encoding="UTF8". If these 1696 // conditions are not met an error will be returned. It is possible these restrictions will be lifted in the future. 1697 func (pgConn *PgConn) EscapeString(s string) (string, error) { 1698 if pgConn.ParameterStatus("standard_conforming_strings") != "on" { 1699 return "", errors.New("EscapeString must be run with standard_conforming_strings=on") 1700 } 1701 1702 if pgConn.ParameterStatus("client_encoding") != "UTF8" { 1703 return "", errors.New("EscapeString must be run with client_encoding=UTF8") 1704 } 1705 1706 return strings.Replace(s, "'", "''", -1), nil 1707 } 1708 1709 // HijackedConn is the result of hijacking a connection. 1710 // 1711 // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning 1712 // compatibility. 1713 type HijackedConn struct { 1714 Conn net.Conn // the underlying TCP or unix domain socket connection 1715 PID uint32 // backend pid 1716 SecretKey uint32 // key to use to send a cancel query message to the server 1717 ParameterStatuses map[string]string // parameters that have been reported by the server 1718 TxStatus byte 1719 Frontend Frontend 1720 Config *Config 1721 } 1722 1723 // Hijack extracts the internal connection data. pgConn must be in an idle state. pgConn is unusable after hijacking. 1724 // Hijacking is typically only useful when using pgconn to establish a connection, but taking complete control of the 1725 // raw connection after that (e.g. a load balancer or proxy). 1726 // 1727 // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning 1728 // compatibility. 1729 func (pgConn *PgConn) Hijack() (*HijackedConn, error) { 1730 if err := pgConn.lock(); err != nil { 1731 return nil, err 1732 } 1733 pgConn.status = connStatusClosed 1734 1735 return &HijackedConn{ 1736 Conn: pgConn.conn, 1737 PID: pgConn.pid, 1738 SecretKey: pgConn.secretKey, 1739 ParameterStatuses: pgConn.parameterStatuses, 1740 TxStatus: pgConn.txStatus, 1741 Frontend: pgConn.frontend, 1742 Config: pgConn.config, 1743 }, nil 1744 } 1745 1746 // Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of 1747 // PgConn.Hijack. The connection must be in an idle state. 1748 // 1749 // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning 1750 // compatibility. 1751 func Construct(hc *HijackedConn) (*PgConn, error) { 1752 pgConn := &PgConn{ 1753 conn: hc.Conn, 1754 pid: hc.PID, 1755 secretKey: hc.SecretKey, 1756 parameterStatuses: hc.ParameterStatuses, 1757 txStatus: hc.TxStatus, 1758 frontend: hc.Frontend, 1759 config: hc.Config, 1760 1761 status: connStatusIdle, 1762 1763 wbuf: make([]byte, 0, wbufLen), 1764 cleanupDone: make(chan struct{}), 1765 } 1766 1767 pgConn.contextWatcher = newContextWatcher(pgConn.conn) 1768 1769 return pgConn, nil 1770 }