gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

auth_scram.go (7712B)


      1 // SCRAM-SHA-256 authentication
      2 //
      3 // Resources:
      4 //   https://tools.ietf.org/html/rfc5802
      5 //   https://tools.ietf.org/html/rfc8265
      6 //   https://www.postgresql.org/docs/current/sasl-authentication.html
      7 //
      8 // Inspiration drawn from other implementations:
      9 //   https://github.com/lib/pq/pull/608
     10 //   https://github.com/lib/pq/pull/788
     11 //   https://github.com/lib/pq/pull/833
     12 
     13 package pgconn
     14 
     15 import (
     16 	"bytes"
     17 	"crypto/hmac"
     18 	"crypto/rand"
     19 	"crypto/sha256"
     20 	"encoding/base64"
     21 	"errors"
     22 	"fmt"
     23 	"strconv"
     24 
     25 	"github.com/jackc/pgproto3/v2"
     26 	"golang.org/x/crypto/pbkdf2"
     27 	"golang.org/x/text/secure/precis"
     28 )
     29 
     30 const clientNonceLen = 18
     31 
     32 // Perform SCRAM authentication.
     33 func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
     34 	sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
     35 	if err != nil {
     36 		return err
     37 	}
     38 
     39 	// Send client-first-message in a SASLInitialResponse
     40 	saslInitialResponse := &pgproto3.SASLInitialResponse{
     41 		AuthMechanism: "SCRAM-SHA-256",
     42 		Data:          sc.clientFirstMessage(),
     43 	}
     44 	_, err = c.conn.Write(saslInitialResponse.Encode(nil))
     45 	if err != nil {
     46 		return err
     47 	}
     48 
     49 	// Receive server-first-message payload in a AuthenticationSASLContinue.
     50 	saslContinue, err := c.rxSASLContinue()
     51 	if err != nil {
     52 		return err
     53 	}
     54 	err = sc.recvServerFirstMessage(saslContinue.Data)
     55 	if err != nil {
     56 		return err
     57 	}
     58 
     59 	// Send client-final-message in a SASLResponse
     60 	saslResponse := &pgproto3.SASLResponse{
     61 		Data: []byte(sc.clientFinalMessage()),
     62 	}
     63 	_, err = c.conn.Write(saslResponse.Encode(nil))
     64 	if err != nil {
     65 		return err
     66 	}
     67 
     68 	// Receive server-final-message payload in a AuthenticationSASLFinal.
     69 	saslFinal, err := c.rxSASLFinal()
     70 	if err != nil {
     71 		return err
     72 	}
     73 	return sc.recvServerFinalMessage(saslFinal.Data)
     74 }
     75 
     76 func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
     77 	msg, err := c.receiveMessage()
     78 	if err != nil {
     79 		return nil, err
     80 	}
     81 	switch m := msg.(type) {
     82 	case *pgproto3.AuthenticationSASLContinue:
     83 		return m, nil
     84 	case *pgproto3.ErrorResponse:
     85 		return nil, ErrorResponseToPgError(m)
     86 	}
     87 
     88 	return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
     89 }
     90 
     91 func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
     92 	msg, err := c.receiveMessage()
     93 	if err != nil {
     94 		return nil, err
     95 	}
     96 	switch m := msg.(type) {
     97 	case *pgproto3.AuthenticationSASLFinal:
     98 		return m, nil
     99 	case *pgproto3.ErrorResponse:
    100 		return nil, ErrorResponseToPgError(m)
    101 	}
    102 
    103 	return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
    104 }
    105 
    106 type scramClient struct {
    107 	serverAuthMechanisms []string
    108 	password             []byte
    109 	clientNonce          []byte
    110 
    111 	clientFirstMessageBare []byte
    112 
    113 	serverFirstMessage   []byte
    114 	clientAndServerNonce []byte
    115 	salt                 []byte
    116 	iterations           int
    117 
    118 	saltedPassword []byte
    119 	authMessage    []byte
    120 }
    121 
    122 func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
    123 	sc := &scramClient{
    124 		serverAuthMechanisms: serverAuthMechanisms,
    125 	}
    126 
    127 	// Ensure server supports SCRAM-SHA-256
    128 	hasScramSHA256 := false
    129 	for _, mech := range sc.serverAuthMechanisms {
    130 		if mech == "SCRAM-SHA-256" {
    131 			hasScramSHA256 = true
    132 			break
    133 		}
    134 	}
    135 	if !hasScramSHA256 {
    136 		return nil, errors.New("server does not support SCRAM-SHA-256")
    137 	}
    138 
    139 	// precis.OpaqueString is equivalent to SASLprep for password.
    140 	var err error
    141 	sc.password, err = precis.OpaqueString.Bytes([]byte(password))
    142 	if err != nil {
    143 		// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
    144 		sc.password = []byte(password)
    145 	}
    146 
    147 	buf := make([]byte, clientNonceLen)
    148 	_, err = rand.Read(buf)
    149 	if err != nil {
    150 		return nil, err
    151 	}
    152 	sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
    153 	base64.RawStdEncoding.Encode(sc.clientNonce, buf)
    154 
    155 	return sc, nil
    156 }
    157 
    158 func (sc *scramClient) clientFirstMessage() []byte {
    159 	sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
    160 	return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
    161 }
    162 
    163 func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
    164 	sc.serverFirstMessage = serverFirstMessage
    165 	buf := serverFirstMessage
    166 	if !bytes.HasPrefix(buf, []byte("r=")) {
    167 		return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
    168 	}
    169 	buf = buf[2:]
    170 
    171 	idx := bytes.IndexByte(buf, ',')
    172 	if idx == -1 {
    173 		return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
    174 	}
    175 	sc.clientAndServerNonce = buf[:idx]
    176 	buf = buf[idx+1:]
    177 
    178 	if !bytes.HasPrefix(buf, []byte("s=")) {
    179 		return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
    180 	}
    181 	buf = buf[2:]
    182 
    183 	idx = bytes.IndexByte(buf, ',')
    184 	if idx == -1 {
    185 		return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
    186 	}
    187 	saltStr := buf[:idx]
    188 	buf = buf[idx+1:]
    189 
    190 	if !bytes.HasPrefix(buf, []byte("i=")) {
    191 		return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
    192 	}
    193 	buf = buf[2:]
    194 	iterationsStr := buf
    195 
    196 	var err error
    197 	sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
    198 	if err != nil {
    199 		return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
    200 	}
    201 
    202 	sc.iterations, err = strconv.Atoi(string(iterationsStr))
    203 	if err != nil || sc.iterations <= 0 {
    204 		return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
    205 	}
    206 
    207 	if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
    208 		return errors.New("invalid SCRAM nonce: did not start with client nonce")
    209 	}
    210 
    211 	if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
    212 		return errors.New("invalid SCRAM nonce: did not include server nonce")
    213 	}
    214 
    215 	return nil
    216 }
    217 
    218 func (sc *scramClient) clientFinalMessage() string {
    219 	clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
    220 
    221 	sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
    222 	sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
    223 
    224 	clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
    225 
    226 	return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
    227 }
    228 
    229 func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
    230 	if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
    231 		return errors.New("invalid SCRAM server-final-message received from server")
    232 	}
    233 
    234 	serverSignature := serverFinalMessage[2:]
    235 
    236 	if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
    237 		return errors.New("invalid SCRAM ServerSignature received from server")
    238 	}
    239 
    240 	return nil
    241 }
    242 
    243 func computeHMAC(key, msg []byte) []byte {
    244 	mac := hmac.New(sha256.New, key)
    245 	mac.Write(msg)
    246 	return mac.Sum(nil)
    247 }
    248 
    249 func computeClientProof(saltedPassword, authMessage []byte) []byte {
    250 	clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
    251 	storedKey := sha256.Sum256(clientKey)
    252 	clientSignature := computeHMAC(storedKey[:], authMessage)
    253 
    254 	clientProof := make([]byte, len(clientSignature))
    255 	for i := 0; i < len(clientSignature); i++ {
    256 		clientProof[i] = clientKey[i] ^ clientSignature[i]
    257 	}
    258 
    259 	buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
    260 	base64.StdEncoding.Encode(buf, clientProof)
    261 	return buf
    262 }
    263 
    264 func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
    265 	serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
    266 	serverSignature := computeHMAC(serverKey, authMessage)
    267 	buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
    268 	base64.StdEncoding.Encode(buf, serverSignature)
    269 	return buf
    270 }