gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

backend.go (6385B)


      1 package pgproto3
      2 
      3 import (
      4 	"bytes"
      5 	"encoding/binary"
      6 	"fmt"
      7 	"io"
      8 )
      9 
     10 // Backend acts as a server for the PostgreSQL wire protocol version 3.
     11 type Backend struct {
     12 	cr *chunkReader
     13 	w  io.Writer
     14 
     15 	// tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced
     16 	// before it is actually transmitted (i.e. before Flush).
     17 	tracer *tracer
     18 
     19 	wbuf []byte
     20 
     21 	// Frontend message flyweights
     22 	bind           Bind
     23 	cancelRequest  CancelRequest
     24 	_close         Close
     25 	copyFail       CopyFail
     26 	copyData       CopyData
     27 	copyDone       CopyDone
     28 	describe       Describe
     29 	execute        Execute
     30 	flush          Flush
     31 	functionCall   FunctionCall
     32 	gssEncRequest  GSSEncRequest
     33 	parse          Parse
     34 	query          Query
     35 	sslRequest     SSLRequest
     36 	startupMessage StartupMessage
     37 	sync           Sync
     38 	terminate      Terminate
     39 
     40 	bodyLen    int
     41 	msgType    byte
     42 	partialMsg bool
     43 	authType   uint32
     44 }
     45 
     46 const (
     47 	minStartupPacketLen = 4     // minStartupPacketLen is a single 32-bit int version or code.
     48 	maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
     49 )
     50 
     51 // NewBackend creates a new Backend.
     52 func NewBackend(r io.Reader, w io.Writer) *Backend {
     53 	cr := newChunkReader(r, 0)
     54 	return &Backend{cr: cr, w: w}
     55 }
     56 
     57 // Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
     58 // called.
     59 func (b *Backend) Send(msg BackendMessage) {
     60 	prevLen := len(b.wbuf)
     61 	b.wbuf = msg.Encode(b.wbuf)
     62 	if b.tracer != nil {
     63 		b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
     64 	}
     65 }
     66 
     67 // Flush writes any pending messages to the frontend (i.e. the client).
     68 func (b *Backend) Flush() error {
     69 	n, err := b.w.Write(b.wbuf)
     70 
     71 	const maxLen = 1024
     72 	if len(b.wbuf) > maxLen {
     73 		b.wbuf = make([]byte, 0, maxLen)
     74 	} else {
     75 		b.wbuf = b.wbuf[:0]
     76 	}
     77 
     78 	if err != nil {
     79 		return &writeError{err: err, safeToRetry: n == 0}
     80 	}
     81 
     82 	return nil
     83 }
     84 
     85 // Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function
     86 // PQtrace.
     87 func (b *Backend) Trace(w io.Writer, options TracerOptions) {
     88 	b.tracer = &tracer{
     89 		w:             w,
     90 		buf:           &bytes.Buffer{},
     91 		TracerOptions: options,
     92 	}
     93 }
     94 
     95 // Untrace stops tracing.
     96 func (b *Backend) Untrace() {
     97 	b.tracer = nil
     98 }
     99 
    100 // ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
    101 // because the initial connection message is "special" and does not include the message type as the first byte. This
    102 // will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest.
    103 func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
    104 	buf, err := b.cr.Next(4)
    105 	if err != nil {
    106 		return nil, err
    107 	}
    108 	msgSize := int(binary.BigEndian.Uint32(buf) - 4)
    109 
    110 	if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
    111 		return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
    112 	}
    113 
    114 	buf, err = b.cr.Next(msgSize)
    115 	if err != nil {
    116 		return nil, translateEOFtoErrUnexpectedEOF(err)
    117 	}
    118 
    119 	code := binary.BigEndian.Uint32(buf)
    120 
    121 	switch code {
    122 	case ProtocolVersionNumber:
    123 		err = b.startupMessage.Decode(buf)
    124 		if err != nil {
    125 			return nil, err
    126 		}
    127 		return &b.startupMessage, nil
    128 	case sslRequestNumber:
    129 		err = b.sslRequest.Decode(buf)
    130 		if err != nil {
    131 			return nil, err
    132 		}
    133 		return &b.sslRequest, nil
    134 	case cancelRequestCode:
    135 		err = b.cancelRequest.Decode(buf)
    136 		if err != nil {
    137 			return nil, err
    138 		}
    139 		return &b.cancelRequest, nil
    140 	case gssEncReqNumber:
    141 		err = b.gssEncRequest.Decode(buf)
    142 		if err != nil {
    143 			return nil, err
    144 		}
    145 		return &b.gssEncRequest, nil
    146 	default:
    147 		return nil, fmt.Errorf("unknown startup message code: %d", code)
    148 	}
    149 }
    150 
    151 // Receive receives a message from the frontend. The returned message is only valid until the next call to Receive.
    152 func (b *Backend) Receive() (FrontendMessage, error) {
    153 	if !b.partialMsg {
    154 		header, err := b.cr.Next(5)
    155 		if err != nil {
    156 			return nil, translateEOFtoErrUnexpectedEOF(err)
    157 		}
    158 
    159 		b.msgType = header[0]
    160 		b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
    161 		b.partialMsg = true
    162 	}
    163 
    164 	var msg FrontendMessage
    165 	switch b.msgType {
    166 	case 'B':
    167 		msg = &b.bind
    168 	case 'C':
    169 		msg = &b._close
    170 	case 'D':
    171 		msg = &b.describe
    172 	case 'E':
    173 		msg = &b.execute
    174 	case 'F':
    175 		msg = &b.functionCall
    176 	case 'f':
    177 		msg = &b.copyFail
    178 	case 'd':
    179 		msg = &b.copyData
    180 	case 'c':
    181 		msg = &b.copyDone
    182 	case 'H':
    183 		msg = &b.flush
    184 	case 'P':
    185 		msg = &b.parse
    186 	case 'p':
    187 		switch b.authType {
    188 		case AuthTypeSASL:
    189 			msg = &SASLInitialResponse{}
    190 		case AuthTypeSASLContinue:
    191 			msg = &SASLResponse{}
    192 		case AuthTypeSASLFinal:
    193 			msg = &SASLResponse{}
    194 		case AuthTypeGSS, AuthTypeGSSCont:
    195 			msg = &GSSResponse{}
    196 		case AuthTypeCleartextPassword, AuthTypeMD5Password:
    197 			fallthrough
    198 		default:
    199 			// to maintain backwards compatibility
    200 			msg = &PasswordMessage{}
    201 		}
    202 	case 'Q':
    203 		msg = &b.query
    204 	case 'S':
    205 		msg = &b.sync
    206 	case 'X':
    207 		msg = &b.terminate
    208 	default:
    209 		return nil, fmt.Errorf("unknown message type: %c", b.msgType)
    210 	}
    211 
    212 	msgBody, err := b.cr.Next(b.bodyLen)
    213 	if err != nil {
    214 		return nil, translateEOFtoErrUnexpectedEOF(err)
    215 	}
    216 
    217 	b.partialMsg = false
    218 
    219 	err = msg.Decode(msgBody)
    220 	if err != nil {
    221 		return nil, err
    222 	}
    223 
    224 	if b.tracer != nil {
    225 		b.tracer.traceMessage('F', int32(5+len(msgBody)), msg)
    226 	}
    227 
    228 	return msg, nil
    229 }
    230 
    231 // SetAuthType sets the authentication type in the backend.
    232 // Since multiple message types can start with 'p', SetAuthType allows
    233 // contextual identification of FrontendMessages. For example, in the
    234 // PG message flow documentation for PasswordMessage:
    235 //
    236 //			Byte1('p')
    237 //
    238 //	     Identifies the message as a password response. Note that this is also used for
    239 //			GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
    240 //			the context.
    241 //
    242 // Since the Frontend does not know about the state of a backend, it is important
    243 // to call SetAuthType() after an authentication request is received by the Frontend.
    244 func (b *Backend) SetAuthType(authType uint32) error {
    245 	switch authType {
    246 	case AuthTypeOk,
    247 		AuthTypeCleartextPassword,
    248 		AuthTypeMD5Password,
    249 		AuthTypeSCMCreds,
    250 		AuthTypeGSS,
    251 		AuthTypeGSSCont,
    252 		AuthTypeSSPI,
    253 		AuthTypeSASL,
    254 		AuthTypeSASLContinue,
    255 		AuthTypeSASLFinal:
    256 		b.authType = authType
    257 	default:
    258 		return fmt.Errorf("authType not recognized: %d", authType)
    259 	}
    260 
    261 	return nil
    262 }