gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

auth_scram.go (7774B)


      1 // SCRAM-SHA-256 authentication
      2 //
      3 // Resources:
      4 //   https://tools.ietf.org/html/rfc5802
      5 //   https://tools.ietf.org/html/rfc8265
      6 //   https://www.postgresql.org/docs/current/sasl-authentication.html
      7 //
      8 // Inspiration drawn from other implementations:
      9 //   https://github.com/lib/pq/pull/608
     10 //   https://github.com/lib/pq/pull/788
     11 //   https://github.com/lib/pq/pull/833
     12 
     13 package pgconn
     14 
     15 import (
     16 	"bytes"
     17 	"crypto/hmac"
     18 	"crypto/rand"
     19 	"crypto/sha256"
     20 	"encoding/base64"
     21 	"errors"
     22 	"fmt"
     23 	"strconv"
     24 
     25 	"github.com/jackc/pgx/v5/pgproto3"
     26 	"golang.org/x/crypto/pbkdf2"
     27 	"golang.org/x/text/secure/precis"
     28 )
     29 
     30 const clientNonceLen = 18
     31 
     32 // Perform SCRAM authentication.
     33 func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
     34 	sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
     35 	if err != nil {
     36 		return err
     37 	}
     38 
     39 	// Send client-first-message in a SASLInitialResponse
     40 	saslInitialResponse := &pgproto3.SASLInitialResponse{
     41 		AuthMechanism: "SCRAM-SHA-256",
     42 		Data:          sc.clientFirstMessage(),
     43 	}
     44 	c.frontend.Send(saslInitialResponse)
     45 	err = c.flushWithPotentialWriteReadDeadlock()
     46 	if err != nil {
     47 		return err
     48 	}
     49 
     50 	// Receive server-first-message payload in a AuthenticationSASLContinue.
     51 	saslContinue, err := c.rxSASLContinue()
     52 	if err != nil {
     53 		return err
     54 	}
     55 	err = sc.recvServerFirstMessage(saslContinue.Data)
     56 	if err != nil {
     57 		return err
     58 	}
     59 
     60 	// Send client-final-message in a SASLResponse
     61 	saslResponse := &pgproto3.SASLResponse{
     62 		Data: []byte(sc.clientFinalMessage()),
     63 	}
     64 	c.frontend.Send(saslResponse)
     65 	err = c.flushWithPotentialWriteReadDeadlock()
     66 	if err != nil {
     67 		return err
     68 	}
     69 
     70 	// Receive server-final-message payload in a AuthenticationSASLFinal.
     71 	saslFinal, err := c.rxSASLFinal()
     72 	if err != nil {
     73 		return err
     74 	}
     75 	return sc.recvServerFinalMessage(saslFinal.Data)
     76 }
     77 
     78 func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
     79 	msg, err := c.receiveMessage()
     80 	if err != nil {
     81 		return nil, err
     82 	}
     83 	switch m := msg.(type) {
     84 	case *pgproto3.AuthenticationSASLContinue:
     85 		return m, nil
     86 	case *pgproto3.ErrorResponse:
     87 		return nil, ErrorResponseToPgError(m)
     88 	}
     89 
     90 	return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
     91 }
     92 
     93 func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
     94 	msg, err := c.receiveMessage()
     95 	if err != nil {
     96 		return nil, err
     97 	}
     98 	switch m := msg.(type) {
     99 	case *pgproto3.AuthenticationSASLFinal:
    100 		return m, nil
    101 	case *pgproto3.ErrorResponse:
    102 		return nil, ErrorResponseToPgError(m)
    103 	}
    104 
    105 	return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
    106 }
    107 
    108 type scramClient struct {
    109 	serverAuthMechanisms []string
    110 	password             []byte
    111 	clientNonce          []byte
    112 
    113 	clientFirstMessageBare []byte
    114 
    115 	serverFirstMessage   []byte
    116 	clientAndServerNonce []byte
    117 	salt                 []byte
    118 	iterations           int
    119 
    120 	saltedPassword []byte
    121 	authMessage    []byte
    122 }
    123 
    124 func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
    125 	sc := &scramClient{
    126 		serverAuthMechanisms: serverAuthMechanisms,
    127 	}
    128 
    129 	// Ensure server supports SCRAM-SHA-256
    130 	hasScramSHA256 := false
    131 	for _, mech := range sc.serverAuthMechanisms {
    132 		if mech == "SCRAM-SHA-256" {
    133 			hasScramSHA256 = true
    134 			break
    135 		}
    136 	}
    137 	if !hasScramSHA256 {
    138 		return nil, errors.New("server does not support SCRAM-SHA-256")
    139 	}
    140 
    141 	// precis.OpaqueString is equivalent to SASLprep for password.
    142 	var err error
    143 	sc.password, err = precis.OpaqueString.Bytes([]byte(password))
    144 	if err != nil {
    145 		// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
    146 		sc.password = []byte(password)
    147 	}
    148 
    149 	buf := make([]byte, clientNonceLen)
    150 	_, err = rand.Read(buf)
    151 	if err != nil {
    152 		return nil, err
    153 	}
    154 	sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
    155 	base64.RawStdEncoding.Encode(sc.clientNonce, buf)
    156 
    157 	return sc, nil
    158 }
    159 
    160 func (sc *scramClient) clientFirstMessage() []byte {
    161 	sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
    162 	return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
    163 }
    164 
    165 func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
    166 	sc.serverFirstMessage = serverFirstMessage
    167 	buf := serverFirstMessage
    168 	if !bytes.HasPrefix(buf, []byte("r=")) {
    169 		return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
    170 	}
    171 	buf = buf[2:]
    172 
    173 	idx := bytes.IndexByte(buf, ',')
    174 	if idx == -1 {
    175 		return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
    176 	}
    177 	sc.clientAndServerNonce = buf[:idx]
    178 	buf = buf[idx+1:]
    179 
    180 	if !bytes.HasPrefix(buf, []byte("s=")) {
    181 		return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
    182 	}
    183 	buf = buf[2:]
    184 
    185 	idx = bytes.IndexByte(buf, ',')
    186 	if idx == -1 {
    187 		return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
    188 	}
    189 	saltStr := buf[:idx]
    190 	buf = buf[idx+1:]
    191 
    192 	if !bytes.HasPrefix(buf, []byte("i=")) {
    193 		return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
    194 	}
    195 	buf = buf[2:]
    196 	iterationsStr := buf
    197 
    198 	var err error
    199 	sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
    200 	if err != nil {
    201 		return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
    202 	}
    203 
    204 	sc.iterations, err = strconv.Atoi(string(iterationsStr))
    205 	if err != nil || sc.iterations <= 0 {
    206 		return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
    207 	}
    208 
    209 	if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
    210 		return errors.New("invalid SCRAM nonce: did not start with client nonce")
    211 	}
    212 
    213 	if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
    214 		return errors.New("invalid SCRAM nonce: did not include server nonce")
    215 	}
    216 
    217 	return nil
    218 }
    219 
    220 func (sc *scramClient) clientFinalMessage() string {
    221 	clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
    222 
    223 	sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
    224 	sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
    225 
    226 	clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
    227 
    228 	return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
    229 }
    230 
    231 func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
    232 	if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
    233 		return errors.New("invalid SCRAM server-final-message received from server")
    234 	}
    235 
    236 	serverSignature := serverFinalMessage[2:]
    237 
    238 	if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
    239 		return errors.New("invalid SCRAM ServerSignature received from server")
    240 	}
    241 
    242 	return nil
    243 }
    244 
    245 func computeHMAC(key, msg []byte) []byte {
    246 	mac := hmac.New(sha256.New, key)
    247 	mac.Write(msg)
    248 	return mac.Sum(nil)
    249 }
    250 
    251 func computeClientProof(saltedPassword, authMessage []byte) []byte {
    252 	clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
    253 	storedKey := sha256.Sum256(clientKey)
    254 	clientSignature := computeHMAC(storedKey[:], authMessage)
    255 
    256 	clientProof := make([]byte, len(clientSignature))
    257 	for i := 0; i < len(clientSignature); i++ {
    258 		clientProof[i] = clientKey[i] ^ clientSignature[i]
    259 	}
    260 
    261 	buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
    262 	base64.StdEncoding.Encode(buf, clientProof)
    263 	return buf
    264 }
    265 
    266 func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
    267 	serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
    268 	serverSignature := computeHMAC(serverKey, authMessage)
    269 	buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
    270 	base64.StdEncoding.Encode(buf, serverSignature)
    271 	return buf
    272 }