config.go (28397B)
1 package pgconn 2 3 import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "encoding/pem" 8 "errors" 9 "fmt" 10 "io" 11 "io/ioutil" 12 "math" 13 "net" 14 "net/url" 15 "os" 16 "path/filepath" 17 "strconv" 18 "strings" 19 "time" 20 21 "github.com/jackc/chunkreader/v2" 22 "github.com/jackc/pgpassfile" 23 "github.com/jackc/pgproto3/v2" 24 "github.com/jackc/pgservicefile" 25 ) 26 27 type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error 28 type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error 29 type GetSSLPasswordFunc func(ctx context.Context) string 30 31 // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A 32 // manually initialized Config will cause ConnectConfig to panic. 33 type Config struct { 34 Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) 35 Port uint16 36 Database string 37 User string 38 Password string 39 TLSConfig *tls.Config // nil disables TLS 40 ConnectTimeout time.Duration 41 DialFunc DialFunc // e.g. net.Dialer.DialContext 42 LookupFunc LookupFunc // e.g. net.Resolver.LookupHost 43 BuildFrontend BuildFrontendFunc 44 RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) 45 46 KerberosSrvName string 47 KerberosSpn string 48 Fallbacks []*FallbackConfig 49 50 // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. 51 // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next 52 // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. 53 ValidateConnect ValidateConnectFunc 54 55 // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables 56 // or prepare statements). If this returns an error the connection attempt fails. 57 AfterConnect AfterConnectFunc 58 59 // OnNotice is a callback function called when a notice response is received. 60 OnNotice NoticeHandler 61 62 // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. 63 OnNotification NotificationHandler 64 65 createdByParseConfig bool // Used to enforce created by ParseConfig rule. 66 } 67 68 // ParseConfigOptions contains options that control how a config is built such as getsslpassword. 69 type ParseConfigOptions struct { 70 // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function 71 // PQsetSSLKeyPassHook_OpenSSL. 72 GetSSLPassword GetSSLPasswordFunc 73 } 74 75 // Copy returns a deep copy of the config that is safe to use and modify. 76 // The only exception is the TLSConfig field: 77 // according to the tls.Config docs it must not be modified after creation. 78 func (c *Config) Copy() *Config { 79 newConf := new(Config) 80 *newConf = *c 81 if newConf.TLSConfig != nil { 82 newConf.TLSConfig = c.TLSConfig.Clone() 83 } 84 if newConf.RuntimeParams != nil { 85 newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams)) 86 for k, v := range c.RuntimeParams { 87 newConf.RuntimeParams[k] = v 88 } 89 } 90 if newConf.Fallbacks != nil { 91 newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks)) 92 for i, fallback := range c.Fallbacks { 93 newFallback := new(FallbackConfig) 94 *newFallback = *fallback 95 if newFallback.TLSConfig != nil { 96 newFallback.TLSConfig = fallback.TLSConfig.Clone() 97 } 98 newConf.Fallbacks[i] = newFallback 99 } 100 } 101 return newConf 102 } 103 104 // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a 105 // network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. 106 type FallbackConfig struct { 107 Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) 108 Port uint16 109 TLSConfig *tls.Config // nil disables TLS 110 } 111 112 // isAbsolutePath checks if the provided value is an absolute path either 113 // beginning with a forward slash (as on Linux-based systems) or with a capital 114 // letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows). 115 func isAbsolutePath(path string) bool { 116 isWindowsPath := func(p string) bool { 117 if len(p) < 3 { 118 return false 119 } 120 drive := p[0] 121 colon := p[1] 122 backslash := p[2] 123 if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' { 124 return true 125 } 126 return false 127 } 128 return strings.HasPrefix(path, "/") || isWindowsPath(path) 129 } 130 131 // NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with 132 // net.Dial. 133 func NetworkAddress(host string, port uint16) (network, address string) { 134 if isAbsolutePath(host) { 135 network = "unix" 136 address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) 137 } else { 138 network = "tcp" 139 address = net.JoinHostPort(host, strconv.Itoa(int(port))) 140 } 141 return network, address 142 } 143 144 // ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It 145 // uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely 146 // matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). 147 // See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be 148 // empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. 149 // 150 // # Example DSN 151 // user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca 152 // 153 // # Example URL 154 // postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca 155 // 156 // The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done 157 // through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be 158 // interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should 159 // not be modified individually. They should all be modified or all left unchanged. 160 // 161 // ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated 162 // values that will be tried in order. This can be used as part of a high availability system. See 163 // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. 164 // 165 // # Example URL 166 // postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb 167 // 168 // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed 169 // via database URL or DSN: 170 // 171 // PGHOST 172 // PGPORT 173 // PGDATABASE 174 // PGUSER 175 // PGPASSWORD 176 // PGPASSFILE 177 // PGSERVICE 178 // PGSERVICEFILE 179 // PGSSLMODE 180 // PGSSLCERT 181 // PGSSLKEY 182 // PGSSLROOTCERT 183 // PGSSLPASSWORD 184 // PGAPPNAME 185 // PGCONNECT_TIMEOUT 186 // PGTARGETSESSIONATTRS 187 // 188 // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. 189 // 190 // See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are 191 // usually but not always the environment variable name downcased and without the "PG" prefix. 192 // 193 // Important Security Notes: 194 // 195 // ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if 196 // not set. 197 // 198 // See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of 199 // security each sslmode provides. 200 // 201 // The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of 202 // the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of 203 // sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback 204 // which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually 205 // changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting 206 // TLSConfig. 207 // 208 // Other known differences with libpq: 209 // 210 // When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn 211 // does not. 212 // 213 // In addition, ParseConfig accepts the following options: 214 // 215 // min_read_buffer_size 216 // The minimum size of the internal read buffer. Default 8192. 217 // servicefile 218 // libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a 219 // part of the connection string. 220 func ParseConfig(connString string) (*Config, error) { 221 var parseConfigOptions ParseConfigOptions 222 return ParseConfigWithOptions(connString, parseConfigOptions) 223 } 224 225 // ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard 226 // C library libpq. options contains settings that cannot be specified in a connString such as providing a function to 227 // get the SSL password. 228 func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) { 229 defaultSettings := defaultSettings() 230 envSettings := parseEnvSettings() 231 232 connStringSettings := make(map[string]string) 233 if connString != "" { 234 var err error 235 // connString may be a database URL or a DSN 236 if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { 237 connStringSettings, err = parseURLSettings(connString) 238 if err != nil { 239 return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} 240 } 241 } else { 242 connStringSettings, err = parseDSNSettings(connString) 243 if err != nil { 244 return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} 245 } 246 } 247 } 248 249 settings := mergeSettings(defaultSettings, envSettings, connStringSettings) 250 if service, present := settings["service"]; present { 251 serviceSettings, err := parseServiceSettings(settings["servicefile"], service) 252 if err != nil { 253 return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err} 254 } 255 256 settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) 257 } 258 259 minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) 260 if err != nil { 261 return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err} 262 } 263 264 config := &Config{ 265 createdByParseConfig: true, 266 Database: settings["database"], 267 User: settings["user"], 268 Password: settings["password"], 269 RuntimeParams: make(map[string]string), 270 BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)), 271 } 272 273 if connectTimeoutSetting, present := settings["connect_timeout"]; present { 274 connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting) 275 if err != nil { 276 return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} 277 } 278 config.ConnectTimeout = connectTimeout 279 config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout) 280 } else { 281 defaultDialer := makeDefaultDialer() 282 config.DialFunc = defaultDialer.DialContext 283 } 284 285 config.LookupFunc = makeDefaultResolver().LookupHost 286 287 notRuntimeParams := map[string]struct{}{ 288 "host": {}, 289 "port": {}, 290 "database": {}, 291 "user": {}, 292 "password": {}, 293 "passfile": {}, 294 "connect_timeout": {}, 295 "sslmode": {}, 296 "sslkey": {}, 297 "sslcert": {}, 298 "sslrootcert": {}, 299 "sslpassword": {}, 300 "sslsni": {}, 301 "krbspn": {}, 302 "krbsrvname": {}, 303 "target_session_attrs": {}, 304 "min_read_buffer_size": {}, 305 "service": {}, 306 "servicefile": {}, 307 } 308 309 // Adding kerberos configuration 310 if _, present := settings["krbsrvname"]; present { 311 config.KerberosSrvName = settings["krbsrvname"] 312 } 313 if _, present := settings["krbspn"]; present { 314 config.KerberosSpn = settings["krbspn"] 315 } 316 317 for k, v := range settings { 318 if _, present := notRuntimeParams[k]; present { 319 continue 320 } 321 config.RuntimeParams[k] = v 322 } 323 324 fallbacks := []*FallbackConfig{} 325 326 hosts := strings.Split(settings["host"], ",") 327 ports := strings.Split(settings["port"], ",") 328 329 for i, host := range hosts { 330 var portStr string 331 if i < len(ports) { 332 portStr = ports[i] 333 } else { 334 portStr = ports[0] 335 } 336 337 port, err := parsePort(portStr) 338 if err != nil { 339 return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} 340 } 341 342 var tlsConfigs []*tls.Config 343 344 // Ignore TLS settings if Unix domain socket like libpq 345 if network, _ := NetworkAddress(host, port); network == "unix" { 346 tlsConfigs = append(tlsConfigs, nil) 347 } else { 348 var err error 349 tlsConfigs, err = configTLS(settings, host, options) 350 if err != nil { 351 return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} 352 } 353 } 354 355 for _, tlsConfig := range tlsConfigs { 356 fallbacks = append(fallbacks, &FallbackConfig{ 357 Host: host, 358 Port: port, 359 TLSConfig: tlsConfig, 360 }) 361 } 362 } 363 364 config.Host = fallbacks[0].Host 365 config.Port = fallbacks[0].Port 366 config.TLSConfig = fallbacks[0].TLSConfig 367 config.Fallbacks = fallbacks[1:] 368 369 passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) 370 if err == nil { 371 if config.Password == "" { 372 host := config.Host 373 if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { 374 host = "localhost" 375 } 376 377 config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User) 378 } 379 } 380 381 switch tsa := settings["target_session_attrs"]; tsa { 382 case "read-write": 383 config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite 384 case "read-only": 385 config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly 386 case "primary": 387 config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary 388 case "standby": 389 config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby 390 case "prefer-standby": 391 config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby 392 case "any": 393 // do nothing 394 default: 395 return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} 396 } 397 398 return config, nil 399 } 400 401 func mergeSettings(settingSets ...map[string]string) map[string]string { 402 settings := make(map[string]string) 403 404 for _, s2 := range settingSets { 405 for k, v := range s2 { 406 settings[k] = v 407 } 408 } 409 410 return settings 411 } 412 413 func parseEnvSettings() map[string]string { 414 settings := make(map[string]string) 415 416 nameMap := map[string]string{ 417 "PGHOST": "host", 418 "PGPORT": "port", 419 "PGDATABASE": "database", 420 "PGUSER": "user", 421 "PGPASSWORD": "password", 422 "PGPASSFILE": "passfile", 423 "PGAPPNAME": "application_name", 424 "PGCONNECT_TIMEOUT": "connect_timeout", 425 "PGSSLMODE": "sslmode", 426 "PGSSLKEY": "sslkey", 427 "PGSSLCERT": "sslcert", 428 "PGSSLSNI": "sslsni", 429 "PGSSLROOTCERT": "sslrootcert", 430 "PGSSLPASSWORD": "sslpassword", 431 "PGTARGETSESSIONATTRS": "target_session_attrs", 432 "PGSERVICE": "service", 433 "PGSERVICEFILE": "servicefile", 434 } 435 436 for envname, realname := range nameMap { 437 value := os.Getenv(envname) 438 if value != "" { 439 settings[realname] = value 440 } 441 } 442 443 return settings 444 } 445 446 func parseURLSettings(connString string) (map[string]string, error) { 447 settings := make(map[string]string) 448 449 url, err := url.Parse(connString) 450 if err != nil { 451 return nil, err 452 } 453 454 if url.User != nil { 455 settings["user"] = url.User.Username() 456 if password, present := url.User.Password(); present { 457 settings["password"] = password 458 } 459 } 460 461 // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. 462 var hosts []string 463 var ports []string 464 for _, host := range strings.Split(url.Host, ",") { 465 if host == "" { 466 continue 467 } 468 if isIPOnly(host) { 469 hosts = append(hosts, strings.Trim(host, "[]")) 470 continue 471 } 472 h, p, err := net.SplitHostPort(host) 473 if err != nil { 474 return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err) 475 } 476 if h != "" { 477 hosts = append(hosts, h) 478 } 479 if p != "" { 480 ports = append(ports, p) 481 } 482 } 483 if len(hosts) > 0 { 484 settings["host"] = strings.Join(hosts, ",") 485 } 486 if len(ports) > 0 { 487 settings["port"] = strings.Join(ports, ",") 488 } 489 490 database := strings.TrimLeft(url.Path, "/") 491 if database != "" { 492 settings["database"] = database 493 } 494 495 nameMap := map[string]string{ 496 "dbname": "database", 497 } 498 499 for k, v := range url.Query() { 500 if k2, present := nameMap[k]; present { 501 k = k2 502 } 503 504 settings[k] = v[0] 505 } 506 507 return settings, nil 508 } 509 510 func isIPOnly(host string) bool { 511 return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":") 512 } 513 514 var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} 515 516 func parseDSNSettings(s string) (map[string]string, error) { 517 settings := make(map[string]string) 518 519 nameMap := map[string]string{ 520 "dbname": "database", 521 } 522 523 for len(s) > 0 { 524 var key, val string 525 eqIdx := strings.IndexRune(s, '=') 526 if eqIdx < 0 { 527 return nil, errors.New("invalid dsn") 528 } 529 530 key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") 531 s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f") 532 if len(s) == 0 { 533 } else if s[0] != '\'' { 534 end := 0 535 for ; end < len(s); end++ { 536 if asciiSpace[s[end]] == 1 { 537 break 538 } 539 if s[end] == '\\' { 540 end++ 541 if end == len(s) { 542 return nil, errors.New("invalid backslash") 543 } 544 } 545 } 546 val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) 547 if end == len(s) { 548 s = "" 549 } else { 550 s = s[end+1:] 551 } 552 } else { // quoted string 553 s = s[1:] 554 end := 0 555 for ; end < len(s); end++ { 556 if s[end] == '\'' { 557 break 558 } 559 if s[end] == '\\' { 560 end++ 561 } 562 } 563 if end == len(s) { 564 return nil, errors.New("unterminated quoted string in connection info string") 565 } 566 val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) 567 if end == len(s) { 568 s = "" 569 } else { 570 s = s[end+1:] 571 } 572 } 573 574 if k, ok := nameMap[key]; ok { 575 key = k 576 } 577 578 if key == "" { 579 return nil, errors.New("invalid dsn") 580 } 581 582 settings[key] = val 583 } 584 585 return settings, nil 586 } 587 588 func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) { 589 servicefile, err := pgservicefile.ReadServicefile(servicefilePath) 590 if err != nil { 591 return nil, fmt.Errorf("failed to read service file: %v", servicefilePath) 592 } 593 594 service, err := servicefile.GetService(serviceName) 595 if err != nil { 596 return nil, fmt.Errorf("unable to find service: %v", serviceName) 597 } 598 599 nameMap := map[string]string{ 600 "dbname": "database", 601 } 602 603 settings := make(map[string]string, len(service.Settings)) 604 for k, v := range service.Settings { 605 if k2, present := nameMap[k]; present { 606 k = k2 607 } 608 settings[k] = v 609 } 610 611 return settings, nil 612 } 613 614 // configTLS uses libpq's TLS parameters to construct []*tls.Config. It is 615 // necessary to allow returning multiple TLS configs as sslmode "allow" and 616 // "prefer" allow fallback. 617 func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) { 618 host := thisHost 619 sslmode := settings["sslmode"] 620 sslrootcert := settings["sslrootcert"] 621 sslcert := settings["sslcert"] 622 sslkey := settings["sslkey"] 623 sslpassword := settings["sslpassword"] 624 sslsni := settings["sslsni"] 625 626 // Match libpq default behavior 627 if sslmode == "" { 628 sslmode = "prefer" 629 } 630 if sslsni == "" { 631 sslsni = "1" 632 } 633 634 tlsConfig := &tls.Config{} 635 636 switch sslmode { 637 case "disable": 638 return []*tls.Config{nil}, nil 639 case "allow", "prefer": 640 tlsConfig.InsecureSkipVerify = true 641 case "require": 642 // According to PostgreSQL documentation, if a root CA file exists, 643 // the behavior of sslmode=require should be the same as that of verify-ca 644 // 645 // See https://www.postgresql.org/docs/12/libpq-ssl.html 646 if sslrootcert != "" { 647 goto nextCase 648 } 649 tlsConfig.InsecureSkipVerify = true 650 break 651 nextCase: 652 fallthrough 653 case "verify-ca": 654 // Don't perform the default certificate verification because it 655 // will verify the hostname. Instead, verify the server's 656 // certificate chain ourselves in VerifyPeerCertificate and 657 // ignore the server name. This emulates libpq's verify-ca 658 // behavior. 659 // 660 // See https://github.com/golang/go/issues/21971#issuecomment-332693931 661 // and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate 662 // for more info. 663 tlsConfig.InsecureSkipVerify = true 664 tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error { 665 certs := make([]*x509.Certificate, len(certificates)) 666 for i, asn1Data := range certificates { 667 cert, err := x509.ParseCertificate(asn1Data) 668 if err != nil { 669 return errors.New("failed to parse certificate from server: " + err.Error()) 670 } 671 certs[i] = cert 672 } 673 674 // Leave DNSName empty to skip hostname verification. 675 opts := x509.VerifyOptions{ 676 Roots: tlsConfig.RootCAs, 677 Intermediates: x509.NewCertPool(), 678 } 679 // Skip the first cert because it's the leaf. All others 680 // are intermediates. 681 for _, cert := range certs[1:] { 682 opts.Intermediates.AddCert(cert) 683 } 684 _, err := certs[0].Verify(opts) 685 return err 686 } 687 case "verify-full": 688 tlsConfig.ServerName = host 689 default: 690 return nil, errors.New("sslmode is invalid") 691 } 692 693 if sslrootcert != "" { 694 caCertPool := x509.NewCertPool() 695 696 caPath := sslrootcert 697 caCert, err := ioutil.ReadFile(caPath) 698 if err != nil { 699 return nil, fmt.Errorf("unable to read CA file: %w", err) 700 } 701 702 if !caCertPool.AppendCertsFromPEM(caCert) { 703 return nil, errors.New("unable to add CA to cert pool") 704 } 705 706 tlsConfig.RootCAs = caCertPool 707 tlsConfig.ClientCAs = caCertPool 708 } 709 710 if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { 711 return nil, errors.New(`both "sslcert" and "sslkey" are required`) 712 } 713 714 if sslcert != "" && sslkey != "" { 715 buf, err := ioutil.ReadFile(sslkey) 716 if err != nil { 717 return nil, fmt.Errorf("unable to read sslkey: %w", err) 718 } 719 block, _ := pem.Decode(buf) 720 var pemKey []byte 721 var decryptedKey []byte 722 var decryptedError error 723 // If PEM is encrypted, attempt to decrypt using pass phrase 724 if x509.IsEncryptedPEMBlock(block) { 725 // Attempt decryption with pass phrase 726 // NOTE: only supports RSA (PKCS#1) 727 if sslpassword != "" { 728 decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) 729 } 730 //if sslpassword not provided or has decryption error when use it 731 //try to find sslpassword with callback function 732 if sslpassword == "" || decryptedError != nil { 733 if parseConfigOptions.GetSSLPassword != nil { 734 sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) 735 } 736 if sslpassword == "" { 737 return nil, fmt.Errorf("unable to find sslpassword") 738 } 739 } 740 decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) 741 // Should we also provide warning for PKCS#1 needed? 742 if decryptedError != nil { 743 return nil, fmt.Errorf("unable to decrypt key: %w", err) 744 } 745 746 pemBytes := pem.Block{ 747 Type: "RSA PRIVATE KEY", 748 Bytes: decryptedKey, 749 } 750 pemKey = pem.EncodeToMemory(&pemBytes) 751 } else { 752 pemKey = pem.EncodeToMemory(block) 753 } 754 certfile, err := ioutil.ReadFile(sslcert) 755 if err != nil { 756 return nil, fmt.Errorf("unable to read cert: %w", err) 757 } 758 cert, err := tls.X509KeyPair(certfile, pemKey) 759 if err != nil { 760 return nil, fmt.Errorf("unable to load cert: %w", err) 761 } 762 tlsConfig.Certificates = []tls.Certificate{cert} 763 } 764 765 // Set Server Name Indication (SNI), if enabled by connection parameters. 766 // Per RFC 6066, do not set it if the host is a literal IP address (IPv4 767 // or IPv6). 768 if sslsni == "1" && net.ParseIP(host) == nil { 769 tlsConfig.ServerName = host 770 } 771 772 switch sslmode { 773 case "allow": 774 return []*tls.Config{nil, tlsConfig}, nil 775 case "prefer": 776 return []*tls.Config{tlsConfig, nil}, nil 777 case "require", "verify-ca", "verify-full": 778 return []*tls.Config{tlsConfig}, nil 779 default: 780 panic("BUG: bad sslmode should already have been caught") 781 } 782 } 783 784 func parsePort(s string) (uint16, error) { 785 port, err := strconv.ParseUint(s, 10, 16) 786 if err != nil { 787 return 0, err 788 } 789 if port < 1 || port > math.MaxUint16 { 790 return 0, errors.New("outside range") 791 } 792 return uint16(port), nil 793 } 794 795 func makeDefaultDialer() *net.Dialer { 796 return &net.Dialer{KeepAlive: 5 * time.Minute} 797 } 798 799 func makeDefaultResolver() *net.Resolver { 800 return net.DefaultResolver 801 } 802 803 func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { 804 return func(r io.Reader, w io.Writer) Frontend { 805 cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen}) 806 if err != nil { 807 panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err)) 808 } 809 frontend := pgproto3.NewFrontend(cr, w) 810 811 return frontend 812 } 813 } 814 815 func parseConnectTimeoutSetting(s string) (time.Duration, error) { 816 timeout, err := strconv.ParseInt(s, 10, 64) 817 if err != nil { 818 return 0, err 819 } 820 if timeout < 0 { 821 return 0, errors.New("negative timeout") 822 } 823 return time.Duration(timeout) * time.Second, nil 824 } 825 826 func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc { 827 d := makeDefaultDialer() 828 d.Timeout = timeout 829 return d.DialContext 830 } 831 832 // ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible 833 // target_session_attrs=read-write. 834 func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { 835 result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() 836 if result.Err != nil { 837 return result.Err 838 } 839 840 if string(result.Rows[0][0]) == "on" { 841 return errors.New("read only connection") 842 } 843 844 return nil 845 } 846 847 // ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible 848 // target_session_attrs=read-only. 849 func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error { 850 result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() 851 if result.Err != nil { 852 return result.Err 853 } 854 855 if string(result.Rows[0][0]) != "on" { 856 return errors.New("connection is not read only") 857 } 858 859 return nil 860 } 861 862 // ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible 863 // target_session_attrs=standby. 864 func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error { 865 result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() 866 if result.Err != nil { 867 return result.Err 868 } 869 870 if string(result.Rows[0][0]) != "t" { 871 return errors.New("server is not in hot standby mode") 872 } 873 874 return nil 875 } 876 877 // ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible 878 // target_session_attrs=primary. 879 func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error { 880 result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() 881 if result.Err != nil { 882 return result.Err 883 } 884 885 if string(result.Rows[0][0]) == "t" { 886 return errors.New("server is in standby mode") 887 } 888 889 return nil 890 } 891 892 // ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible 893 // target_session_attrs=prefer-standby. 894 func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error { 895 result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() 896 if result.Err != nil { 897 return result.Err 898 } 899 900 if string(result.Rows[0][0]) != "t" { 901 return &NotPreferredError{err: errors.New("server is not in hot standby mode")} 902 } 903 904 return nil 905 }