config.go (27638B)
1 package pgconn 2 3 import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "encoding/pem" 8 "errors" 9 "fmt" 10 "io" 11 "math" 12 "net" 13 "net/url" 14 "os" 15 "path/filepath" 16 "strconv" 17 "strings" 18 "time" 19 20 "github.com/jackc/pgpassfile" 21 "github.com/jackc/pgservicefile" 22 "github.com/jackc/pgx/v5/pgproto3" 23 ) 24 25 type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error 26 type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error 27 type GetSSLPasswordFunc func(ctx context.Context) string 28 29 // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A 30 // manually initialized Config will cause ConnectConfig to panic. 31 type Config struct { 32 Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) 33 Port uint16 34 Database string 35 User string 36 Password string 37 TLSConfig *tls.Config // nil disables TLS 38 ConnectTimeout time.Duration 39 DialFunc DialFunc // e.g. net.Dialer.DialContext 40 LookupFunc LookupFunc // e.g. net.Resolver.LookupHost 41 BuildFrontend BuildFrontendFunc 42 RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) 43 44 KerberosSrvName string 45 KerberosSpn string 46 Fallbacks []*FallbackConfig 47 48 // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. 49 // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next 50 // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. 51 ValidateConnect ValidateConnectFunc 52 53 // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables 54 // or prepare statements). If this returns an error the connection attempt fails. 55 AfterConnect AfterConnectFunc 56 57 // OnNotice is a callback function called when a notice response is received. 58 OnNotice NoticeHandler 59 60 // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. 61 OnNotification NotificationHandler 62 63 createdByParseConfig bool // Used to enforce created by ParseConfig rule. 64 } 65 66 // ParseConfigOptions contains options that control how a config is built such as GetSSLPassword. 67 type ParseConfigOptions struct { 68 // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function 69 // PQsetSSLKeyPassHook_OpenSSL. 70 GetSSLPassword GetSSLPasswordFunc 71 } 72 73 // Copy returns a deep copy of the config that is safe to use and modify. 74 // The only exception is the TLSConfig field: 75 // according to the tls.Config docs it must not be modified after creation. 76 func (c *Config) Copy() *Config { 77 newConf := new(Config) 78 *newConf = *c 79 if newConf.TLSConfig != nil { 80 newConf.TLSConfig = c.TLSConfig.Clone() 81 } 82 if newConf.RuntimeParams != nil { 83 newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams)) 84 for k, v := range c.RuntimeParams { 85 newConf.RuntimeParams[k] = v 86 } 87 } 88 if newConf.Fallbacks != nil { 89 newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks)) 90 for i, fallback := range c.Fallbacks { 91 newFallback := new(FallbackConfig) 92 *newFallback = *fallback 93 if newFallback.TLSConfig != nil { 94 newFallback.TLSConfig = fallback.TLSConfig.Clone() 95 } 96 newConf.Fallbacks[i] = newFallback 97 } 98 } 99 return newConf 100 } 101 102 // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a 103 // network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. 104 type FallbackConfig struct { 105 Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) 106 Port uint16 107 TLSConfig *tls.Config // nil disables TLS 108 } 109 110 // isAbsolutePath checks if the provided value is an absolute path either 111 // beginning with a forward slash (as on Linux-based systems) or with a capital 112 // letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows). 113 func isAbsolutePath(path string) bool { 114 isWindowsPath := func(p string) bool { 115 if len(p) < 3 { 116 return false 117 } 118 drive := p[0] 119 colon := p[1] 120 backslash := p[2] 121 if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' { 122 return true 123 } 124 return false 125 } 126 return strings.HasPrefix(path, "/") || isWindowsPath(path) 127 } 128 129 // NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with 130 // net.Dial. 131 func NetworkAddress(host string, port uint16) (network, address string) { 132 if isAbsolutePath(host) { 133 network = "unix" 134 address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) 135 } else { 136 network = "tcp" 137 address = net.JoinHostPort(host, strconv.Itoa(int(port))) 138 } 139 return network, address 140 } 141 142 // ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It 143 // uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely 144 // matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). 145 // See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be 146 // empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. 147 // 148 // # Example DSN 149 // user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca 150 // 151 // # Example URL 152 // postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca 153 // 154 // The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done 155 // through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be 156 // interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should 157 // not be modified individually. They should all be modified or all left unchanged. 158 // 159 // ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated 160 // values that will be tried in order. This can be used as part of a high availability system. See 161 // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. 162 // 163 // # Example URL 164 // postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb 165 // 166 // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed 167 // via database URL or DSN: 168 // 169 // PGHOST 170 // PGPORT 171 // PGDATABASE 172 // PGUSER 173 // PGPASSWORD 174 // PGPASSFILE 175 // PGSERVICE 176 // PGSERVICEFILE 177 // PGSSLMODE 178 // PGSSLCERT 179 // PGSSLKEY 180 // PGSSLROOTCERT 181 // PGSSLPASSWORD 182 // PGAPPNAME 183 // PGCONNECT_TIMEOUT 184 // PGTARGETSESSIONATTRS 185 // 186 // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. 187 // 188 // See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are 189 // usually but not always the environment variable name downcased and without the "PG" prefix. 190 // 191 // Important Security Notes: 192 // 193 // ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if 194 // not set. 195 // 196 // See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of 197 // security each sslmode provides. 198 // 199 // The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of 200 // the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of 201 // sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback 202 // which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually 203 // changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting 204 // TLSConfig. 205 // 206 // Other known differences with libpq: 207 // 208 // When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn 209 // does not. 210 // 211 // In addition, ParseConfig accepts the following options: 212 // 213 // - servicefile. 214 // libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a 215 // part of the connection string. 216 func ParseConfig(connString string) (*Config, error) { 217 var parseConfigOptions ParseConfigOptions 218 return ParseConfigWithOptions(connString, parseConfigOptions) 219 } 220 221 // ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard 222 // C library libpq. options contains settings that cannot be specified in a connString such as providing a function to 223 // get the SSL password. 224 func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) { 225 defaultSettings := defaultSettings() 226 envSettings := parseEnvSettings() 227 228 connStringSettings := make(map[string]string) 229 if connString != "" { 230 var err error 231 // connString may be a database URL or a DSN 232 if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { 233 connStringSettings, err = parseURLSettings(connString) 234 if err != nil { 235 return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} 236 } 237 } else { 238 connStringSettings, err = parseDSNSettings(connString) 239 if err != nil { 240 return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} 241 } 242 } 243 } 244 245 settings := mergeSettings(defaultSettings, envSettings, connStringSettings) 246 if service, present := settings["service"]; present { 247 serviceSettings, err := parseServiceSettings(settings["servicefile"], service) 248 if err != nil { 249 return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err} 250 } 251 252 settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) 253 } 254 255 config := &Config{ 256 createdByParseConfig: true, 257 Database: settings["database"], 258 User: settings["user"], 259 Password: settings["password"], 260 RuntimeParams: make(map[string]string), 261 BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { 262 return pgproto3.NewFrontend(r, w) 263 }, 264 } 265 266 if connectTimeoutSetting, present := settings["connect_timeout"]; present { 267 connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting) 268 if err != nil { 269 return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} 270 } 271 config.ConnectTimeout = connectTimeout 272 config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout) 273 } else { 274 defaultDialer := makeDefaultDialer() 275 config.DialFunc = defaultDialer.DialContext 276 } 277 278 config.LookupFunc = makeDefaultResolver().LookupHost 279 280 notRuntimeParams := map[string]struct{}{ 281 "host": {}, 282 "port": {}, 283 "database": {}, 284 "user": {}, 285 "password": {}, 286 "passfile": {}, 287 "connect_timeout": {}, 288 "sslmode": {}, 289 "sslkey": {}, 290 "sslcert": {}, 291 "sslrootcert": {}, 292 "sslpassword": {}, 293 "sslsni": {}, 294 "krbspn": {}, 295 "krbsrvname": {}, 296 "target_session_attrs": {}, 297 "service": {}, 298 "servicefile": {}, 299 } 300 301 // Adding kerberos configuration 302 if _, present := settings["krbsrvname"]; present { 303 config.KerberosSrvName = settings["krbsrvname"] 304 } 305 if _, present := settings["krbspn"]; present { 306 config.KerberosSpn = settings["krbspn"] 307 } 308 309 for k, v := range settings { 310 if _, present := notRuntimeParams[k]; present { 311 continue 312 } 313 config.RuntimeParams[k] = v 314 } 315 316 fallbacks := []*FallbackConfig{} 317 318 hosts := strings.Split(settings["host"], ",") 319 ports := strings.Split(settings["port"], ",") 320 321 for i, host := range hosts { 322 var portStr string 323 if i < len(ports) { 324 portStr = ports[i] 325 } else { 326 portStr = ports[0] 327 } 328 329 port, err := parsePort(portStr) 330 if err != nil { 331 return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} 332 } 333 334 var tlsConfigs []*tls.Config 335 336 // Ignore TLS settings if Unix domain socket like libpq 337 if network, _ := NetworkAddress(host, port); network == "unix" { 338 tlsConfigs = append(tlsConfigs, nil) 339 } else { 340 var err error 341 tlsConfigs, err = configTLS(settings, host, options) 342 if err != nil { 343 return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} 344 } 345 } 346 347 for _, tlsConfig := range tlsConfigs { 348 fallbacks = append(fallbacks, &FallbackConfig{ 349 Host: host, 350 Port: port, 351 TLSConfig: tlsConfig, 352 }) 353 } 354 } 355 356 config.Host = fallbacks[0].Host 357 config.Port = fallbacks[0].Port 358 config.TLSConfig = fallbacks[0].TLSConfig 359 config.Fallbacks = fallbacks[1:] 360 361 passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) 362 if err == nil { 363 if config.Password == "" { 364 host := config.Host 365 if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { 366 host = "localhost" 367 } 368 369 config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User) 370 } 371 } 372 373 switch tsa := settings["target_session_attrs"]; tsa { 374 case "read-write": 375 config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite 376 case "read-only": 377 config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly 378 case "primary": 379 config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary 380 case "standby": 381 config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby 382 case "prefer-standby": 383 config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby 384 case "any": 385 // do nothing 386 default: 387 return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} 388 } 389 390 return config, nil 391 } 392 393 func mergeSettings(settingSets ...map[string]string) map[string]string { 394 settings := make(map[string]string) 395 396 for _, s2 := range settingSets { 397 for k, v := range s2 { 398 settings[k] = v 399 } 400 } 401 402 return settings 403 } 404 405 func parseEnvSettings() map[string]string { 406 settings := make(map[string]string) 407 408 nameMap := map[string]string{ 409 "PGHOST": "host", 410 "PGPORT": "port", 411 "PGDATABASE": "database", 412 "PGUSER": "user", 413 "PGPASSWORD": "password", 414 "PGPASSFILE": "passfile", 415 "PGAPPNAME": "application_name", 416 "PGCONNECT_TIMEOUT": "connect_timeout", 417 "PGSSLMODE": "sslmode", 418 "PGSSLKEY": "sslkey", 419 "PGSSLCERT": "sslcert", 420 "PGSSLSNI": "sslsni", 421 "PGSSLROOTCERT": "sslrootcert", 422 "PGSSLPASSWORD": "sslpassword", 423 "PGTARGETSESSIONATTRS": "target_session_attrs", 424 "PGSERVICE": "service", 425 "PGSERVICEFILE": "servicefile", 426 } 427 428 for envname, realname := range nameMap { 429 value := os.Getenv(envname) 430 if value != "" { 431 settings[realname] = value 432 } 433 } 434 435 return settings 436 } 437 438 func parseURLSettings(connString string) (map[string]string, error) { 439 settings := make(map[string]string) 440 441 url, err := url.Parse(connString) 442 if err != nil { 443 return nil, err 444 } 445 446 if url.User != nil { 447 settings["user"] = url.User.Username() 448 if password, present := url.User.Password(); present { 449 settings["password"] = password 450 } 451 } 452 453 // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. 454 var hosts []string 455 var ports []string 456 for _, host := range strings.Split(url.Host, ",") { 457 if host == "" { 458 continue 459 } 460 if isIPOnly(host) { 461 hosts = append(hosts, strings.Trim(host, "[]")) 462 continue 463 } 464 h, p, err := net.SplitHostPort(host) 465 if err != nil { 466 return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err) 467 } 468 if h != "" { 469 hosts = append(hosts, h) 470 } 471 if p != "" { 472 ports = append(ports, p) 473 } 474 } 475 if len(hosts) > 0 { 476 settings["host"] = strings.Join(hosts, ",") 477 } 478 if len(ports) > 0 { 479 settings["port"] = strings.Join(ports, ",") 480 } 481 482 database := strings.TrimLeft(url.Path, "/") 483 if database != "" { 484 settings["database"] = database 485 } 486 487 nameMap := map[string]string{ 488 "dbname": "database", 489 } 490 491 for k, v := range url.Query() { 492 if k2, present := nameMap[k]; present { 493 k = k2 494 } 495 496 settings[k] = v[0] 497 } 498 499 return settings, nil 500 } 501 502 func isIPOnly(host string) bool { 503 return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":") 504 } 505 506 var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} 507 508 func parseDSNSettings(s string) (map[string]string, error) { 509 settings := make(map[string]string) 510 511 nameMap := map[string]string{ 512 "dbname": "database", 513 } 514 515 for len(s) > 0 { 516 var key, val string 517 eqIdx := strings.IndexRune(s, '=') 518 if eqIdx < 0 { 519 return nil, errors.New("invalid dsn") 520 } 521 522 key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") 523 s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f") 524 if len(s) == 0 { 525 } else if s[0] != '\'' { 526 end := 0 527 for ; end < len(s); end++ { 528 if asciiSpace[s[end]] == 1 { 529 break 530 } 531 if s[end] == '\\' { 532 end++ 533 if end == len(s) { 534 return nil, errors.New("invalid backslash") 535 } 536 } 537 } 538 val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) 539 if end == len(s) { 540 s = "" 541 } else { 542 s = s[end+1:] 543 } 544 } else { // quoted string 545 s = s[1:] 546 end := 0 547 for ; end < len(s); end++ { 548 if s[end] == '\'' { 549 break 550 } 551 if s[end] == '\\' { 552 end++ 553 } 554 } 555 if end == len(s) { 556 return nil, errors.New("unterminated quoted string in connection info string") 557 } 558 val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) 559 if end == len(s) { 560 s = "" 561 } else { 562 s = s[end+1:] 563 } 564 } 565 566 if k, ok := nameMap[key]; ok { 567 key = k 568 } 569 570 if key == "" { 571 return nil, errors.New("invalid dsn") 572 } 573 574 settings[key] = val 575 } 576 577 return settings, nil 578 } 579 580 func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) { 581 servicefile, err := pgservicefile.ReadServicefile(servicefilePath) 582 if err != nil { 583 return nil, fmt.Errorf("failed to read service file: %v", servicefilePath) 584 } 585 586 service, err := servicefile.GetService(serviceName) 587 if err != nil { 588 return nil, fmt.Errorf("unable to find service: %v", serviceName) 589 } 590 591 nameMap := map[string]string{ 592 "dbname": "database", 593 } 594 595 settings := make(map[string]string, len(service.Settings)) 596 for k, v := range service.Settings { 597 if k2, present := nameMap[k]; present { 598 k = k2 599 } 600 settings[k] = v 601 } 602 603 return settings, nil 604 } 605 606 // configTLS uses libpq's TLS parameters to construct []*tls.Config. It is 607 // necessary to allow returning multiple TLS configs as sslmode "allow" and 608 // "prefer" allow fallback. 609 func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) { 610 host := thisHost 611 sslmode := settings["sslmode"] 612 sslrootcert := settings["sslrootcert"] 613 sslcert := settings["sslcert"] 614 sslkey := settings["sslkey"] 615 sslpassword := settings["sslpassword"] 616 sslsni := settings["sslsni"] 617 618 // Match libpq default behavior 619 if sslmode == "" { 620 sslmode = "prefer" 621 } 622 if sslsni == "" { 623 sslsni = "1" 624 } 625 626 tlsConfig := &tls.Config{} 627 628 switch sslmode { 629 case "disable": 630 return []*tls.Config{nil}, nil 631 case "allow", "prefer": 632 tlsConfig.InsecureSkipVerify = true 633 case "require": 634 // According to PostgreSQL documentation, if a root CA file exists, 635 // the behavior of sslmode=require should be the same as that of verify-ca 636 // 637 // See https://www.postgresql.org/docs/12/libpq-ssl.html 638 if sslrootcert != "" { 639 goto nextCase 640 } 641 tlsConfig.InsecureSkipVerify = true 642 break 643 nextCase: 644 fallthrough 645 case "verify-ca": 646 // Don't perform the default certificate verification because it 647 // will verify the hostname. Instead, verify the server's 648 // certificate chain ourselves in VerifyPeerCertificate and 649 // ignore the server name. This emulates libpq's verify-ca 650 // behavior. 651 // 652 // See https://github.com/golang/go/issues/21971#issuecomment-332693931 653 // and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate 654 // for more info. 655 tlsConfig.InsecureSkipVerify = true 656 tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error { 657 certs := make([]*x509.Certificate, len(certificates)) 658 for i, asn1Data := range certificates { 659 cert, err := x509.ParseCertificate(asn1Data) 660 if err != nil { 661 return errors.New("failed to parse certificate from server: " + err.Error()) 662 } 663 certs[i] = cert 664 } 665 666 // Leave DNSName empty to skip hostname verification. 667 opts := x509.VerifyOptions{ 668 Roots: tlsConfig.RootCAs, 669 Intermediates: x509.NewCertPool(), 670 } 671 // Skip the first cert because it's the leaf. All others 672 // are intermediates. 673 for _, cert := range certs[1:] { 674 opts.Intermediates.AddCert(cert) 675 } 676 _, err := certs[0].Verify(opts) 677 return err 678 } 679 case "verify-full": 680 tlsConfig.ServerName = host 681 default: 682 return nil, errors.New("sslmode is invalid") 683 } 684 685 if sslrootcert != "" { 686 caCertPool := x509.NewCertPool() 687 688 caPath := sslrootcert 689 caCert, err := os.ReadFile(caPath) 690 if err != nil { 691 return nil, fmt.Errorf("unable to read CA file: %w", err) 692 } 693 694 if !caCertPool.AppendCertsFromPEM(caCert) { 695 return nil, errors.New("unable to add CA to cert pool") 696 } 697 698 tlsConfig.RootCAs = caCertPool 699 tlsConfig.ClientCAs = caCertPool 700 } 701 702 if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { 703 return nil, errors.New(`both "sslcert" and "sslkey" are required`) 704 } 705 706 if sslcert != "" && sslkey != "" { 707 buf, err := os.ReadFile(sslkey) 708 if err != nil { 709 return nil, fmt.Errorf("unable to read sslkey: %w", err) 710 } 711 block, _ := pem.Decode(buf) 712 var pemKey []byte 713 var decryptedKey []byte 714 var decryptedError error 715 // If PEM is encrypted, attempt to decrypt using pass phrase 716 if x509.IsEncryptedPEMBlock(block) { 717 // Attempt decryption with pass phrase 718 // NOTE: only supports RSA (PKCS#1) 719 if sslpassword != "" { 720 decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) 721 } 722 //if sslpassword not provided or has decryption error when use it 723 //try to find sslpassword with callback function 724 if sslpassword == "" || decryptedError != nil { 725 if parseConfigOptions.GetSSLPassword != nil { 726 sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) 727 } 728 if sslpassword == "" { 729 return nil, fmt.Errorf("unable to find sslpassword") 730 } 731 } 732 decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) 733 // Should we also provide warning for PKCS#1 needed? 734 if decryptedError != nil { 735 return nil, fmt.Errorf("unable to decrypt key: %w", err) 736 } 737 738 pemBytes := pem.Block{ 739 Type: "RSA PRIVATE KEY", 740 Bytes: decryptedKey, 741 } 742 pemKey = pem.EncodeToMemory(&pemBytes) 743 } else { 744 pemKey = pem.EncodeToMemory(block) 745 } 746 certfile, err := os.ReadFile(sslcert) 747 if err != nil { 748 return nil, fmt.Errorf("unable to read cert: %w", err) 749 } 750 cert, err := tls.X509KeyPair(certfile, pemKey) 751 if err != nil { 752 return nil, fmt.Errorf("unable to load cert: %w", err) 753 } 754 tlsConfig.Certificates = []tls.Certificate{cert} 755 } 756 757 // Set Server Name Indication (SNI), if enabled by connection parameters. 758 // Per RFC 6066, do not set it if the host is a literal IP address (IPv4 759 // or IPv6). 760 if sslsni == "1" && net.ParseIP(host) == nil { 761 tlsConfig.ServerName = host 762 } 763 764 switch sslmode { 765 case "allow": 766 return []*tls.Config{nil, tlsConfig}, nil 767 case "prefer": 768 return []*tls.Config{tlsConfig, nil}, nil 769 case "require", "verify-ca", "verify-full": 770 return []*tls.Config{tlsConfig}, nil 771 default: 772 panic("BUG: bad sslmode should already have been caught") 773 } 774 } 775 776 func parsePort(s string) (uint16, error) { 777 port, err := strconv.ParseUint(s, 10, 16) 778 if err != nil { 779 return 0, err 780 } 781 if port < 1 || port > math.MaxUint16 { 782 return 0, errors.New("outside range") 783 } 784 return uint16(port), nil 785 } 786 787 func makeDefaultDialer() *net.Dialer { 788 return &net.Dialer{KeepAlive: 5 * time.Minute} 789 } 790 791 func makeDefaultResolver() *net.Resolver { 792 return net.DefaultResolver 793 } 794 795 func parseConnectTimeoutSetting(s string) (time.Duration, error) { 796 timeout, err := strconv.ParseInt(s, 10, 64) 797 if err != nil { 798 return 0, err 799 } 800 if timeout < 0 { 801 return 0, errors.New("negative timeout") 802 } 803 return time.Duration(timeout) * time.Second, nil 804 } 805 806 func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc { 807 d := makeDefaultDialer() 808 d.Timeout = timeout 809 return d.DialContext 810 } 811 812 // ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible 813 // target_session_attrs=read-write. 814 func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { 815 result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() 816 if result.Err != nil { 817 return result.Err 818 } 819 820 if string(result.Rows[0][0]) == "on" { 821 return errors.New("read only connection") 822 } 823 824 return nil 825 } 826 827 // ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible 828 // target_session_attrs=read-only. 829 func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error { 830 result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() 831 if result.Err != nil { 832 return result.Err 833 } 834 835 if string(result.Rows[0][0]) != "on" { 836 return errors.New("connection is not read only") 837 } 838 839 return nil 840 } 841 842 // ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible 843 // target_session_attrs=standby. 844 func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error { 845 result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() 846 if result.Err != nil { 847 return result.Err 848 } 849 850 if string(result.Rows[0][0]) != "t" { 851 return errors.New("server is not in hot standby mode") 852 } 853 854 return nil 855 } 856 857 // ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible 858 // target_session_attrs=primary. 859 func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error { 860 result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() 861 if result.Err != nil { 862 return result.Err 863 } 864 865 if string(result.Rows[0][0]) == "t" { 866 return errors.New("server is in standby mode") 867 } 868 869 return nil 870 } 871 872 // ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible 873 // target_session_attrs=prefer-standby. 874 func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error { 875 result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() 876 if result.Err != nil { 877 return result.Err 878 } 879 880 if string(result.Rows[0][0]) != "t" { 881 return &NotPreferredError{err: errors.New("server is not in hot standby mode")} 882 } 883 884 return nil 885 }