frontend.go (5561B)
1 package pgproto3 2 3 import ( 4 "encoding/binary" 5 "errors" 6 "fmt" 7 "io" 8 ) 9 10 // Frontend acts as a client for the PostgreSQL wire protocol version 3. 11 type Frontend struct { 12 cr ChunkReader 13 w io.Writer 14 15 // Backend message flyweights 16 authenticationOk AuthenticationOk 17 authenticationCleartextPassword AuthenticationCleartextPassword 18 authenticationMD5Password AuthenticationMD5Password 19 authenticationGSS AuthenticationGSS 20 authenticationGSSContinue AuthenticationGSSContinue 21 authenticationSASL AuthenticationSASL 22 authenticationSASLContinue AuthenticationSASLContinue 23 authenticationSASLFinal AuthenticationSASLFinal 24 backendKeyData BackendKeyData 25 bindComplete BindComplete 26 closeComplete CloseComplete 27 commandComplete CommandComplete 28 copyBothResponse CopyBothResponse 29 copyData CopyData 30 copyInResponse CopyInResponse 31 copyOutResponse CopyOutResponse 32 copyDone CopyDone 33 dataRow DataRow 34 emptyQueryResponse EmptyQueryResponse 35 errorResponse ErrorResponse 36 functionCallResponse FunctionCallResponse 37 noData NoData 38 noticeResponse NoticeResponse 39 notificationResponse NotificationResponse 40 parameterDescription ParameterDescription 41 parameterStatus ParameterStatus 42 parseComplete ParseComplete 43 readyForQuery ReadyForQuery 44 rowDescription RowDescription 45 portalSuspended PortalSuspended 46 47 bodyLen int 48 msgType byte 49 partialMsg bool 50 authType uint32 51 } 52 53 // NewFrontend creates a new Frontend. 54 func NewFrontend(cr ChunkReader, w io.Writer) *Frontend { 55 return &Frontend{cr: cr, w: w} 56 } 57 58 // Send sends a message to the backend. 59 func (f *Frontend) Send(msg FrontendMessage) error { 60 _, err := f.w.Write(msg.Encode(nil)) 61 return err 62 } 63 64 func translateEOFtoErrUnexpectedEOF(err error) error { 65 if err == io.EOF { 66 return io.ErrUnexpectedEOF 67 } 68 return err 69 } 70 71 // Receive receives a message from the backend. The returned message is only valid until the next call to Receive. 72 func (f *Frontend) Receive() (BackendMessage, error) { 73 if !f.partialMsg { 74 header, err := f.cr.Next(5) 75 if err != nil { 76 return nil, translateEOFtoErrUnexpectedEOF(err) 77 } 78 79 f.msgType = header[0] 80 f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 81 f.partialMsg = true 82 if f.bodyLen < 0 { 83 return nil, errors.New("invalid message with negative body length received") 84 } 85 } 86 87 msgBody, err := f.cr.Next(f.bodyLen) 88 if err != nil { 89 return nil, translateEOFtoErrUnexpectedEOF(err) 90 } 91 92 f.partialMsg = false 93 94 var msg BackendMessage 95 switch f.msgType { 96 case '1': 97 msg = &f.parseComplete 98 case '2': 99 msg = &f.bindComplete 100 case '3': 101 msg = &f.closeComplete 102 case 'A': 103 msg = &f.notificationResponse 104 case 'c': 105 msg = &f.copyDone 106 case 'C': 107 msg = &f.commandComplete 108 case 'd': 109 msg = &f.copyData 110 case 'D': 111 msg = &f.dataRow 112 case 'E': 113 msg = &f.errorResponse 114 case 'G': 115 msg = &f.copyInResponse 116 case 'H': 117 msg = &f.copyOutResponse 118 case 'I': 119 msg = &f.emptyQueryResponse 120 case 'K': 121 msg = &f.backendKeyData 122 case 'n': 123 msg = &f.noData 124 case 'N': 125 msg = &f.noticeResponse 126 case 'R': 127 var err error 128 msg, err = f.findAuthenticationMessageType(msgBody) 129 if err != nil { 130 return nil, err 131 } 132 case 's': 133 msg = &f.portalSuspended 134 case 'S': 135 msg = &f.parameterStatus 136 case 't': 137 msg = &f.parameterDescription 138 case 'T': 139 msg = &f.rowDescription 140 case 'V': 141 msg = &f.functionCallResponse 142 case 'W': 143 msg = &f.copyBothResponse 144 case 'Z': 145 msg = &f.readyForQuery 146 default: 147 return nil, fmt.Errorf("unknown message type: %c", f.msgType) 148 } 149 150 err = msg.Decode(msgBody) 151 return msg, err 152 } 153 154 // Authentication message type constants. 155 // See src/include/libpq/pqcomm.h for all 156 // constants. 157 const ( 158 AuthTypeOk = 0 159 AuthTypeCleartextPassword = 3 160 AuthTypeMD5Password = 5 161 AuthTypeSCMCreds = 6 162 AuthTypeGSS = 7 163 AuthTypeGSSCont = 8 164 AuthTypeSSPI = 9 165 AuthTypeSASL = 10 166 AuthTypeSASLContinue = 11 167 AuthTypeSASLFinal = 12 168 ) 169 170 func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) { 171 if len(src) < 4 { 172 return nil, errors.New("authentication message too short") 173 } 174 f.authType = binary.BigEndian.Uint32(src[:4]) 175 176 switch f.authType { 177 case AuthTypeOk: 178 return &f.authenticationOk, nil 179 case AuthTypeCleartextPassword: 180 return &f.authenticationCleartextPassword, nil 181 case AuthTypeMD5Password: 182 return &f.authenticationMD5Password, nil 183 case AuthTypeSCMCreds: 184 return nil, errors.New("AuthTypeSCMCreds is unimplemented") 185 case AuthTypeGSS: 186 return &f.authenticationGSS, nil 187 case AuthTypeGSSCont: 188 return &f.authenticationGSSContinue, nil 189 case AuthTypeSSPI: 190 return nil, errors.New("AuthTypeSSPI is unimplemented") 191 case AuthTypeSASL: 192 return &f.authenticationSASL, nil 193 case AuthTypeSASLContinue: 194 return &f.authenticationSASLContinue, nil 195 case AuthTypeSASLFinal: 196 return &f.authenticationSASLFinal, nil 197 default: 198 return nil, fmt.Errorf("unknown authentication type: %d", f.authType) 199 } 200 } 201 202 // GetAuthType returns the authType used in the current state of the frontend. 203 // See SetAuthType for more information. 204 func (f *Frontend) GetAuthType() uint32 { 205 return f.authType 206 }