bind.go (5605B)
1 package pgproto3 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "encoding/hex" 7 "encoding/json" 8 "fmt" 9 10 "github.com/jackc/pgio" 11 ) 12 13 type Bind struct { 14 DestinationPortal string 15 PreparedStatement string 16 ParameterFormatCodes []int16 17 Parameters [][]byte 18 ResultFormatCodes []int16 19 } 20 21 // Frontend identifies this message as sendable by a PostgreSQL frontend. 22 func (*Bind) Frontend() {} 23 24 // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 25 // type identifier and 4 byte message length. 26 func (dst *Bind) Decode(src []byte) error { 27 *dst = Bind{} 28 29 idx := bytes.IndexByte(src, 0) 30 if idx < 0 { 31 return &invalidMessageFormatErr{messageType: "Bind"} 32 } 33 dst.DestinationPortal = string(src[:idx]) 34 rp := idx + 1 35 36 idx = bytes.IndexByte(src[rp:], 0) 37 if idx < 0 { 38 return &invalidMessageFormatErr{messageType: "Bind"} 39 } 40 dst.PreparedStatement = string(src[rp : rp+idx]) 41 rp += idx + 1 42 43 if len(src[rp:]) < 2 { 44 return &invalidMessageFormatErr{messageType: "Bind"} 45 } 46 parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) 47 rp += 2 48 49 if parameterFormatCodeCount > 0 { 50 dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount) 51 52 if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { 53 return &invalidMessageFormatErr{messageType: "Bind"} 54 } 55 for i := 0; i < parameterFormatCodeCount; i++ { 56 dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) 57 rp += 2 58 } 59 } 60 61 if len(src[rp:]) < 2 { 62 return &invalidMessageFormatErr{messageType: "Bind"} 63 } 64 parameterCount := int(binary.BigEndian.Uint16(src[rp:])) 65 rp += 2 66 67 if parameterCount > 0 { 68 dst.Parameters = make([][]byte, parameterCount) 69 70 for i := 0; i < parameterCount; i++ { 71 if len(src[rp:]) < 4 { 72 return &invalidMessageFormatErr{messageType: "Bind"} 73 } 74 75 msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) 76 rp += 4 77 78 // null 79 if msgSize == -1 { 80 continue 81 } 82 83 if len(src[rp:]) < msgSize { 84 return &invalidMessageFormatErr{messageType: "Bind"} 85 } 86 87 dst.Parameters[i] = src[rp : rp+msgSize] 88 rp += msgSize 89 } 90 } 91 92 if len(src[rp:]) < 2 { 93 return &invalidMessageFormatErr{messageType: "Bind"} 94 } 95 resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) 96 rp += 2 97 98 dst.ResultFormatCodes = make([]int16, resultFormatCodeCount) 99 if len(src[rp:]) < len(dst.ResultFormatCodes)*2 { 100 return &invalidMessageFormatErr{messageType: "Bind"} 101 } 102 for i := 0; i < resultFormatCodeCount; i++ { 103 dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) 104 rp += 2 105 } 106 107 return nil 108 } 109 110 // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 111 func (src *Bind) Encode(dst []byte) []byte { 112 dst = append(dst, 'B') 113 sp := len(dst) 114 dst = pgio.AppendInt32(dst, -1) 115 116 dst = append(dst, src.DestinationPortal...) 117 dst = append(dst, 0) 118 dst = append(dst, src.PreparedStatement...) 119 dst = append(dst, 0) 120 121 dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) 122 for _, fc := range src.ParameterFormatCodes { 123 dst = pgio.AppendInt16(dst, fc) 124 } 125 126 dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) 127 for _, p := range src.Parameters { 128 if p == nil { 129 dst = pgio.AppendInt32(dst, -1) 130 continue 131 } 132 133 dst = pgio.AppendInt32(dst, int32(len(p))) 134 dst = append(dst, p...) 135 } 136 137 dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) 138 for _, fc := range src.ResultFormatCodes { 139 dst = pgio.AppendInt16(dst, fc) 140 } 141 142 pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 143 144 return dst 145 } 146 147 // MarshalJSON implements encoding/json.Marshaler. 148 func (src Bind) MarshalJSON() ([]byte, error) { 149 formattedParameters := make([]map[string]string, len(src.Parameters)) 150 for i, p := range src.Parameters { 151 if p == nil { 152 continue 153 } 154 155 textFormat := true 156 if len(src.ParameterFormatCodes) == 1 { 157 textFormat = src.ParameterFormatCodes[0] == 0 158 } else if len(src.ParameterFormatCodes) > 1 { 159 textFormat = src.ParameterFormatCodes[i] == 0 160 } 161 162 if textFormat { 163 formattedParameters[i] = map[string]string{"text": string(p)} 164 } else { 165 formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)} 166 } 167 } 168 169 return json.Marshal(struct { 170 Type string 171 DestinationPortal string 172 PreparedStatement string 173 ParameterFormatCodes []int16 174 Parameters []map[string]string 175 ResultFormatCodes []int16 176 }{ 177 Type: "Bind", 178 DestinationPortal: src.DestinationPortal, 179 PreparedStatement: src.PreparedStatement, 180 ParameterFormatCodes: src.ParameterFormatCodes, 181 Parameters: formattedParameters, 182 ResultFormatCodes: src.ResultFormatCodes, 183 }) 184 } 185 186 // UnmarshalJSON implements encoding/json.Unmarshaler. 187 func (dst *Bind) UnmarshalJSON(data []byte) error { 188 // Ignore null, like in the main JSON package. 189 if string(data) == "null" { 190 return nil 191 } 192 193 var msg struct { 194 DestinationPortal string 195 PreparedStatement string 196 ParameterFormatCodes []int16 197 Parameters []map[string]string 198 ResultFormatCodes []int16 199 } 200 err := json.Unmarshal(data, &msg) 201 if err != nil { 202 return err 203 } 204 dst.DestinationPortal = msg.DestinationPortal 205 dst.PreparedStatement = msg.PreparedStatement 206 dst.ParameterFormatCodes = msg.ParameterFormatCodes 207 dst.Parameters = make([][]byte, len(msg.Parameters)) 208 dst.ResultFormatCodes = msg.ResultFormatCodes 209 for n, parameter := range msg.Parameters { 210 dst.Parameters[n], err = getValueFromJSON(parameter) 211 if err != nil { 212 return fmt.Errorf("cannot get param %d: %w", n, err) 213 } 214 } 215 return nil 216 }