krb5.go (2315B)
1 package pgconn 2 3 import ( 4 "errors" 5 "fmt" 6 7 "github.com/jackc/pgproto3/v2" 8 ) 9 10 // NewGSSFunc creates a GSS authentication provider, for use with 11 // RegisterGSSProvider. 12 type NewGSSFunc func() (GSS, error) 13 14 var newGSS NewGSSFunc 15 16 // RegisterGSSProvider registers a GSS authentication provider. For example, if 17 // you need to use Kerberos to authenticate with your server, add this to your 18 // main package: 19 // 20 // import "github.com/otan/gopgkrb5" 21 // 22 // func init() { 23 // pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() }) 24 // } 25 func RegisterGSSProvider(newGSSArg NewGSSFunc) { 26 newGSS = newGSSArg 27 } 28 29 // GSS provides GSSAPI authentication (e.g., Kerberos). 30 type GSS interface { 31 GetInitToken(host string, service string) ([]byte, error) 32 GetInitTokenFromSPN(spn string) ([]byte, error) 33 Continue(inToken []byte) (done bool, outToken []byte, err error) 34 } 35 36 func (c *PgConn) gssAuth() error { 37 if newGSS == nil { 38 return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5") 39 } 40 cli, err := newGSS() 41 if err != nil { 42 return err 43 } 44 45 var nextData []byte 46 if c.config.KerberosSpn != "" { 47 // Use the supplied SPN if provided. 48 nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn) 49 } else { 50 // Allow the kerberos service name to be overridden 51 service := "postgres" 52 if c.config.KerberosSrvName != "" { 53 service = c.config.KerberosSrvName 54 } 55 nextData, err = cli.GetInitToken(c.config.Host, service) 56 } 57 if err != nil { 58 return err 59 } 60 61 for { 62 gssResponse := &pgproto3.GSSResponse{ 63 Data: nextData, 64 } 65 _, err = c.conn.Write(gssResponse.Encode(nil)) 66 if err != nil { 67 return err 68 } 69 resp, err := c.rxGSSContinue() 70 if err != nil { 71 return err 72 } 73 var done bool 74 done, nextData, err = cli.Continue(resp.Data) 75 if err != nil { 76 return err 77 } 78 if done { 79 break 80 } 81 } 82 return nil 83 } 84 85 func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) { 86 msg, err := c.receiveMessage() 87 if err != nil { 88 return nil, err 89 } 90 91 switch m := msg.(type) { 92 case *pgproto3.AuthenticationGSSContinue: 93 return m, nil 94 case *pgproto3.ErrorResponse: 95 return nil, ErrorResponseToPgError(m) 96 } 97 98 return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg) 99 }