backend.go (6385B)
1 package pgproto3 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "fmt" 7 "io" 8 ) 9 10 // Backend acts as a server for the PostgreSQL wire protocol version 3. 11 type Backend struct { 12 cr *chunkReader 13 w io.Writer 14 15 // tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced 16 // before it is actually transmitted (i.e. before Flush). 17 tracer *tracer 18 19 wbuf []byte 20 21 // Frontend message flyweights 22 bind Bind 23 cancelRequest CancelRequest 24 _close Close 25 copyFail CopyFail 26 copyData CopyData 27 copyDone CopyDone 28 describe Describe 29 execute Execute 30 flush Flush 31 functionCall FunctionCall 32 gssEncRequest GSSEncRequest 33 parse Parse 34 query Query 35 sslRequest SSLRequest 36 startupMessage StartupMessage 37 sync Sync 38 terminate Terminate 39 40 bodyLen int 41 msgType byte 42 partialMsg bool 43 authType uint32 44 } 45 46 const ( 47 minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. 48 maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. 49 ) 50 51 // NewBackend creates a new Backend. 52 func NewBackend(r io.Reader, w io.Writer) *Backend { 53 cr := newChunkReader(r, 0) 54 return &Backend{cr: cr, w: w} 55 } 56 57 // Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is 58 // called. 59 func (b *Backend) Send(msg BackendMessage) { 60 prevLen := len(b.wbuf) 61 b.wbuf = msg.Encode(b.wbuf) 62 if b.tracer != nil { 63 b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg) 64 } 65 } 66 67 // Flush writes any pending messages to the frontend (i.e. the client). 68 func (b *Backend) Flush() error { 69 n, err := b.w.Write(b.wbuf) 70 71 const maxLen = 1024 72 if len(b.wbuf) > maxLen { 73 b.wbuf = make([]byte, 0, maxLen) 74 } else { 75 b.wbuf = b.wbuf[:0] 76 } 77 78 if err != nil { 79 return &writeError{err: err, safeToRetry: n == 0} 80 } 81 82 return nil 83 } 84 85 // Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function 86 // PQtrace. 87 func (b *Backend) Trace(w io.Writer, options TracerOptions) { 88 b.tracer = &tracer{ 89 w: w, 90 buf: &bytes.Buffer{}, 91 TracerOptions: options, 92 } 93 } 94 95 // Untrace stops tracing. 96 func (b *Backend) Untrace() { 97 b.tracer = nil 98 } 99 100 // ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method 101 // because the initial connection message is "special" and does not include the message type as the first byte. This 102 // will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest. 103 func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { 104 buf, err := b.cr.Next(4) 105 if err != nil { 106 return nil, err 107 } 108 msgSize := int(binary.BigEndian.Uint32(buf) - 4) 109 110 if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen { 111 return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize) 112 } 113 114 buf, err = b.cr.Next(msgSize) 115 if err != nil { 116 return nil, translateEOFtoErrUnexpectedEOF(err) 117 } 118 119 code := binary.BigEndian.Uint32(buf) 120 121 switch code { 122 case ProtocolVersionNumber: 123 err = b.startupMessage.Decode(buf) 124 if err != nil { 125 return nil, err 126 } 127 return &b.startupMessage, nil 128 case sslRequestNumber: 129 err = b.sslRequest.Decode(buf) 130 if err != nil { 131 return nil, err 132 } 133 return &b.sslRequest, nil 134 case cancelRequestCode: 135 err = b.cancelRequest.Decode(buf) 136 if err != nil { 137 return nil, err 138 } 139 return &b.cancelRequest, nil 140 case gssEncReqNumber: 141 err = b.gssEncRequest.Decode(buf) 142 if err != nil { 143 return nil, err 144 } 145 return &b.gssEncRequest, nil 146 default: 147 return nil, fmt.Errorf("unknown startup message code: %d", code) 148 } 149 } 150 151 // Receive receives a message from the frontend. The returned message is only valid until the next call to Receive. 152 func (b *Backend) Receive() (FrontendMessage, error) { 153 if !b.partialMsg { 154 header, err := b.cr.Next(5) 155 if err != nil { 156 return nil, translateEOFtoErrUnexpectedEOF(err) 157 } 158 159 b.msgType = header[0] 160 b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 161 b.partialMsg = true 162 } 163 164 var msg FrontendMessage 165 switch b.msgType { 166 case 'B': 167 msg = &b.bind 168 case 'C': 169 msg = &b._close 170 case 'D': 171 msg = &b.describe 172 case 'E': 173 msg = &b.execute 174 case 'F': 175 msg = &b.functionCall 176 case 'f': 177 msg = &b.copyFail 178 case 'd': 179 msg = &b.copyData 180 case 'c': 181 msg = &b.copyDone 182 case 'H': 183 msg = &b.flush 184 case 'P': 185 msg = &b.parse 186 case 'p': 187 switch b.authType { 188 case AuthTypeSASL: 189 msg = &SASLInitialResponse{} 190 case AuthTypeSASLContinue: 191 msg = &SASLResponse{} 192 case AuthTypeSASLFinal: 193 msg = &SASLResponse{} 194 case AuthTypeGSS, AuthTypeGSSCont: 195 msg = &GSSResponse{} 196 case AuthTypeCleartextPassword, AuthTypeMD5Password: 197 fallthrough 198 default: 199 // to maintain backwards compatibility 200 msg = &PasswordMessage{} 201 } 202 case 'Q': 203 msg = &b.query 204 case 'S': 205 msg = &b.sync 206 case 'X': 207 msg = &b.terminate 208 default: 209 return nil, fmt.Errorf("unknown message type: %c", b.msgType) 210 } 211 212 msgBody, err := b.cr.Next(b.bodyLen) 213 if err != nil { 214 return nil, translateEOFtoErrUnexpectedEOF(err) 215 } 216 217 b.partialMsg = false 218 219 err = msg.Decode(msgBody) 220 if err != nil { 221 return nil, err 222 } 223 224 if b.tracer != nil { 225 b.tracer.traceMessage('F', int32(5+len(msgBody)), msg) 226 } 227 228 return msg, nil 229 } 230 231 // SetAuthType sets the authentication type in the backend. 232 // Since multiple message types can start with 'p', SetAuthType allows 233 // contextual identification of FrontendMessages. For example, in the 234 // PG message flow documentation for PasswordMessage: 235 // 236 // Byte1('p') 237 // 238 // Identifies the message as a password response. Note that this is also used for 239 // GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from 240 // the context. 241 // 242 // Since the Frontend does not know about the state of a backend, it is important 243 // to call SetAuthType() after an authentication request is received by the Frontend. 244 func (b *Backend) SetAuthType(authType uint32) error { 245 switch authType { 246 case AuthTypeOk, 247 AuthTypeCleartextPassword, 248 AuthTypeMD5Password, 249 AuthTypeSCMCreds, 250 AuthTypeGSS, 251 AuthTypeGSSCont, 252 AuthTypeSSPI, 253 AuthTypeSASL, 254 AuthTypeSASLContinue, 255 AuthTypeSASLFinal: 256 b.authType = authType 257 default: 258 return fmt.Errorf("authType not recognized: %d", authType) 259 } 260 261 return nil 262 }