krb5.go (2349B)
1 package pgconn 2 3 import ( 4 "errors" 5 "fmt" 6 7 "github.com/jackc/pgx/v5/pgproto3" 8 ) 9 10 // NewGSSFunc creates a GSS authentication provider, for use with 11 // RegisterGSSProvider. 12 type NewGSSFunc func() (GSS, error) 13 14 var newGSS NewGSSFunc 15 16 // RegisterGSSProvider registers a GSS authentication provider. For example, if 17 // you need to use Kerberos to authenticate with your server, add this to your 18 // main package: 19 // 20 // import "github.com/otan/gopgkrb5" 21 // 22 // func init() { 23 // pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() }) 24 // } 25 func RegisterGSSProvider(newGSSArg NewGSSFunc) { 26 newGSS = newGSSArg 27 } 28 29 // GSS provides GSSAPI authentication (e.g., Kerberos). 30 type GSS interface { 31 GetInitToken(host string, service string) ([]byte, error) 32 GetInitTokenFromSPN(spn string) ([]byte, error) 33 Continue(inToken []byte) (done bool, outToken []byte, err error) 34 } 35 36 func (c *PgConn) gssAuth() error { 37 if newGSS == nil { 38 return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5") 39 } 40 cli, err := newGSS() 41 if err != nil { 42 return err 43 } 44 45 var nextData []byte 46 if c.config.KerberosSpn != "" { 47 // Use the supplied SPN if provided. 48 nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn) 49 } else { 50 // Allow the kerberos service name to be overridden 51 service := "postgres" 52 if c.config.KerberosSrvName != "" { 53 service = c.config.KerberosSrvName 54 } 55 nextData, err = cli.GetInitToken(c.config.Host, service) 56 } 57 if err != nil { 58 return err 59 } 60 61 for { 62 gssResponse := &pgproto3.GSSResponse{ 63 Data: nextData, 64 } 65 c.frontend.Send(gssResponse) 66 err = c.flushWithPotentialWriteReadDeadlock() 67 if err != nil { 68 return err 69 } 70 resp, err := c.rxGSSContinue() 71 if err != nil { 72 return err 73 } 74 var done bool 75 done, nextData, err = cli.Continue(resp.Data) 76 if err != nil { 77 return err 78 } 79 if done { 80 break 81 } 82 } 83 return nil 84 } 85 86 func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) { 87 msg, err := c.receiveMessage() 88 if err != nil { 89 return nil, err 90 } 91 92 switch m := msg.(type) { 93 case *pgproto3.AuthenticationGSSContinue: 94 return m, nil 95 case *pgproto3.ErrorResponse: 96 return nil, ErrorResponseToPgError(m) 97 } 98 99 return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg) 100 }