backend.go (5251B)
1 package pgproto3 2 3 import ( 4 "encoding/binary" 5 "errors" 6 "fmt" 7 "io" 8 ) 9 10 // Backend acts as a server for the PostgreSQL wire protocol version 3. 11 type Backend struct { 12 cr ChunkReader 13 w io.Writer 14 15 // Frontend message flyweights 16 bind Bind 17 cancelRequest CancelRequest 18 _close Close 19 copyFail CopyFail 20 copyData CopyData 21 copyDone CopyDone 22 describe Describe 23 execute Execute 24 flush Flush 25 functionCall FunctionCall 26 gssEncRequest GSSEncRequest 27 parse Parse 28 query Query 29 sslRequest SSLRequest 30 startupMessage StartupMessage 31 sync Sync 32 terminate Terminate 33 34 bodyLen int 35 msgType byte 36 partialMsg bool 37 authType uint32 38 } 39 40 const ( 41 minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. 42 maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. 43 ) 44 45 // NewBackend creates a new Backend. 46 func NewBackend(cr ChunkReader, w io.Writer) *Backend { 47 return &Backend{cr: cr, w: w} 48 } 49 50 // Send sends a message to the frontend. 51 func (b *Backend) Send(msg BackendMessage) error { 52 _, err := b.w.Write(msg.Encode(nil)) 53 return err 54 } 55 56 // ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method 57 // because the initial connection message is "special" and does not include the message type as the first byte. This 58 // will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest. 59 func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { 60 buf, err := b.cr.Next(4) 61 if err != nil { 62 return nil, err 63 } 64 msgSize := int(binary.BigEndian.Uint32(buf) - 4) 65 66 if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen { 67 return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize) 68 } 69 70 buf, err = b.cr.Next(msgSize) 71 if err != nil { 72 return nil, translateEOFtoErrUnexpectedEOF(err) 73 } 74 75 code := binary.BigEndian.Uint32(buf) 76 77 switch code { 78 case ProtocolVersionNumber: 79 err = b.startupMessage.Decode(buf) 80 if err != nil { 81 return nil, err 82 } 83 return &b.startupMessage, nil 84 case sslRequestNumber: 85 err = b.sslRequest.Decode(buf) 86 if err != nil { 87 return nil, err 88 } 89 return &b.sslRequest, nil 90 case cancelRequestCode: 91 err = b.cancelRequest.Decode(buf) 92 if err != nil { 93 return nil, err 94 } 95 return &b.cancelRequest, nil 96 case gssEncReqNumber: 97 err = b.gssEncRequest.Decode(buf) 98 if err != nil { 99 return nil, err 100 } 101 return &b.gssEncRequest, nil 102 default: 103 return nil, fmt.Errorf("unknown startup message code: %d", code) 104 } 105 } 106 107 // Receive receives a message from the frontend. The returned message is only valid until the next call to Receive. 108 func (b *Backend) Receive() (FrontendMessage, error) { 109 if !b.partialMsg { 110 header, err := b.cr.Next(5) 111 if err != nil { 112 return nil, translateEOFtoErrUnexpectedEOF(err) 113 } 114 115 b.msgType = header[0] 116 b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 117 b.partialMsg = true 118 if b.bodyLen < 0 { 119 return nil, errors.New("invalid message with negative body length received") 120 } 121 } 122 123 var msg FrontendMessage 124 switch b.msgType { 125 case 'B': 126 msg = &b.bind 127 case 'C': 128 msg = &b._close 129 case 'D': 130 msg = &b.describe 131 case 'E': 132 msg = &b.execute 133 case 'F': 134 msg = &b.functionCall 135 case 'f': 136 msg = &b.copyFail 137 case 'd': 138 msg = &b.copyData 139 case 'c': 140 msg = &b.copyDone 141 case 'H': 142 msg = &b.flush 143 case 'P': 144 msg = &b.parse 145 case 'p': 146 switch b.authType { 147 case AuthTypeSASL: 148 msg = &SASLInitialResponse{} 149 case AuthTypeSASLContinue: 150 msg = &SASLResponse{} 151 case AuthTypeSASLFinal: 152 msg = &SASLResponse{} 153 case AuthTypeGSS, AuthTypeGSSCont: 154 msg = &GSSResponse{} 155 case AuthTypeCleartextPassword, AuthTypeMD5Password: 156 fallthrough 157 default: 158 // to maintain backwards compatability 159 msg = &PasswordMessage{} 160 } 161 case 'Q': 162 msg = &b.query 163 case 'S': 164 msg = &b.sync 165 case 'X': 166 msg = &b.terminate 167 default: 168 return nil, fmt.Errorf("unknown message type: %c", b.msgType) 169 } 170 171 msgBody, err := b.cr.Next(b.bodyLen) 172 if err != nil { 173 return nil, translateEOFtoErrUnexpectedEOF(err) 174 } 175 176 b.partialMsg = false 177 178 err = msg.Decode(msgBody) 179 return msg, err 180 } 181 182 // SetAuthType sets the authentication type in the backend. 183 // Since multiple message types can start with 'p', SetAuthType allows 184 // contextual identification of FrontendMessages. For example, in the 185 // PG message flow documentation for PasswordMessage: 186 // 187 // Byte1('p') 188 // 189 // Identifies the message as a password response. Note that this is also used for 190 // GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from 191 // the context. 192 // 193 // Since the Frontend does not know about the state of a backend, it is important 194 // to call SetAuthType() after an authentication request is received by the Frontend. 195 func (b *Backend) SetAuthType(authType uint32) error { 196 switch authType { 197 case AuthTypeOk, 198 AuthTypeCleartextPassword, 199 AuthTypeMD5Password, 200 AuthTypeSCMCreds, 201 AuthTypeGSS, 202 AuthTypeGSSCont, 203 AuthTypeSSPI, 204 AuthTypeSASL, 205 AuthTypeSASLContinue, 206 AuthTypeSASLFinal: 207 b.authType = authType 208 default: 209 return fmt.Errorf("authType not recognized: %d", authType) 210 } 211 212 return nil 213 }