gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

bind.go (5621B)


      1 package pgproto3
      2 
      3 import (
      4 	"bytes"
      5 	"encoding/binary"
      6 	"encoding/hex"
      7 	"encoding/json"
      8 	"fmt"
      9 
     10 	"github.com/jackc/pgx/v5/internal/pgio"
     11 )
     12 
     13 type Bind struct {
     14 	DestinationPortal    string
     15 	PreparedStatement    string
     16 	ParameterFormatCodes []int16
     17 	Parameters           [][]byte
     18 	ResultFormatCodes    []int16
     19 }
     20 
     21 // Frontend identifies this message as sendable by a PostgreSQL frontend.
     22 func (*Bind) Frontend() {}
     23 
     24 // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
     25 // type identifier and 4 byte message length.
     26 func (dst *Bind) Decode(src []byte) error {
     27 	*dst = Bind{}
     28 
     29 	idx := bytes.IndexByte(src, 0)
     30 	if idx < 0 {
     31 		return &invalidMessageFormatErr{messageType: "Bind"}
     32 	}
     33 	dst.DestinationPortal = string(src[:idx])
     34 	rp := idx + 1
     35 
     36 	idx = bytes.IndexByte(src[rp:], 0)
     37 	if idx < 0 {
     38 		return &invalidMessageFormatErr{messageType: "Bind"}
     39 	}
     40 	dst.PreparedStatement = string(src[rp : rp+idx])
     41 	rp += idx + 1
     42 
     43 	if len(src[rp:]) < 2 {
     44 		return &invalidMessageFormatErr{messageType: "Bind"}
     45 	}
     46 	parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
     47 	rp += 2
     48 
     49 	if parameterFormatCodeCount > 0 {
     50 		dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
     51 
     52 		if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
     53 			return &invalidMessageFormatErr{messageType: "Bind"}
     54 		}
     55 		for i := 0; i < parameterFormatCodeCount; i++ {
     56 			dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
     57 			rp += 2
     58 		}
     59 	}
     60 
     61 	if len(src[rp:]) < 2 {
     62 		return &invalidMessageFormatErr{messageType: "Bind"}
     63 	}
     64 	parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
     65 	rp += 2
     66 
     67 	if parameterCount > 0 {
     68 		dst.Parameters = make([][]byte, parameterCount)
     69 
     70 		for i := 0; i < parameterCount; i++ {
     71 			if len(src[rp:]) < 4 {
     72 				return &invalidMessageFormatErr{messageType: "Bind"}
     73 			}
     74 
     75 			msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
     76 			rp += 4
     77 
     78 			// null
     79 			if msgSize == -1 {
     80 				continue
     81 			}
     82 
     83 			if len(src[rp:]) < msgSize {
     84 				return &invalidMessageFormatErr{messageType: "Bind"}
     85 			}
     86 
     87 			dst.Parameters[i] = src[rp : rp+msgSize]
     88 			rp += msgSize
     89 		}
     90 	}
     91 
     92 	if len(src[rp:]) < 2 {
     93 		return &invalidMessageFormatErr{messageType: "Bind"}
     94 	}
     95 	resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
     96 	rp += 2
     97 
     98 	dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
     99 	if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
    100 		return &invalidMessageFormatErr{messageType: "Bind"}
    101 	}
    102 	for i := 0; i < resultFormatCodeCount; i++ {
    103 		dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
    104 		rp += 2
    105 	}
    106 
    107 	return nil
    108 }
    109 
    110 // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    111 func (src *Bind) Encode(dst []byte) []byte {
    112 	dst = append(dst, 'B')
    113 	sp := len(dst)
    114 	dst = pgio.AppendInt32(dst, -1)
    115 
    116 	dst = append(dst, src.DestinationPortal...)
    117 	dst = append(dst, 0)
    118 	dst = append(dst, src.PreparedStatement...)
    119 	dst = append(dst, 0)
    120 
    121 	dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
    122 	for _, fc := range src.ParameterFormatCodes {
    123 		dst = pgio.AppendInt16(dst, fc)
    124 	}
    125 
    126 	dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
    127 	for _, p := range src.Parameters {
    128 		if p == nil {
    129 			dst = pgio.AppendInt32(dst, -1)
    130 			continue
    131 		}
    132 
    133 		dst = pgio.AppendInt32(dst, int32(len(p)))
    134 		dst = append(dst, p...)
    135 	}
    136 
    137 	dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
    138 	for _, fc := range src.ResultFormatCodes {
    139 		dst = pgio.AppendInt16(dst, fc)
    140 	}
    141 
    142 	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    143 
    144 	return dst
    145 }
    146 
    147 // MarshalJSON implements encoding/json.Marshaler.
    148 func (src Bind) MarshalJSON() ([]byte, error) {
    149 	formattedParameters := make([]map[string]string, len(src.Parameters))
    150 	for i, p := range src.Parameters {
    151 		if p == nil {
    152 			continue
    153 		}
    154 
    155 		textFormat := true
    156 		if len(src.ParameterFormatCodes) == 1 {
    157 			textFormat = src.ParameterFormatCodes[0] == 0
    158 		} else if len(src.ParameterFormatCodes) > 1 {
    159 			textFormat = src.ParameterFormatCodes[i] == 0
    160 		}
    161 
    162 		if textFormat {
    163 			formattedParameters[i] = map[string]string{"text": string(p)}
    164 		} else {
    165 			formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
    166 		}
    167 	}
    168 
    169 	return json.Marshal(struct {
    170 		Type                 string
    171 		DestinationPortal    string
    172 		PreparedStatement    string
    173 		ParameterFormatCodes []int16
    174 		Parameters           []map[string]string
    175 		ResultFormatCodes    []int16
    176 	}{
    177 		Type:                 "Bind",
    178 		DestinationPortal:    src.DestinationPortal,
    179 		PreparedStatement:    src.PreparedStatement,
    180 		ParameterFormatCodes: src.ParameterFormatCodes,
    181 		Parameters:           formattedParameters,
    182 		ResultFormatCodes:    src.ResultFormatCodes,
    183 	})
    184 }
    185 
    186 // UnmarshalJSON implements encoding/json.Unmarshaler.
    187 func (dst *Bind) UnmarshalJSON(data []byte) error {
    188 	// Ignore null, like in the main JSON package.
    189 	if string(data) == "null" {
    190 		return nil
    191 	}
    192 
    193 	var msg struct {
    194 		DestinationPortal    string
    195 		PreparedStatement    string
    196 		ParameterFormatCodes []int16
    197 		Parameters           []map[string]string
    198 		ResultFormatCodes    []int16
    199 	}
    200 	err := json.Unmarshal(data, &msg)
    201 	if err != nil {
    202 		return err
    203 	}
    204 	dst.DestinationPortal = msg.DestinationPortal
    205 	dst.PreparedStatement = msg.PreparedStatement
    206 	dst.ParameterFormatCodes = msg.ParameterFormatCodes
    207 	dst.Parameters = make([][]byte, len(msg.Parameters))
    208 	dst.ResultFormatCodes = msg.ResultFormatCodes
    209 	for n, parameter := range msg.Parameters {
    210 		dst.Parameters[n], err = getValueFromJSON(parameter)
    211 		if err != nil {
    212 			return fmt.Errorf("cannot get param %d: %w", n, err)
    213 		}
    214 	}
    215 	return nil
    216 }