gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

krb5.go (2349B)


      1 package pgconn
      2 
      3 import (
      4 	"errors"
      5 	"fmt"
      6 
      7 	"github.com/jackc/pgx/v5/pgproto3"
      8 )
      9 
     10 // NewGSSFunc creates a GSS authentication provider, for use with
     11 // RegisterGSSProvider.
     12 type NewGSSFunc func() (GSS, error)
     13 
     14 var newGSS NewGSSFunc
     15 
     16 // RegisterGSSProvider registers a GSS authentication provider. For example, if
     17 // you need to use Kerberos to authenticate with your server, add this to your
     18 // main package:
     19 //
     20 //	import "github.com/otan/gopgkrb5"
     21 //
     22 //	func init() {
     23 //		pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
     24 //	}
     25 func RegisterGSSProvider(newGSSArg NewGSSFunc) {
     26 	newGSS = newGSSArg
     27 }
     28 
     29 // GSS provides GSSAPI authentication (e.g., Kerberos).
     30 type GSS interface {
     31 	GetInitToken(host string, service string) ([]byte, error)
     32 	GetInitTokenFromSPN(spn string) ([]byte, error)
     33 	Continue(inToken []byte) (done bool, outToken []byte, err error)
     34 }
     35 
     36 func (c *PgConn) gssAuth() error {
     37 	if newGSS == nil {
     38 		return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
     39 	}
     40 	cli, err := newGSS()
     41 	if err != nil {
     42 		return err
     43 	}
     44 
     45 	var nextData []byte
     46 	if c.config.KerberosSpn != "" {
     47 		// Use the supplied SPN if provided.
     48 		nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
     49 	} else {
     50 		// Allow the kerberos service name to be overridden
     51 		service := "postgres"
     52 		if c.config.KerberosSrvName != "" {
     53 			service = c.config.KerberosSrvName
     54 		}
     55 		nextData, err = cli.GetInitToken(c.config.Host, service)
     56 	}
     57 	if err != nil {
     58 		return err
     59 	}
     60 
     61 	for {
     62 		gssResponse := &pgproto3.GSSResponse{
     63 			Data: nextData,
     64 		}
     65 		c.frontend.Send(gssResponse)
     66 		err = c.flushWithPotentialWriteReadDeadlock()
     67 		if err != nil {
     68 			return err
     69 		}
     70 		resp, err := c.rxGSSContinue()
     71 		if err != nil {
     72 			return err
     73 		}
     74 		var done bool
     75 		done, nextData, err = cli.Continue(resp.Data)
     76 		if err != nil {
     77 			return err
     78 		}
     79 		if done {
     80 			break
     81 		}
     82 	}
     83 	return nil
     84 }
     85 
     86 func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
     87 	msg, err := c.receiveMessage()
     88 	if err != nil {
     89 		return nil, err
     90 	}
     91 
     92 	switch m := msg.(type) {
     93 	case *pgproto3.AuthenticationGSSContinue:
     94 		return m, nil
     95 	case *pgproto3.ErrorResponse:
     96 		return nil, ErrorResponseToPgError(m)
     97 	}
     98 
     99 	return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
    100 }