frontend.go (10343B)
1 package pgproto3 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "fmt" 8 "io" 9 ) 10 11 // Frontend acts as a client for the PostgreSQL wire protocol version 3. 12 type Frontend struct { 13 cr *chunkReader 14 w io.Writer 15 16 // tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced 17 // before it is actually transmitted (i.e. before Flush). It is safe to change this variable when the Frontend is 18 // idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq. 19 tracer *tracer 20 21 wbuf []byte 22 23 // Backend message flyweights 24 authenticationOk AuthenticationOk 25 authenticationCleartextPassword AuthenticationCleartextPassword 26 authenticationMD5Password AuthenticationMD5Password 27 authenticationGSS AuthenticationGSS 28 authenticationGSSContinue AuthenticationGSSContinue 29 authenticationSASL AuthenticationSASL 30 authenticationSASLContinue AuthenticationSASLContinue 31 authenticationSASLFinal AuthenticationSASLFinal 32 backendKeyData BackendKeyData 33 bindComplete BindComplete 34 closeComplete CloseComplete 35 commandComplete CommandComplete 36 copyBothResponse CopyBothResponse 37 copyData CopyData 38 copyInResponse CopyInResponse 39 copyOutResponse CopyOutResponse 40 copyDone CopyDone 41 dataRow DataRow 42 emptyQueryResponse EmptyQueryResponse 43 errorResponse ErrorResponse 44 functionCallResponse FunctionCallResponse 45 noData NoData 46 noticeResponse NoticeResponse 47 notificationResponse NotificationResponse 48 parameterDescription ParameterDescription 49 parameterStatus ParameterStatus 50 parseComplete ParseComplete 51 readyForQuery ReadyForQuery 52 rowDescription RowDescription 53 portalSuspended PortalSuspended 54 55 bodyLen int 56 msgType byte 57 partialMsg bool 58 authType uint32 59 } 60 61 // NewFrontend creates a new Frontend. 62 func NewFrontend(r io.Reader, w io.Writer) *Frontend { 63 cr := newChunkReader(r, 0) 64 return &Frontend{cr: cr, w: w} 65 } 66 67 // Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is 68 // called. 69 // 70 // Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods 71 // such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an 72 // extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden 73 // behind an interface. 74 func (f *Frontend) Send(msg FrontendMessage) { 75 prevLen := len(f.wbuf) 76 f.wbuf = msg.Encode(f.wbuf) 77 if f.tracer != nil { 78 f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg) 79 } 80 } 81 82 // Flush writes any pending messages to the backend (i.e. the server). 83 func (f *Frontend) Flush() error { 84 if len(f.wbuf) == 0 { 85 return nil 86 } 87 88 n, err := f.w.Write(f.wbuf) 89 90 const maxLen = 1024 91 if len(f.wbuf) > maxLen { 92 f.wbuf = make([]byte, 0, maxLen) 93 } else { 94 f.wbuf = f.wbuf[:0] 95 } 96 97 if err != nil { 98 return &writeError{err: err, safeToRetry: n == 0} 99 } 100 101 return nil 102 } 103 104 // Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function 105 // PQtrace. 106 func (f *Frontend) Trace(w io.Writer, options TracerOptions) { 107 f.tracer = &tracer{ 108 w: w, 109 buf: &bytes.Buffer{}, 110 TracerOptions: options, 111 } 112 } 113 114 // Untrace stops tracing. 115 func (f *Frontend) Untrace() { 116 f.tracer = nil 117 } 118 119 // SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until 120 // Flush is called. 121 func (f *Frontend) SendBind(msg *Bind) { 122 prevLen := len(f.wbuf) 123 f.wbuf = msg.Encode(f.wbuf) 124 if f.tracer != nil { 125 f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg) 126 } 127 } 128 129 // SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until 130 // Flush is called. 131 func (f *Frontend) SendParse(msg *Parse) { 132 prevLen := len(f.wbuf) 133 f.wbuf = msg.Encode(f.wbuf) 134 if f.tracer != nil { 135 f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg) 136 } 137 } 138 139 // SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until 140 // Flush is called. 141 func (f *Frontend) SendClose(msg *Close) { 142 prevLen := len(f.wbuf) 143 f.wbuf = msg.Encode(f.wbuf) 144 if f.tracer != nil { 145 f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg) 146 } 147 } 148 149 // SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until 150 // Flush is called. 151 func (f *Frontend) SendDescribe(msg *Describe) { 152 prevLen := len(f.wbuf) 153 f.wbuf = msg.Encode(f.wbuf) 154 if f.tracer != nil { 155 f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg) 156 } 157 } 158 159 // SendExecute sends a Execute message to the backend (i.e. the server). The message is not guaranteed to be written until 160 // Flush is called. 161 func (f *Frontend) SendExecute(msg *Execute) { 162 prevLen := len(f.wbuf) 163 f.wbuf = msg.Encode(f.wbuf) 164 if f.tracer != nil { 165 f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg) 166 } 167 } 168 169 // SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until 170 // Flush is called. 171 func (f *Frontend) SendSync(msg *Sync) { 172 prevLen := len(f.wbuf) 173 f.wbuf = msg.Encode(f.wbuf) 174 if f.tracer != nil { 175 f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg) 176 } 177 } 178 179 // SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until 180 // Flush is called. 181 func (f *Frontend) SendQuery(msg *Query) { 182 prevLen := len(f.wbuf) 183 f.wbuf = msg.Encode(f.wbuf) 184 if f.tracer != nil { 185 f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg) 186 } 187 } 188 189 // SendUnbufferedEncodedCopyData immediately sends an encoded CopyData message to the backend (i.e. the server). This method 190 // is more efficient than sending a CopyData message with Send as the message data is not copied to the internal buffer 191 // before being written out. The internal buffer is flushed before the message is sent. 192 func (f *Frontend) SendUnbufferedEncodedCopyData(msg []byte) error { 193 err := f.Flush() 194 if err != nil { 195 return err 196 } 197 198 n, err := f.w.Write(msg) 199 if err != nil { 200 return &writeError{err: err, safeToRetry: n == 0} 201 } 202 203 if f.tracer != nil { 204 f.tracer.traceCopyData('F', int32(len(msg)-1), &CopyData{}) 205 } 206 207 return nil 208 } 209 210 func translateEOFtoErrUnexpectedEOF(err error) error { 211 if err == io.EOF { 212 return io.ErrUnexpectedEOF 213 } 214 return err 215 } 216 217 // Receive receives a message from the backend. The returned message is only valid until the next call to Receive. 218 func (f *Frontend) Receive() (BackendMessage, error) { 219 if !f.partialMsg { 220 header, err := f.cr.Next(5) 221 if err != nil { 222 return nil, translateEOFtoErrUnexpectedEOF(err) 223 } 224 225 f.msgType = header[0] 226 227 msgLength := int(binary.BigEndian.Uint32(header[1:])) 228 if msgLength < 4 { 229 return nil, fmt.Errorf("invalid message length: %d", msgLength) 230 } 231 232 f.bodyLen = msgLength - 4 233 f.partialMsg = true 234 } 235 236 msgBody, err := f.cr.Next(f.bodyLen) 237 if err != nil { 238 return nil, translateEOFtoErrUnexpectedEOF(err) 239 } 240 241 f.partialMsg = false 242 243 var msg BackendMessage 244 switch f.msgType { 245 case '1': 246 msg = &f.parseComplete 247 case '2': 248 msg = &f.bindComplete 249 case '3': 250 msg = &f.closeComplete 251 case 'A': 252 msg = &f.notificationResponse 253 case 'c': 254 msg = &f.copyDone 255 case 'C': 256 msg = &f.commandComplete 257 case 'd': 258 msg = &f.copyData 259 case 'D': 260 msg = &f.dataRow 261 case 'E': 262 msg = &f.errorResponse 263 case 'G': 264 msg = &f.copyInResponse 265 case 'H': 266 msg = &f.copyOutResponse 267 case 'I': 268 msg = &f.emptyQueryResponse 269 case 'K': 270 msg = &f.backendKeyData 271 case 'n': 272 msg = &f.noData 273 case 'N': 274 msg = &f.noticeResponse 275 case 'R': 276 var err error 277 msg, err = f.findAuthenticationMessageType(msgBody) 278 if err != nil { 279 return nil, err 280 } 281 case 's': 282 msg = &f.portalSuspended 283 case 'S': 284 msg = &f.parameterStatus 285 case 't': 286 msg = &f.parameterDescription 287 case 'T': 288 msg = &f.rowDescription 289 case 'V': 290 msg = &f.functionCallResponse 291 case 'W': 292 msg = &f.copyBothResponse 293 case 'Z': 294 msg = &f.readyForQuery 295 default: 296 return nil, fmt.Errorf("unknown message type: %c", f.msgType) 297 } 298 299 err = msg.Decode(msgBody) 300 if err != nil { 301 return nil, err 302 } 303 304 if f.tracer != nil { 305 f.tracer.traceMessage('B', int32(5+len(msgBody)), msg) 306 } 307 308 return msg, nil 309 } 310 311 // Authentication message type constants. 312 // See src/include/libpq/pqcomm.h for all 313 // constants. 314 const ( 315 AuthTypeOk = 0 316 AuthTypeCleartextPassword = 3 317 AuthTypeMD5Password = 5 318 AuthTypeSCMCreds = 6 319 AuthTypeGSS = 7 320 AuthTypeGSSCont = 8 321 AuthTypeSSPI = 9 322 AuthTypeSASL = 10 323 AuthTypeSASLContinue = 11 324 AuthTypeSASLFinal = 12 325 ) 326 327 func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) { 328 if len(src) < 4 { 329 return nil, errors.New("authentication message too short") 330 } 331 f.authType = binary.BigEndian.Uint32(src[:4]) 332 333 switch f.authType { 334 case AuthTypeOk: 335 return &f.authenticationOk, nil 336 case AuthTypeCleartextPassword: 337 return &f.authenticationCleartextPassword, nil 338 case AuthTypeMD5Password: 339 return &f.authenticationMD5Password, nil 340 case AuthTypeSCMCreds: 341 return nil, errors.New("AuthTypeSCMCreds is unimplemented") 342 case AuthTypeGSS: 343 return &f.authenticationGSS, nil 344 case AuthTypeGSSCont: 345 return &f.authenticationGSSContinue, nil 346 case AuthTypeSSPI: 347 return nil, errors.New("AuthTypeSSPI is unimplemented") 348 case AuthTypeSASL: 349 return &f.authenticationSASL, nil 350 case AuthTypeSASLContinue: 351 return &f.authenticationSASLContinue, nil 352 case AuthTypeSASLFinal: 353 return &f.authenticationSASLFinal, nil 354 default: 355 return nil, fmt.Errorf("unknown authentication type: %d", f.authType) 356 } 357 } 358 359 // GetAuthType returns the authType used in the current state of the frontend. 360 // See SetAuthType for more information. 361 func (f *Frontend) GetAuthType() uint32 { 362 return f.authType 363 }