gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

krb5.go (2315B)


      1 package pgconn
      2 
      3 import (
      4 	"errors"
      5 	"fmt"
      6 
      7 	"github.com/jackc/pgproto3/v2"
      8 )
      9 
     10 // NewGSSFunc creates a GSS authentication provider, for use with
     11 // RegisterGSSProvider.
     12 type NewGSSFunc func() (GSS, error)
     13 
     14 var newGSS NewGSSFunc
     15 
     16 // RegisterGSSProvider registers a GSS authentication provider. For example, if
     17 // you need to use Kerberos to authenticate with your server, add this to your
     18 // main package:
     19 //
     20 //	import "github.com/otan/gopgkrb5"
     21 //
     22 //	func init() {
     23 //		pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
     24 //	}
     25 func RegisterGSSProvider(newGSSArg NewGSSFunc) {
     26 	newGSS = newGSSArg
     27 }
     28 
     29 // GSS provides GSSAPI authentication (e.g., Kerberos).
     30 type GSS interface {
     31 	GetInitToken(host string, service string) ([]byte, error)
     32 	GetInitTokenFromSPN(spn string) ([]byte, error)
     33 	Continue(inToken []byte) (done bool, outToken []byte, err error)
     34 }
     35 
     36 func (c *PgConn) gssAuth() error {
     37 	if newGSS == nil {
     38 		return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
     39 	}
     40 	cli, err := newGSS()
     41 	if err != nil {
     42 		return err
     43 	}
     44 
     45 	var nextData []byte
     46 	if c.config.KerberosSpn != "" {
     47 		// Use the supplied SPN if provided.
     48 		nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
     49 	} else {
     50 		// Allow the kerberos service name to be overridden
     51 		service := "postgres"
     52 		if c.config.KerberosSrvName != "" {
     53 			service = c.config.KerberosSrvName
     54 		}
     55 		nextData, err = cli.GetInitToken(c.config.Host, service)
     56 	}
     57 	if err != nil {
     58 		return err
     59 	}
     60 
     61 	for {
     62 		gssResponse := &pgproto3.GSSResponse{
     63 			Data: nextData,
     64 		}
     65 		_, err = c.conn.Write(gssResponse.Encode(nil))
     66 		if err != nil {
     67 			return err
     68 		}
     69 		resp, err := c.rxGSSContinue()
     70 		if err != nil {
     71 			return err
     72 		}
     73 		var done bool
     74 		done, nextData, err = cli.Continue(resp.Data)
     75 		if err != nil {
     76 			return err
     77 		}
     78 		if done {
     79 			break
     80 		}
     81 	}
     82 	return nil
     83 }
     84 
     85 func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
     86 	msg, err := c.receiveMessage()
     87 	if err != nil {
     88 		return nil, err
     89 	}
     90 
     91 	switch m := msg.(type) {
     92 	case *pgproto3.AuthenticationGSSContinue:
     93 		return m, nil
     94 	case *pgproto3.ErrorResponse:
     95 		return nil, ErrorResponseToPgError(m)
     96 	}
     97 
     98 	return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
     99 }