auth_scram.go (7712B)
1 // SCRAM-SHA-256 authentication 2 // 3 // Resources: 4 // https://tools.ietf.org/html/rfc5802 5 // https://tools.ietf.org/html/rfc8265 6 // https://www.postgresql.org/docs/current/sasl-authentication.html 7 // 8 // Inspiration drawn from other implementations: 9 // https://github.com/lib/pq/pull/608 10 // https://github.com/lib/pq/pull/788 11 // https://github.com/lib/pq/pull/833 12 13 package pgconn 14 15 import ( 16 "bytes" 17 "crypto/hmac" 18 "crypto/rand" 19 "crypto/sha256" 20 "encoding/base64" 21 "errors" 22 "fmt" 23 "strconv" 24 25 "github.com/jackc/pgproto3/v2" 26 "golang.org/x/crypto/pbkdf2" 27 "golang.org/x/text/secure/precis" 28 ) 29 30 const clientNonceLen = 18 31 32 // Perform SCRAM authentication. 33 func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { 34 sc, err := newScramClient(serverAuthMechanisms, c.config.Password) 35 if err != nil { 36 return err 37 } 38 39 // Send client-first-message in a SASLInitialResponse 40 saslInitialResponse := &pgproto3.SASLInitialResponse{ 41 AuthMechanism: "SCRAM-SHA-256", 42 Data: sc.clientFirstMessage(), 43 } 44 _, err = c.conn.Write(saslInitialResponse.Encode(nil)) 45 if err != nil { 46 return err 47 } 48 49 // Receive server-first-message payload in a AuthenticationSASLContinue. 50 saslContinue, err := c.rxSASLContinue() 51 if err != nil { 52 return err 53 } 54 err = sc.recvServerFirstMessage(saslContinue.Data) 55 if err != nil { 56 return err 57 } 58 59 // Send client-final-message in a SASLResponse 60 saslResponse := &pgproto3.SASLResponse{ 61 Data: []byte(sc.clientFinalMessage()), 62 } 63 _, err = c.conn.Write(saslResponse.Encode(nil)) 64 if err != nil { 65 return err 66 } 67 68 // Receive server-final-message payload in a AuthenticationSASLFinal. 69 saslFinal, err := c.rxSASLFinal() 70 if err != nil { 71 return err 72 } 73 return sc.recvServerFinalMessage(saslFinal.Data) 74 } 75 76 func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) { 77 msg, err := c.receiveMessage() 78 if err != nil { 79 return nil, err 80 } 81 switch m := msg.(type) { 82 case *pgproto3.AuthenticationSASLContinue: 83 return m, nil 84 case *pgproto3.ErrorResponse: 85 return nil, ErrorResponseToPgError(m) 86 } 87 88 return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg) 89 } 90 91 func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { 92 msg, err := c.receiveMessage() 93 if err != nil { 94 return nil, err 95 } 96 switch m := msg.(type) { 97 case *pgproto3.AuthenticationSASLFinal: 98 return m, nil 99 case *pgproto3.ErrorResponse: 100 return nil, ErrorResponseToPgError(m) 101 } 102 103 return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg) 104 } 105 106 type scramClient struct { 107 serverAuthMechanisms []string 108 password []byte 109 clientNonce []byte 110 111 clientFirstMessageBare []byte 112 113 serverFirstMessage []byte 114 clientAndServerNonce []byte 115 salt []byte 116 iterations int 117 118 saltedPassword []byte 119 authMessage []byte 120 } 121 122 func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { 123 sc := &scramClient{ 124 serverAuthMechanisms: serverAuthMechanisms, 125 } 126 127 // Ensure server supports SCRAM-SHA-256 128 hasScramSHA256 := false 129 for _, mech := range sc.serverAuthMechanisms { 130 if mech == "SCRAM-SHA-256" { 131 hasScramSHA256 = true 132 break 133 } 134 } 135 if !hasScramSHA256 { 136 return nil, errors.New("server does not support SCRAM-SHA-256") 137 } 138 139 // precis.OpaqueString is equivalent to SASLprep for password. 140 var err error 141 sc.password, err = precis.OpaqueString.Bytes([]byte(password)) 142 if err != nil { 143 // PostgreSQL allows passwords invalid according to SCRAM / SASLprep. 144 sc.password = []byte(password) 145 } 146 147 buf := make([]byte, clientNonceLen) 148 _, err = rand.Read(buf) 149 if err != nil { 150 return nil, err 151 } 152 sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf))) 153 base64.RawStdEncoding.Encode(sc.clientNonce, buf) 154 155 return sc, nil 156 } 157 158 func (sc *scramClient) clientFirstMessage() []byte { 159 sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce)) 160 return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare)) 161 } 162 163 func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { 164 sc.serverFirstMessage = serverFirstMessage 165 buf := serverFirstMessage 166 if !bytes.HasPrefix(buf, []byte("r=")) { 167 return errors.New("invalid SCRAM server-first-message received from server: did not include r=") 168 } 169 buf = buf[2:] 170 171 idx := bytes.IndexByte(buf, ',') 172 if idx == -1 { 173 return errors.New("invalid SCRAM server-first-message received from server: did not include s=") 174 } 175 sc.clientAndServerNonce = buf[:idx] 176 buf = buf[idx+1:] 177 178 if !bytes.HasPrefix(buf, []byte("s=")) { 179 return errors.New("invalid SCRAM server-first-message received from server: did not include s=") 180 } 181 buf = buf[2:] 182 183 idx = bytes.IndexByte(buf, ',') 184 if idx == -1 { 185 return errors.New("invalid SCRAM server-first-message received from server: did not include i=") 186 } 187 saltStr := buf[:idx] 188 buf = buf[idx+1:] 189 190 if !bytes.HasPrefix(buf, []byte("i=")) { 191 return errors.New("invalid SCRAM server-first-message received from server: did not include i=") 192 } 193 buf = buf[2:] 194 iterationsStr := buf 195 196 var err error 197 sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr)) 198 if err != nil { 199 return fmt.Errorf("invalid SCRAM salt received from server: %w", err) 200 } 201 202 sc.iterations, err = strconv.Atoi(string(iterationsStr)) 203 if err != nil || sc.iterations <= 0 { 204 return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err) 205 } 206 207 if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) { 208 return errors.New("invalid SCRAM nonce: did not start with client nonce") 209 } 210 211 if len(sc.clientAndServerNonce) <= len(sc.clientNonce) { 212 return errors.New("invalid SCRAM nonce: did not include server nonce") 213 } 214 215 return nil 216 } 217 218 func (sc *scramClient) clientFinalMessage() string { 219 clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce)) 220 221 sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New) 222 sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) 223 224 clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) 225 226 return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof) 227 } 228 229 func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error { 230 if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) { 231 return errors.New("invalid SCRAM server-final-message received from server") 232 } 233 234 serverSignature := serverFinalMessage[2:] 235 236 if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) { 237 return errors.New("invalid SCRAM ServerSignature received from server") 238 } 239 240 return nil 241 } 242 243 func computeHMAC(key, msg []byte) []byte { 244 mac := hmac.New(sha256.New, key) 245 mac.Write(msg) 246 return mac.Sum(nil) 247 } 248 249 func computeClientProof(saltedPassword, authMessage []byte) []byte { 250 clientKey := computeHMAC(saltedPassword, []byte("Client Key")) 251 storedKey := sha256.Sum256(clientKey) 252 clientSignature := computeHMAC(storedKey[:], authMessage) 253 254 clientProof := make([]byte, len(clientSignature)) 255 for i := 0; i < len(clientSignature); i++ { 256 clientProof[i] = clientKey[i] ^ clientSignature[i] 257 } 258 259 buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof))) 260 base64.StdEncoding.Encode(buf, clientProof) 261 return buf 262 } 263 264 func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { 265 serverKey := computeHMAC(saltedPassword, []byte("Server Key")) 266 serverSignature := computeHMAC(serverKey, authMessage) 267 buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) 268 base64.StdEncoding.Encode(buf, serverSignature) 269 return buf 270 }