auth_scram.go (7774B)
1 // SCRAM-SHA-256 authentication 2 // 3 // Resources: 4 // https://tools.ietf.org/html/rfc5802 5 // https://tools.ietf.org/html/rfc8265 6 // https://www.postgresql.org/docs/current/sasl-authentication.html 7 // 8 // Inspiration drawn from other implementations: 9 // https://github.com/lib/pq/pull/608 10 // https://github.com/lib/pq/pull/788 11 // https://github.com/lib/pq/pull/833 12 13 package pgconn 14 15 import ( 16 "bytes" 17 "crypto/hmac" 18 "crypto/rand" 19 "crypto/sha256" 20 "encoding/base64" 21 "errors" 22 "fmt" 23 "strconv" 24 25 "github.com/jackc/pgx/v5/pgproto3" 26 "golang.org/x/crypto/pbkdf2" 27 "golang.org/x/text/secure/precis" 28 ) 29 30 const clientNonceLen = 18 31 32 // Perform SCRAM authentication. 33 func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { 34 sc, err := newScramClient(serverAuthMechanisms, c.config.Password) 35 if err != nil { 36 return err 37 } 38 39 // Send client-first-message in a SASLInitialResponse 40 saslInitialResponse := &pgproto3.SASLInitialResponse{ 41 AuthMechanism: "SCRAM-SHA-256", 42 Data: sc.clientFirstMessage(), 43 } 44 c.frontend.Send(saslInitialResponse) 45 err = c.flushWithPotentialWriteReadDeadlock() 46 if err != nil { 47 return err 48 } 49 50 // Receive server-first-message payload in a AuthenticationSASLContinue. 51 saslContinue, err := c.rxSASLContinue() 52 if err != nil { 53 return err 54 } 55 err = sc.recvServerFirstMessage(saslContinue.Data) 56 if err != nil { 57 return err 58 } 59 60 // Send client-final-message in a SASLResponse 61 saslResponse := &pgproto3.SASLResponse{ 62 Data: []byte(sc.clientFinalMessage()), 63 } 64 c.frontend.Send(saslResponse) 65 err = c.flushWithPotentialWriteReadDeadlock() 66 if err != nil { 67 return err 68 } 69 70 // Receive server-final-message payload in a AuthenticationSASLFinal. 71 saslFinal, err := c.rxSASLFinal() 72 if err != nil { 73 return err 74 } 75 return sc.recvServerFinalMessage(saslFinal.Data) 76 } 77 78 func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) { 79 msg, err := c.receiveMessage() 80 if err != nil { 81 return nil, err 82 } 83 switch m := msg.(type) { 84 case *pgproto3.AuthenticationSASLContinue: 85 return m, nil 86 case *pgproto3.ErrorResponse: 87 return nil, ErrorResponseToPgError(m) 88 } 89 90 return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg) 91 } 92 93 func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { 94 msg, err := c.receiveMessage() 95 if err != nil { 96 return nil, err 97 } 98 switch m := msg.(type) { 99 case *pgproto3.AuthenticationSASLFinal: 100 return m, nil 101 case *pgproto3.ErrorResponse: 102 return nil, ErrorResponseToPgError(m) 103 } 104 105 return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg) 106 } 107 108 type scramClient struct { 109 serverAuthMechanisms []string 110 password []byte 111 clientNonce []byte 112 113 clientFirstMessageBare []byte 114 115 serverFirstMessage []byte 116 clientAndServerNonce []byte 117 salt []byte 118 iterations int 119 120 saltedPassword []byte 121 authMessage []byte 122 } 123 124 func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { 125 sc := &scramClient{ 126 serverAuthMechanisms: serverAuthMechanisms, 127 } 128 129 // Ensure server supports SCRAM-SHA-256 130 hasScramSHA256 := false 131 for _, mech := range sc.serverAuthMechanisms { 132 if mech == "SCRAM-SHA-256" { 133 hasScramSHA256 = true 134 break 135 } 136 } 137 if !hasScramSHA256 { 138 return nil, errors.New("server does not support SCRAM-SHA-256") 139 } 140 141 // precis.OpaqueString is equivalent to SASLprep for password. 142 var err error 143 sc.password, err = precis.OpaqueString.Bytes([]byte(password)) 144 if err != nil { 145 // PostgreSQL allows passwords invalid according to SCRAM / SASLprep. 146 sc.password = []byte(password) 147 } 148 149 buf := make([]byte, clientNonceLen) 150 _, err = rand.Read(buf) 151 if err != nil { 152 return nil, err 153 } 154 sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf))) 155 base64.RawStdEncoding.Encode(sc.clientNonce, buf) 156 157 return sc, nil 158 } 159 160 func (sc *scramClient) clientFirstMessage() []byte { 161 sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce)) 162 return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare)) 163 } 164 165 func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { 166 sc.serverFirstMessage = serverFirstMessage 167 buf := serverFirstMessage 168 if !bytes.HasPrefix(buf, []byte("r=")) { 169 return errors.New("invalid SCRAM server-first-message received from server: did not include r=") 170 } 171 buf = buf[2:] 172 173 idx := bytes.IndexByte(buf, ',') 174 if idx == -1 { 175 return errors.New("invalid SCRAM server-first-message received from server: did not include s=") 176 } 177 sc.clientAndServerNonce = buf[:idx] 178 buf = buf[idx+1:] 179 180 if !bytes.HasPrefix(buf, []byte("s=")) { 181 return errors.New("invalid SCRAM server-first-message received from server: did not include s=") 182 } 183 buf = buf[2:] 184 185 idx = bytes.IndexByte(buf, ',') 186 if idx == -1 { 187 return errors.New("invalid SCRAM server-first-message received from server: did not include i=") 188 } 189 saltStr := buf[:idx] 190 buf = buf[idx+1:] 191 192 if !bytes.HasPrefix(buf, []byte("i=")) { 193 return errors.New("invalid SCRAM server-first-message received from server: did not include i=") 194 } 195 buf = buf[2:] 196 iterationsStr := buf 197 198 var err error 199 sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr)) 200 if err != nil { 201 return fmt.Errorf("invalid SCRAM salt received from server: %w", err) 202 } 203 204 sc.iterations, err = strconv.Atoi(string(iterationsStr)) 205 if err != nil || sc.iterations <= 0 { 206 return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err) 207 } 208 209 if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) { 210 return errors.New("invalid SCRAM nonce: did not start with client nonce") 211 } 212 213 if len(sc.clientAndServerNonce) <= len(sc.clientNonce) { 214 return errors.New("invalid SCRAM nonce: did not include server nonce") 215 } 216 217 return nil 218 } 219 220 func (sc *scramClient) clientFinalMessage() string { 221 clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce)) 222 223 sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New) 224 sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) 225 226 clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) 227 228 return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof) 229 } 230 231 func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error { 232 if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) { 233 return errors.New("invalid SCRAM server-final-message received from server") 234 } 235 236 serverSignature := serverFinalMessage[2:] 237 238 if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) { 239 return errors.New("invalid SCRAM ServerSignature received from server") 240 } 241 242 return nil 243 } 244 245 func computeHMAC(key, msg []byte) []byte { 246 mac := hmac.New(sha256.New, key) 247 mac.Write(msg) 248 return mac.Sum(nil) 249 } 250 251 func computeClientProof(saltedPassword, authMessage []byte) []byte { 252 clientKey := computeHMAC(saltedPassword, []byte("Client Key")) 253 storedKey := sha256.Sum256(clientKey) 254 clientSignature := computeHMAC(storedKey[:], authMessage) 255 256 clientProof := make([]byte, len(clientSignature)) 257 for i := 0; i < len(clientSignature); i++ { 258 clientProof[i] = clientKey[i] ^ clientSignature[i] 259 } 260 261 buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof))) 262 base64.StdEncoding.Encode(buf, clientProof) 263 return buf 264 } 265 266 func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { 267 serverKey := computeHMAC(saltedPassword, []byte("Server Key")) 268 serverSignature := computeHMAC(serverKey, authMessage) 269 buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) 270 base64.StdEncoding.Encode(buf, serverSignature) 271 return buf 272 }