gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

frontend.go (10343B)


      1 package pgproto3
      2 
      3 import (
      4 	"bytes"
      5 	"encoding/binary"
      6 	"errors"
      7 	"fmt"
      8 	"io"
      9 )
     10 
     11 // Frontend acts as a client for the PostgreSQL wire protocol version 3.
     12 type Frontend struct {
     13 	cr *chunkReader
     14 	w  io.Writer
     15 
     16 	// tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced
     17 	// before it is actually transmitted (i.e. before Flush). It is safe to change this variable when the Frontend is
     18 	// idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq.
     19 	tracer *tracer
     20 
     21 	wbuf []byte
     22 
     23 	// Backend message flyweights
     24 	authenticationOk                AuthenticationOk
     25 	authenticationCleartextPassword AuthenticationCleartextPassword
     26 	authenticationMD5Password       AuthenticationMD5Password
     27 	authenticationGSS               AuthenticationGSS
     28 	authenticationGSSContinue       AuthenticationGSSContinue
     29 	authenticationSASL              AuthenticationSASL
     30 	authenticationSASLContinue      AuthenticationSASLContinue
     31 	authenticationSASLFinal         AuthenticationSASLFinal
     32 	backendKeyData                  BackendKeyData
     33 	bindComplete                    BindComplete
     34 	closeComplete                   CloseComplete
     35 	commandComplete                 CommandComplete
     36 	copyBothResponse                CopyBothResponse
     37 	copyData                        CopyData
     38 	copyInResponse                  CopyInResponse
     39 	copyOutResponse                 CopyOutResponse
     40 	copyDone                        CopyDone
     41 	dataRow                         DataRow
     42 	emptyQueryResponse              EmptyQueryResponse
     43 	errorResponse                   ErrorResponse
     44 	functionCallResponse            FunctionCallResponse
     45 	noData                          NoData
     46 	noticeResponse                  NoticeResponse
     47 	notificationResponse            NotificationResponse
     48 	parameterDescription            ParameterDescription
     49 	parameterStatus                 ParameterStatus
     50 	parseComplete                   ParseComplete
     51 	readyForQuery                   ReadyForQuery
     52 	rowDescription                  RowDescription
     53 	portalSuspended                 PortalSuspended
     54 
     55 	bodyLen    int
     56 	msgType    byte
     57 	partialMsg bool
     58 	authType   uint32
     59 }
     60 
     61 // NewFrontend creates a new Frontend.
     62 func NewFrontend(r io.Reader, w io.Writer) *Frontend {
     63 	cr := newChunkReader(r, 0)
     64 	return &Frontend{cr: cr, w: w}
     65 }
     66 
     67 // Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is
     68 // called.
     69 //
     70 // Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods
     71 // such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an
     72 // extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden
     73 // behind an interface.
     74 func (f *Frontend) Send(msg FrontendMessage) {
     75 	prevLen := len(f.wbuf)
     76 	f.wbuf = msg.Encode(f.wbuf)
     77 	if f.tracer != nil {
     78 		f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg)
     79 	}
     80 }
     81 
     82 // Flush writes any pending messages to the backend (i.e. the server).
     83 func (f *Frontend) Flush() error {
     84 	if len(f.wbuf) == 0 {
     85 		return nil
     86 	}
     87 
     88 	n, err := f.w.Write(f.wbuf)
     89 
     90 	const maxLen = 1024
     91 	if len(f.wbuf) > maxLen {
     92 		f.wbuf = make([]byte, 0, maxLen)
     93 	} else {
     94 		f.wbuf = f.wbuf[:0]
     95 	}
     96 
     97 	if err != nil {
     98 		return &writeError{err: err, safeToRetry: n == 0}
     99 	}
    100 
    101 	return nil
    102 }
    103 
    104 // Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function
    105 // PQtrace.
    106 func (f *Frontend) Trace(w io.Writer, options TracerOptions) {
    107 	f.tracer = &tracer{
    108 		w:             w,
    109 		buf:           &bytes.Buffer{},
    110 		TracerOptions: options,
    111 	}
    112 }
    113 
    114 // Untrace stops tracing.
    115 func (f *Frontend) Untrace() {
    116 	f.tracer = nil
    117 }
    118 
    119 // SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until
    120 // Flush is called.
    121 func (f *Frontend) SendBind(msg *Bind) {
    122 	prevLen := len(f.wbuf)
    123 	f.wbuf = msg.Encode(f.wbuf)
    124 	if f.tracer != nil {
    125 		f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg)
    126 	}
    127 }
    128 
    129 // SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until
    130 // Flush is called.
    131 func (f *Frontend) SendParse(msg *Parse) {
    132 	prevLen := len(f.wbuf)
    133 	f.wbuf = msg.Encode(f.wbuf)
    134 	if f.tracer != nil {
    135 		f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg)
    136 	}
    137 }
    138 
    139 // SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until
    140 // Flush is called.
    141 func (f *Frontend) SendClose(msg *Close) {
    142 	prevLen := len(f.wbuf)
    143 	f.wbuf = msg.Encode(f.wbuf)
    144 	if f.tracer != nil {
    145 		f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg)
    146 	}
    147 }
    148 
    149 // SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until
    150 // Flush is called.
    151 func (f *Frontend) SendDescribe(msg *Describe) {
    152 	prevLen := len(f.wbuf)
    153 	f.wbuf = msg.Encode(f.wbuf)
    154 	if f.tracer != nil {
    155 		f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg)
    156 	}
    157 }
    158 
    159 // SendExecute sends a Execute message to the backend (i.e. the server). The message is not guaranteed to be written until
    160 // Flush is called.
    161 func (f *Frontend) SendExecute(msg *Execute) {
    162 	prevLen := len(f.wbuf)
    163 	f.wbuf = msg.Encode(f.wbuf)
    164 	if f.tracer != nil {
    165 		f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
    166 	}
    167 }
    168 
    169 // SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until
    170 // Flush is called.
    171 func (f *Frontend) SendSync(msg *Sync) {
    172 	prevLen := len(f.wbuf)
    173 	f.wbuf = msg.Encode(f.wbuf)
    174 	if f.tracer != nil {
    175 		f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg)
    176 	}
    177 }
    178 
    179 // SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until
    180 // Flush is called.
    181 func (f *Frontend) SendQuery(msg *Query) {
    182 	prevLen := len(f.wbuf)
    183 	f.wbuf = msg.Encode(f.wbuf)
    184 	if f.tracer != nil {
    185 		f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg)
    186 	}
    187 }
    188 
    189 // SendUnbufferedEncodedCopyData immediately sends an encoded CopyData message to the backend (i.e. the server). This method
    190 // is more efficient than sending a CopyData message with Send as the message data is not copied to the internal buffer
    191 // before being written out. The internal buffer is flushed before the message is sent.
    192 func (f *Frontend) SendUnbufferedEncodedCopyData(msg []byte) error {
    193 	err := f.Flush()
    194 	if err != nil {
    195 		return err
    196 	}
    197 
    198 	n, err := f.w.Write(msg)
    199 	if err != nil {
    200 		return &writeError{err: err, safeToRetry: n == 0}
    201 	}
    202 
    203 	if f.tracer != nil {
    204 		f.tracer.traceCopyData('F', int32(len(msg)-1), &CopyData{})
    205 	}
    206 
    207 	return nil
    208 }
    209 
    210 func translateEOFtoErrUnexpectedEOF(err error) error {
    211 	if err == io.EOF {
    212 		return io.ErrUnexpectedEOF
    213 	}
    214 	return err
    215 }
    216 
    217 // Receive receives a message from the backend. The returned message is only valid until the next call to Receive.
    218 func (f *Frontend) Receive() (BackendMessage, error) {
    219 	if !f.partialMsg {
    220 		header, err := f.cr.Next(5)
    221 		if err != nil {
    222 			return nil, translateEOFtoErrUnexpectedEOF(err)
    223 		}
    224 
    225 		f.msgType = header[0]
    226 
    227 		msgLength := int(binary.BigEndian.Uint32(header[1:]))
    228 		if msgLength < 4 {
    229 			return nil, fmt.Errorf("invalid message length: %d", msgLength)
    230 		}
    231 
    232 		f.bodyLen = msgLength - 4
    233 		f.partialMsg = true
    234 	}
    235 
    236 	msgBody, err := f.cr.Next(f.bodyLen)
    237 	if err != nil {
    238 		return nil, translateEOFtoErrUnexpectedEOF(err)
    239 	}
    240 
    241 	f.partialMsg = false
    242 
    243 	var msg BackendMessage
    244 	switch f.msgType {
    245 	case '1':
    246 		msg = &f.parseComplete
    247 	case '2':
    248 		msg = &f.bindComplete
    249 	case '3':
    250 		msg = &f.closeComplete
    251 	case 'A':
    252 		msg = &f.notificationResponse
    253 	case 'c':
    254 		msg = &f.copyDone
    255 	case 'C':
    256 		msg = &f.commandComplete
    257 	case 'd':
    258 		msg = &f.copyData
    259 	case 'D':
    260 		msg = &f.dataRow
    261 	case 'E':
    262 		msg = &f.errorResponse
    263 	case 'G':
    264 		msg = &f.copyInResponse
    265 	case 'H':
    266 		msg = &f.copyOutResponse
    267 	case 'I':
    268 		msg = &f.emptyQueryResponse
    269 	case 'K':
    270 		msg = &f.backendKeyData
    271 	case 'n':
    272 		msg = &f.noData
    273 	case 'N':
    274 		msg = &f.noticeResponse
    275 	case 'R':
    276 		var err error
    277 		msg, err = f.findAuthenticationMessageType(msgBody)
    278 		if err != nil {
    279 			return nil, err
    280 		}
    281 	case 's':
    282 		msg = &f.portalSuspended
    283 	case 'S':
    284 		msg = &f.parameterStatus
    285 	case 't':
    286 		msg = &f.parameterDescription
    287 	case 'T':
    288 		msg = &f.rowDescription
    289 	case 'V':
    290 		msg = &f.functionCallResponse
    291 	case 'W':
    292 		msg = &f.copyBothResponse
    293 	case 'Z':
    294 		msg = &f.readyForQuery
    295 	default:
    296 		return nil, fmt.Errorf("unknown message type: %c", f.msgType)
    297 	}
    298 
    299 	err = msg.Decode(msgBody)
    300 	if err != nil {
    301 		return nil, err
    302 	}
    303 
    304 	if f.tracer != nil {
    305 		f.tracer.traceMessage('B', int32(5+len(msgBody)), msg)
    306 	}
    307 
    308 	return msg, nil
    309 }
    310 
    311 // Authentication message type constants.
    312 // See src/include/libpq/pqcomm.h for all
    313 // constants.
    314 const (
    315 	AuthTypeOk                = 0
    316 	AuthTypeCleartextPassword = 3
    317 	AuthTypeMD5Password       = 5
    318 	AuthTypeSCMCreds          = 6
    319 	AuthTypeGSS               = 7
    320 	AuthTypeGSSCont           = 8
    321 	AuthTypeSSPI              = 9
    322 	AuthTypeSASL              = 10
    323 	AuthTypeSASLContinue      = 11
    324 	AuthTypeSASLFinal         = 12
    325 )
    326 
    327 func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
    328 	if len(src) < 4 {
    329 		return nil, errors.New("authentication message too short")
    330 	}
    331 	f.authType = binary.BigEndian.Uint32(src[:4])
    332 
    333 	switch f.authType {
    334 	case AuthTypeOk:
    335 		return &f.authenticationOk, nil
    336 	case AuthTypeCleartextPassword:
    337 		return &f.authenticationCleartextPassword, nil
    338 	case AuthTypeMD5Password:
    339 		return &f.authenticationMD5Password, nil
    340 	case AuthTypeSCMCreds:
    341 		return nil, errors.New("AuthTypeSCMCreds is unimplemented")
    342 	case AuthTypeGSS:
    343 		return &f.authenticationGSS, nil
    344 	case AuthTypeGSSCont:
    345 		return &f.authenticationGSSContinue, nil
    346 	case AuthTypeSSPI:
    347 		return nil, errors.New("AuthTypeSSPI is unimplemented")
    348 	case AuthTypeSASL:
    349 		return &f.authenticationSASL, nil
    350 	case AuthTypeSASLContinue:
    351 		return &f.authenticationSASLContinue, nil
    352 	case AuthTypeSASLFinal:
    353 		return &f.authenticationSASLFinal, nil
    354 	default:
    355 		return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
    356 	}
    357 }
    358 
    359 // GetAuthType returns the authType used in the current state of the frontend.
    360 // See SetAuthType for more information.
    361 func (f *Frontend) GetAuthType() uint32 {
    362 	return f.authType
    363 }