gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

config.go (28397B)


      1 package pgconn
      2 
      3 import (
      4 	"context"
      5 	"crypto/tls"
      6 	"crypto/x509"
      7 	"encoding/pem"
      8 	"errors"
      9 	"fmt"
     10 	"io"
     11 	"io/ioutil"
     12 	"math"
     13 	"net"
     14 	"net/url"
     15 	"os"
     16 	"path/filepath"
     17 	"strconv"
     18 	"strings"
     19 	"time"
     20 
     21 	"github.com/jackc/chunkreader/v2"
     22 	"github.com/jackc/pgpassfile"
     23 	"github.com/jackc/pgproto3/v2"
     24 	"github.com/jackc/pgservicefile"
     25 )
     26 
     27 type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
     28 type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
     29 type GetSSLPasswordFunc func(ctx context.Context) string
     30 
     31 // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A
     32 // manually initialized Config will cause ConnectConfig to panic.
     33 type Config struct {
     34 	Host           string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
     35 	Port           uint16
     36 	Database       string
     37 	User           string
     38 	Password       string
     39 	TLSConfig      *tls.Config // nil disables TLS
     40 	ConnectTimeout time.Duration
     41 	DialFunc       DialFunc   // e.g. net.Dialer.DialContext
     42 	LookupFunc     LookupFunc // e.g. net.Resolver.LookupHost
     43 	BuildFrontend  BuildFrontendFunc
     44 	RuntimeParams  map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
     45 
     46 	KerberosSrvName string
     47 	KerberosSpn     string
     48 	Fallbacks       []*FallbackConfig
     49 
     50 	// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
     51 	// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
     52 	// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
     53 	ValidateConnect ValidateConnectFunc
     54 
     55 	// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
     56 	// or prepare statements). If this returns an error the connection attempt fails.
     57 	AfterConnect AfterConnectFunc
     58 
     59 	// OnNotice is a callback function called when a notice response is received.
     60 	OnNotice NoticeHandler
     61 
     62 	// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
     63 	OnNotification NotificationHandler
     64 
     65 	createdByParseConfig bool // Used to enforce created by ParseConfig rule.
     66 }
     67 
     68 // ParseConfigOptions contains options that control how a config is built such as getsslpassword.
     69 type ParseConfigOptions struct {
     70 	// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function
     71 	// PQsetSSLKeyPassHook_OpenSSL.
     72 	GetSSLPassword GetSSLPasswordFunc
     73 }
     74 
     75 // Copy returns a deep copy of the config that is safe to use and modify.
     76 // The only exception is the TLSConfig field:
     77 // according to the tls.Config docs it must not be modified after creation.
     78 func (c *Config) Copy() *Config {
     79 	newConf := new(Config)
     80 	*newConf = *c
     81 	if newConf.TLSConfig != nil {
     82 		newConf.TLSConfig = c.TLSConfig.Clone()
     83 	}
     84 	if newConf.RuntimeParams != nil {
     85 		newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
     86 		for k, v := range c.RuntimeParams {
     87 			newConf.RuntimeParams[k] = v
     88 		}
     89 	}
     90 	if newConf.Fallbacks != nil {
     91 		newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
     92 		for i, fallback := range c.Fallbacks {
     93 			newFallback := new(FallbackConfig)
     94 			*newFallback = *fallback
     95 			if newFallback.TLSConfig != nil {
     96 				newFallback.TLSConfig = fallback.TLSConfig.Clone()
     97 			}
     98 			newConf.Fallbacks[i] = newFallback
     99 		}
    100 	}
    101 	return newConf
    102 }
    103 
    104 // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
    105 // network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
    106 type FallbackConfig struct {
    107 	Host      string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
    108 	Port      uint16
    109 	TLSConfig *tls.Config // nil disables TLS
    110 }
    111 
    112 // isAbsolutePath checks if the provided value is an absolute path either
    113 // beginning with a forward slash (as on Linux-based systems) or with a capital
    114 // letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
    115 func isAbsolutePath(path string) bool {
    116 	isWindowsPath := func(p string) bool {
    117 		if len(p) < 3 {
    118 			return false
    119 		}
    120 		drive := p[0]
    121 		colon := p[1]
    122 		backslash := p[2]
    123 		if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
    124 			return true
    125 		}
    126 		return false
    127 	}
    128 	return strings.HasPrefix(path, "/") || isWindowsPath(path)
    129 }
    130 
    131 // NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
    132 // net.Dial.
    133 func NetworkAddress(host string, port uint16) (network, address string) {
    134 	if isAbsolutePath(host) {
    135 		network = "unix"
    136 		address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
    137 	} else {
    138 		network = "tcp"
    139 		address = net.JoinHostPort(host, strconv.Itoa(int(port)))
    140 	}
    141 	return network, address
    142 }
    143 
    144 // ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
    145 // uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
    146 // matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style).
    147 // See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
    148 // empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
    149 //
    150 //   # Example DSN
    151 //   user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
    152 //
    153 //   # Example URL
    154 //   postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
    155 //
    156 // The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done
    157 // through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be
    158 // interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should
    159 // not be modified individually. They should all be modified or all left unchanged.
    160 //
    161 // ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
    162 // values that will be tried in order. This can be used as part of a high availability system. See
    163 // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
    164 //
    165 //   # Example URL
    166 //   postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
    167 //
    168 // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
    169 // via database URL or DSN:
    170 //
    171 //   PGHOST
    172 //   PGPORT
    173 //   PGDATABASE
    174 //   PGUSER
    175 //   PGPASSWORD
    176 //   PGPASSFILE
    177 //   PGSERVICE
    178 //   PGSERVICEFILE
    179 //   PGSSLMODE
    180 //   PGSSLCERT
    181 //   PGSSLKEY
    182 //   PGSSLROOTCERT
    183 //   PGSSLPASSWORD
    184 //   PGAPPNAME
    185 //   PGCONNECT_TIMEOUT
    186 //   PGTARGETSESSIONATTRS
    187 //
    188 // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
    189 //
    190 // See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
    191 // usually but not always the environment variable name downcased and without the "PG" prefix.
    192 //
    193 // Important Security Notes:
    194 //
    195 // ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
    196 // not set.
    197 //
    198 // See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
    199 // security each sslmode provides.
    200 //
    201 // The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
    202 // the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of
    203 // sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback
    204 // which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
    205 // changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
    206 // TLSConfig.
    207 //
    208 // Other known differences with libpq:
    209 //
    210 // When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
    211 // does not.
    212 //
    213 // In addition, ParseConfig accepts the following options:
    214 //
    215 //  min_read_buffer_size
    216 //    The minimum size of the internal read buffer. Default 8192.
    217 //  servicefile
    218 //    libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
    219 //    part of the connection string.
    220 func ParseConfig(connString string) (*Config, error) {
    221 	var parseConfigOptions ParseConfigOptions
    222 	return ParseConfigWithOptions(connString, parseConfigOptions)
    223 }
    224 
    225 // ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard
    226 // C library libpq. options contains settings that cannot be specified in a connString such as providing a function to
    227 // get the SSL password.
    228 func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) {
    229 	defaultSettings := defaultSettings()
    230 	envSettings := parseEnvSettings()
    231 
    232 	connStringSettings := make(map[string]string)
    233 	if connString != "" {
    234 		var err error
    235 		// connString may be a database URL or a DSN
    236 		if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
    237 			connStringSettings, err = parseURLSettings(connString)
    238 			if err != nil {
    239 				return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
    240 			}
    241 		} else {
    242 			connStringSettings, err = parseDSNSettings(connString)
    243 			if err != nil {
    244 				return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
    245 			}
    246 		}
    247 	}
    248 
    249 	settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
    250 	if service, present := settings["service"]; present {
    251 		serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
    252 		if err != nil {
    253 			return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
    254 		}
    255 
    256 		settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
    257 	}
    258 
    259 	minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
    260 	if err != nil {
    261 		return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
    262 	}
    263 
    264 	config := &Config{
    265 		createdByParseConfig: true,
    266 		Database:             settings["database"],
    267 		User:                 settings["user"],
    268 		Password:             settings["password"],
    269 		RuntimeParams:        make(map[string]string),
    270 		BuildFrontend:        makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
    271 	}
    272 
    273 	if connectTimeoutSetting, present := settings["connect_timeout"]; present {
    274 		connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
    275 		if err != nil {
    276 			return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
    277 		}
    278 		config.ConnectTimeout = connectTimeout
    279 		config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
    280 	} else {
    281 		defaultDialer := makeDefaultDialer()
    282 		config.DialFunc = defaultDialer.DialContext
    283 	}
    284 
    285 	config.LookupFunc = makeDefaultResolver().LookupHost
    286 
    287 	notRuntimeParams := map[string]struct{}{
    288 		"host":                 {},
    289 		"port":                 {},
    290 		"database":             {},
    291 		"user":                 {},
    292 		"password":             {},
    293 		"passfile":             {},
    294 		"connect_timeout":      {},
    295 		"sslmode":              {},
    296 		"sslkey":               {},
    297 		"sslcert":              {},
    298 		"sslrootcert":          {},
    299 		"sslpassword":          {},
    300 		"sslsni":               {},
    301 		"krbspn":               {},
    302 		"krbsrvname":           {},
    303 		"target_session_attrs": {},
    304 		"min_read_buffer_size": {},
    305 		"service":              {},
    306 		"servicefile":          {},
    307 	}
    308 
    309 	// Adding kerberos configuration
    310 	if _, present := settings["krbsrvname"]; present {
    311 		config.KerberosSrvName = settings["krbsrvname"]
    312 	}
    313 	if _, present := settings["krbspn"]; present {
    314 		config.KerberosSpn = settings["krbspn"]
    315 	}
    316 
    317 	for k, v := range settings {
    318 		if _, present := notRuntimeParams[k]; present {
    319 			continue
    320 		}
    321 		config.RuntimeParams[k] = v
    322 	}
    323 
    324 	fallbacks := []*FallbackConfig{}
    325 
    326 	hosts := strings.Split(settings["host"], ",")
    327 	ports := strings.Split(settings["port"], ",")
    328 
    329 	for i, host := range hosts {
    330 		var portStr string
    331 		if i < len(ports) {
    332 			portStr = ports[i]
    333 		} else {
    334 			portStr = ports[0]
    335 		}
    336 
    337 		port, err := parsePort(portStr)
    338 		if err != nil {
    339 			return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
    340 		}
    341 
    342 		var tlsConfigs []*tls.Config
    343 
    344 		// Ignore TLS settings if Unix domain socket like libpq
    345 		if network, _ := NetworkAddress(host, port); network == "unix" {
    346 			tlsConfigs = append(tlsConfigs, nil)
    347 		} else {
    348 			var err error
    349 			tlsConfigs, err = configTLS(settings, host, options)
    350 			if err != nil {
    351 				return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
    352 			}
    353 		}
    354 
    355 		for _, tlsConfig := range tlsConfigs {
    356 			fallbacks = append(fallbacks, &FallbackConfig{
    357 				Host:      host,
    358 				Port:      port,
    359 				TLSConfig: tlsConfig,
    360 			})
    361 		}
    362 	}
    363 
    364 	config.Host = fallbacks[0].Host
    365 	config.Port = fallbacks[0].Port
    366 	config.TLSConfig = fallbacks[0].TLSConfig
    367 	config.Fallbacks = fallbacks[1:]
    368 
    369 	passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
    370 	if err == nil {
    371 		if config.Password == "" {
    372 			host := config.Host
    373 			if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
    374 				host = "localhost"
    375 			}
    376 
    377 			config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
    378 		}
    379 	}
    380 
    381 	switch tsa := settings["target_session_attrs"]; tsa {
    382 	case "read-write":
    383 		config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
    384 	case "read-only":
    385 		config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly
    386 	case "primary":
    387 		config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
    388 	case "standby":
    389 		config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
    390 	case "prefer-standby":
    391 		config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
    392 	case "any":
    393 		// do nothing
    394 	default:
    395 		return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
    396 	}
    397 
    398 	return config, nil
    399 }
    400 
    401 func mergeSettings(settingSets ...map[string]string) map[string]string {
    402 	settings := make(map[string]string)
    403 
    404 	for _, s2 := range settingSets {
    405 		for k, v := range s2 {
    406 			settings[k] = v
    407 		}
    408 	}
    409 
    410 	return settings
    411 }
    412 
    413 func parseEnvSettings() map[string]string {
    414 	settings := make(map[string]string)
    415 
    416 	nameMap := map[string]string{
    417 		"PGHOST":               "host",
    418 		"PGPORT":               "port",
    419 		"PGDATABASE":           "database",
    420 		"PGUSER":               "user",
    421 		"PGPASSWORD":           "password",
    422 		"PGPASSFILE":           "passfile",
    423 		"PGAPPNAME":            "application_name",
    424 		"PGCONNECT_TIMEOUT":    "connect_timeout",
    425 		"PGSSLMODE":            "sslmode",
    426 		"PGSSLKEY":             "sslkey",
    427 		"PGSSLCERT":            "sslcert",
    428 		"PGSSLSNI":             "sslsni",
    429 		"PGSSLROOTCERT":        "sslrootcert",
    430 		"PGSSLPASSWORD":        "sslpassword",
    431 		"PGTARGETSESSIONATTRS": "target_session_attrs",
    432 		"PGSERVICE":            "service",
    433 		"PGSERVICEFILE":        "servicefile",
    434 	}
    435 
    436 	for envname, realname := range nameMap {
    437 		value := os.Getenv(envname)
    438 		if value != "" {
    439 			settings[realname] = value
    440 		}
    441 	}
    442 
    443 	return settings
    444 }
    445 
    446 func parseURLSettings(connString string) (map[string]string, error) {
    447 	settings := make(map[string]string)
    448 
    449 	url, err := url.Parse(connString)
    450 	if err != nil {
    451 		return nil, err
    452 	}
    453 
    454 	if url.User != nil {
    455 		settings["user"] = url.User.Username()
    456 		if password, present := url.User.Password(); present {
    457 			settings["password"] = password
    458 		}
    459 	}
    460 
    461 	// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
    462 	var hosts []string
    463 	var ports []string
    464 	for _, host := range strings.Split(url.Host, ",") {
    465 		if host == "" {
    466 			continue
    467 		}
    468 		if isIPOnly(host) {
    469 			hosts = append(hosts, strings.Trim(host, "[]"))
    470 			continue
    471 		}
    472 		h, p, err := net.SplitHostPort(host)
    473 		if err != nil {
    474 			return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
    475 		}
    476 		if h != "" {
    477 			hosts = append(hosts, h)
    478 		}
    479 		if p != "" {
    480 			ports = append(ports, p)
    481 		}
    482 	}
    483 	if len(hosts) > 0 {
    484 		settings["host"] = strings.Join(hosts, ",")
    485 	}
    486 	if len(ports) > 0 {
    487 		settings["port"] = strings.Join(ports, ",")
    488 	}
    489 
    490 	database := strings.TrimLeft(url.Path, "/")
    491 	if database != "" {
    492 		settings["database"] = database
    493 	}
    494 
    495 	nameMap := map[string]string{
    496 		"dbname": "database",
    497 	}
    498 
    499 	for k, v := range url.Query() {
    500 		if k2, present := nameMap[k]; present {
    501 			k = k2
    502 		}
    503 
    504 		settings[k] = v[0]
    505 	}
    506 
    507 	return settings, nil
    508 }
    509 
    510 func isIPOnly(host string) bool {
    511 	return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
    512 }
    513 
    514 var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
    515 
    516 func parseDSNSettings(s string) (map[string]string, error) {
    517 	settings := make(map[string]string)
    518 
    519 	nameMap := map[string]string{
    520 		"dbname": "database",
    521 	}
    522 
    523 	for len(s) > 0 {
    524 		var key, val string
    525 		eqIdx := strings.IndexRune(s, '=')
    526 		if eqIdx < 0 {
    527 			return nil, errors.New("invalid dsn")
    528 		}
    529 
    530 		key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
    531 		s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
    532 		if len(s) == 0 {
    533 		} else if s[0] != '\'' {
    534 			end := 0
    535 			for ; end < len(s); end++ {
    536 				if asciiSpace[s[end]] == 1 {
    537 					break
    538 				}
    539 				if s[end] == '\\' {
    540 					end++
    541 					if end == len(s) {
    542 						return nil, errors.New("invalid backslash")
    543 					}
    544 				}
    545 			}
    546 			val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
    547 			if end == len(s) {
    548 				s = ""
    549 			} else {
    550 				s = s[end+1:]
    551 			}
    552 		} else { // quoted string
    553 			s = s[1:]
    554 			end := 0
    555 			for ; end < len(s); end++ {
    556 				if s[end] == '\'' {
    557 					break
    558 				}
    559 				if s[end] == '\\' {
    560 					end++
    561 				}
    562 			}
    563 			if end == len(s) {
    564 				return nil, errors.New("unterminated quoted string in connection info string")
    565 			}
    566 			val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
    567 			if end == len(s) {
    568 				s = ""
    569 			} else {
    570 				s = s[end+1:]
    571 			}
    572 		}
    573 
    574 		if k, ok := nameMap[key]; ok {
    575 			key = k
    576 		}
    577 
    578 		if key == "" {
    579 			return nil, errors.New("invalid dsn")
    580 		}
    581 
    582 		settings[key] = val
    583 	}
    584 
    585 	return settings, nil
    586 }
    587 
    588 func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
    589 	servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
    590 	if err != nil {
    591 		return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
    592 	}
    593 
    594 	service, err := servicefile.GetService(serviceName)
    595 	if err != nil {
    596 		return nil, fmt.Errorf("unable to find service: %v", serviceName)
    597 	}
    598 
    599 	nameMap := map[string]string{
    600 		"dbname": "database",
    601 	}
    602 
    603 	settings := make(map[string]string, len(service.Settings))
    604 	for k, v := range service.Settings {
    605 		if k2, present := nameMap[k]; present {
    606 			k = k2
    607 		}
    608 		settings[k] = v
    609 	}
    610 
    611 	return settings, nil
    612 }
    613 
    614 // configTLS uses libpq's TLS parameters to construct  []*tls.Config. It is
    615 // necessary to allow returning multiple TLS configs as sslmode "allow" and
    616 // "prefer" allow fallback.
    617 func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) {
    618 	host := thisHost
    619 	sslmode := settings["sslmode"]
    620 	sslrootcert := settings["sslrootcert"]
    621 	sslcert := settings["sslcert"]
    622 	sslkey := settings["sslkey"]
    623 	sslpassword := settings["sslpassword"]
    624 	sslsni := settings["sslsni"]
    625 
    626 	// Match libpq default behavior
    627 	if sslmode == "" {
    628 		sslmode = "prefer"
    629 	}
    630 	if sslsni == "" {
    631 		sslsni = "1"
    632 	}
    633 
    634 	tlsConfig := &tls.Config{}
    635 
    636 	switch sslmode {
    637 	case "disable":
    638 		return []*tls.Config{nil}, nil
    639 	case "allow", "prefer":
    640 		tlsConfig.InsecureSkipVerify = true
    641 	case "require":
    642 		// According to PostgreSQL documentation, if a root CA file exists,
    643 		// the behavior of sslmode=require should be the same as that of verify-ca
    644 		//
    645 		// See https://www.postgresql.org/docs/12/libpq-ssl.html
    646 		if sslrootcert != "" {
    647 			goto nextCase
    648 		}
    649 		tlsConfig.InsecureSkipVerify = true
    650 		break
    651 	nextCase:
    652 		fallthrough
    653 	case "verify-ca":
    654 		// Don't perform the default certificate verification because it
    655 		// will verify the hostname. Instead, verify the server's
    656 		// certificate chain ourselves in VerifyPeerCertificate and
    657 		// ignore the server name. This emulates libpq's verify-ca
    658 		// behavior.
    659 		//
    660 		// See https://github.com/golang/go/issues/21971#issuecomment-332693931
    661 		// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
    662 		// for more info.
    663 		tlsConfig.InsecureSkipVerify = true
    664 		tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
    665 			certs := make([]*x509.Certificate, len(certificates))
    666 			for i, asn1Data := range certificates {
    667 				cert, err := x509.ParseCertificate(asn1Data)
    668 				if err != nil {
    669 					return errors.New("failed to parse certificate from server: " + err.Error())
    670 				}
    671 				certs[i] = cert
    672 			}
    673 
    674 			// Leave DNSName empty to skip hostname verification.
    675 			opts := x509.VerifyOptions{
    676 				Roots:         tlsConfig.RootCAs,
    677 				Intermediates: x509.NewCertPool(),
    678 			}
    679 			// Skip the first cert because it's the leaf. All others
    680 			// are intermediates.
    681 			for _, cert := range certs[1:] {
    682 				opts.Intermediates.AddCert(cert)
    683 			}
    684 			_, err := certs[0].Verify(opts)
    685 			return err
    686 		}
    687 	case "verify-full":
    688 		tlsConfig.ServerName = host
    689 	default:
    690 		return nil, errors.New("sslmode is invalid")
    691 	}
    692 
    693 	if sslrootcert != "" {
    694 		caCertPool := x509.NewCertPool()
    695 
    696 		caPath := sslrootcert
    697 		caCert, err := ioutil.ReadFile(caPath)
    698 		if err != nil {
    699 			return nil, fmt.Errorf("unable to read CA file: %w", err)
    700 		}
    701 
    702 		if !caCertPool.AppendCertsFromPEM(caCert) {
    703 			return nil, errors.New("unable to add CA to cert pool")
    704 		}
    705 
    706 		tlsConfig.RootCAs = caCertPool
    707 		tlsConfig.ClientCAs = caCertPool
    708 	}
    709 
    710 	if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
    711 		return nil, errors.New(`both "sslcert" and "sslkey" are required`)
    712 	}
    713 
    714 	if sslcert != "" && sslkey != "" {
    715 		buf, err := ioutil.ReadFile(sslkey)
    716 		if err != nil {
    717 			return nil, fmt.Errorf("unable to read sslkey: %w", err)
    718 		}
    719 		block, _ := pem.Decode(buf)
    720 		var pemKey []byte
    721 		var decryptedKey []byte
    722 		var decryptedError error
    723 		// If PEM is encrypted, attempt to decrypt using pass phrase
    724 		if x509.IsEncryptedPEMBlock(block) {
    725 			// Attempt decryption with pass phrase
    726 			// NOTE: only supports RSA (PKCS#1)
    727 			if sslpassword != "" {
    728 				decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
    729 			}
    730 			//if sslpassword not provided or has decryption error when use it
    731 			//try to find sslpassword with callback function
    732 			if sslpassword == "" || decryptedError != nil {
    733 				if parseConfigOptions.GetSSLPassword != nil {
    734 					sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
    735 				}
    736 				if sslpassword == "" {
    737 					return nil, fmt.Errorf("unable to find sslpassword")
    738 				}
    739 			}
    740 			decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
    741 			// Should we also provide warning for PKCS#1 needed?
    742 			if decryptedError != nil {
    743 				return nil, fmt.Errorf("unable to decrypt key: %w", err)
    744 			}
    745 
    746 			pemBytes := pem.Block{
    747 				Type:  "RSA PRIVATE KEY",
    748 				Bytes: decryptedKey,
    749 			}
    750 			pemKey = pem.EncodeToMemory(&pemBytes)
    751 		} else {
    752 			pemKey = pem.EncodeToMemory(block)
    753 		}
    754 		certfile, err := ioutil.ReadFile(sslcert)
    755 		if err != nil {
    756 			return nil, fmt.Errorf("unable to read cert: %w", err)
    757 		}
    758 		cert, err := tls.X509KeyPair(certfile, pemKey)
    759 		if err != nil {
    760 			return nil, fmt.Errorf("unable to load cert: %w", err)
    761 		}
    762 		tlsConfig.Certificates = []tls.Certificate{cert}
    763 	}
    764 
    765 	// Set Server Name Indication (SNI), if enabled by connection parameters.
    766 	// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
    767 	// or IPv6).
    768 	if sslsni == "1" && net.ParseIP(host) == nil {
    769 		tlsConfig.ServerName = host
    770 	}
    771 
    772 	switch sslmode {
    773 	case "allow":
    774 		return []*tls.Config{nil, tlsConfig}, nil
    775 	case "prefer":
    776 		return []*tls.Config{tlsConfig, nil}, nil
    777 	case "require", "verify-ca", "verify-full":
    778 		return []*tls.Config{tlsConfig}, nil
    779 	default:
    780 		panic("BUG: bad sslmode should already have been caught")
    781 	}
    782 }
    783 
    784 func parsePort(s string) (uint16, error) {
    785 	port, err := strconv.ParseUint(s, 10, 16)
    786 	if err != nil {
    787 		return 0, err
    788 	}
    789 	if port < 1 || port > math.MaxUint16 {
    790 		return 0, errors.New("outside range")
    791 	}
    792 	return uint16(port), nil
    793 }
    794 
    795 func makeDefaultDialer() *net.Dialer {
    796 	return &net.Dialer{KeepAlive: 5 * time.Minute}
    797 }
    798 
    799 func makeDefaultResolver() *net.Resolver {
    800 	return net.DefaultResolver
    801 }
    802 
    803 func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
    804 	return func(r io.Reader, w io.Writer) Frontend {
    805 		cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen})
    806 		if err != nil {
    807 			panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err))
    808 		}
    809 		frontend := pgproto3.NewFrontend(cr, w)
    810 
    811 		return frontend
    812 	}
    813 }
    814 
    815 func parseConnectTimeoutSetting(s string) (time.Duration, error) {
    816 	timeout, err := strconv.ParseInt(s, 10, 64)
    817 	if err != nil {
    818 		return 0, err
    819 	}
    820 	if timeout < 0 {
    821 		return 0, errors.New("negative timeout")
    822 	}
    823 	return time.Duration(timeout) * time.Second, nil
    824 }
    825 
    826 func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
    827 	d := makeDefaultDialer()
    828 	d.Timeout = timeout
    829 	return d.DialContext
    830 }
    831 
    832 // ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
    833 // target_session_attrs=read-write.
    834 func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
    835 	result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
    836 	if result.Err != nil {
    837 		return result.Err
    838 	}
    839 
    840 	if string(result.Rows[0][0]) == "on" {
    841 		return errors.New("read only connection")
    842 	}
    843 
    844 	return nil
    845 }
    846 
    847 // ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible
    848 // target_session_attrs=read-only.
    849 func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
    850 	result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
    851 	if result.Err != nil {
    852 		return result.Err
    853 	}
    854 
    855 	if string(result.Rows[0][0]) != "on" {
    856 		return errors.New("connection is not read only")
    857 	}
    858 
    859 	return nil
    860 }
    861 
    862 // ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible
    863 // target_session_attrs=standby.
    864 func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
    865 	result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
    866 	if result.Err != nil {
    867 		return result.Err
    868 	}
    869 
    870 	if string(result.Rows[0][0]) != "t" {
    871 		return errors.New("server is not in hot standby mode")
    872 	}
    873 
    874 	return nil
    875 }
    876 
    877 // ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible
    878 // target_session_attrs=primary.
    879 func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
    880 	result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
    881 	if result.Err != nil {
    882 		return result.Err
    883 	}
    884 
    885 	if string(result.Rows[0][0]) == "t" {
    886 		return errors.New("server is in standby mode")
    887 	}
    888 
    889 	return nil
    890 }
    891 
    892 // ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible
    893 // target_session_attrs=prefer-standby.
    894 func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
    895 	result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
    896 	if result.Err != nil {
    897 		return result.Err
    898 	}
    899 
    900 	if string(result.Rows[0][0]) != "t" {
    901 		return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
    902 	}
    903 
    904 	return nil
    905 }