gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

frontend.go (5561B)


      1 package pgproto3
      2 
      3 import (
      4 	"encoding/binary"
      5 	"errors"
      6 	"fmt"
      7 	"io"
      8 )
      9 
     10 // Frontend acts as a client for the PostgreSQL wire protocol version 3.
     11 type Frontend struct {
     12 	cr ChunkReader
     13 	w  io.Writer
     14 
     15 	// Backend message flyweights
     16 	authenticationOk                AuthenticationOk
     17 	authenticationCleartextPassword AuthenticationCleartextPassword
     18 	authenticationMD5Password       AuthenticationMD5Password
     19 	authenticationGSS               AuthenticationGSS
     20 	authenticationGSSContinue       AuthenticationGSSContinue
     21 	authenticationSASL              AuthenticationSASL
     22 	authenticationSASLContinue      AuthenticationSASLContinue
     23 	authenticationSASLFinal         AuthenticationSASLFinal
     24 	backendKeyData                  BackendKeyData
     25 	bindComplete                    BindComplete
     26 	closeComplete                   CloseComplete
     27 	commandComplete                 CommandComplete
     28 	copyBothResponse                CopyBothResponse
     29 	copyData                        CopyData
     30 	copyInResponse                  CopyInResponse
     31 	copyOutResponse                 CopyOutResponse
     32 	copyDone                        CopyDone
     33 	dataRow                         DataRow
     34 	emptyQueryResponse              EmptyQueryResponse
     35 	errorResponse                   ErrorResponse
     36 	functionCallResponse            FunctionCallResponse
     37 	noData                          NoData
     38 	noticeResponse                  NoticeResponse
     39 	notificationResponse            NotificationResponse
     40 	parameterDescription            ParameterDescription
     41 	parameterStatus                 ParameterStatus
     42 	parseComplete                   ParseComplete
     43 	readyForQuery                   ReadyForQuery
     44 	rowDescription                  RowDescription
     45 	portalSuspended                 PortalSuspended
     46 
     47 	bodyLen    int
     48 	msgType    byte
     49 	partialMsg bool
     50 	authType   uint32
     51 }
     52 
     53 // NewFrontend creates a new Frontend.
     54 func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
     55 	return &Frontend{cr: cr, w: w}
     56 }
     57 
     58 // Send sends a message to the backend.
     59 func (f *Frontend) Send(msg FrontendMessage) error {
     60 	_, err := f.w.Write(msg.Encode(nil))
     61 	return err
     62 }
     63 
     64 func translateEOFtoErrUnexpectedEOF(err error) error {
     65 	if err == io.EOF {
     66 		return io.ErrUnexpectedEOF
     67 	}
     68 	return err
     69 }
     70 
     71 // Receive receives a message from the backend. The returned message is only valid until the next call to Receive.
     72 func (f *Frontend) Receive() (BackendMessage, error) {
     73 	if !f.partialMsg {
     74 		header, err := f.cr.Next(5)
     75 		if err != nil {
     76 			return nil, translateEOFtoErrUnexpectedEOF(err)
     77 		}
     78 
     79 		f.msgType = header[0]
     80 		f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
     81 		f.partialMsg = true
     82 		if f.bodyLen < 0 {
     83 			return nil, errors.New("invalid message with negative body length received")
     84 		}
     85 	}
     86 
     87 	msgBody, err := f.cr.Next(f.bodyLen)
     88 	if err != nil {
     89 		return nil, translateEOFtoErrUnexpectedEOF(err)
     90 	}
     91 
     92 	f.partialMsg = false
     93 
     94 	var msg BackendMessage
     95 	switch f.msgType {
     96 	case '1':
     97 		msg = &f.parseComplete
     98 	case '2':
     99 		msg = &f.bindComplete
    100 	case '3':
    101 		msg = &f.closeComplete
    102 	case 'A':
    103 		msg = &f.notificationResponse
    104 	case 'c':
    105 		msg = &f.copyDone
    106 	case 'C':
    107 		msg = &f.commandComplete
    108 	case 'd':
    109 		msg = &f.copyData
    110 	case 'D':
    111 		msg = &f.dataRow
    112 	case 'E':
    113 		msg = &f.errorResponse
    114 	case 'G':
    115 		msg = &f.copyInResponse
    116 	case 'H':
    117 		msg = &f.copyOutResponse
    118 	case 'I':
    119 		msg = &f.emptyQueryResponse
    120 	case 'K':
    121 		msg = &f.backendKeyData
    122 	case 'n':
    123 		msg = &f.noData
    124 	case 'N':
    125 		msg = &f.noticeResponse
    126 	case 'R':
    127 		var err error
    128 		msg, err = f.findAuthenticationMessageType(msgBody)
    129 		if err != nil {
    130 			return nil, err
    131 		}
    132 	case 's':
    133 		msg = &f.portalSuspended
    134 	case 'S':
    135 		msg = &f.parameterStatus
    136 	case 't':
    137 		msg = &f.parameterDescription
    138 	case 'T':
    139 		msg = &f.rowDescription
    140 	case 'V':
    141 		msg = &f.functionCallResponse
    142 	case 'W':
    143 		msg = &f.copyBothResponse
    144 	case 'Z':
    145 		msg = &f.readyForQuery
    146 	default:
    147 		return nil, fmt.Errorf("unknown message type: %c", f.msgType)
    148 	}
    149 
    150 	err = msg.Decode(msgBody)
    151 	return msg, err
    152 }
    153 
    154 // Authentication message type constants.
    155 // See src/include/libpq/pqcomm.h for all
    156 // constants.
    157 const (
    158 	AuthTypeOk                = 0
    159 	AuthTypeCleartextPassword = 3
    160 	AuthTypeMD5Password       = 5
    161 	AuthTypeSCMCreds          = 6
    162 	AuthTypeGSS               = 7
    163 	AuthTypeGSSCont           = 8
    164 	AuthTypeSSPI              = 9
    165 	AuthTypeSASL              = 10
    166 	AuthTypeSASLContinue      = 11
    167 	AuthTypeSASLFinal         = 12
    168 )
    169 
    170 func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
    171 	if len(src) < 4 {
    172 		return nil, errors.New("authentication message too short")
    173 	}
    174 	f.authType = binary.BigEndian.Uint32(src[:4])
    175 
    176 	switch f.authType {
    177 	case AuthTypeOk:
    178 		return &f.authenticationOk, nil
    179 	case AuthTypeCleartextPassword:
    180 		return &f.authenticationCleartextPassword, nil
    181 	case AuthTypeMD5Password:
    182 		return &f.authenticationMD5Password, nil
    183 	case AuthTypeSCMCreds:
    184 		return nil, errors.New("AuthTypeSCMCreds is unimplemented")
    185 	case AuthTypeGSS:
    186 		return &f.authenticationGSS, nil
    187 	case AuthTypeGSSCont:
    188 		return &f.authenticationGSSContinue, nil
    189 	case AuthTypeSSPI:
    190 		return nil, errors.New("AuthTypeSSPI is unimplemented")
    191 	case AuthTypeSASL:
    192 		return &f.authenticationSASL, nil
    193 	case AuthTypeSASLContinue:
    194 		return &f.authenticationSASLContinue, nil
    195 	case AuthTypeSASLFinal:
    196 		return &f.authenticationSASLFinal, nil
    197 	default:
    198 		return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
    199 	}
    200 }
    201 
    202 // GetAuthType returns the authType used in the current state of the frontend.
    203 // See SetAuthType for more information.
    204 func (f *Frontend) GetAuthType() uint32 {
    205 	return f.authType
    206 }