gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

extended_query_builder.go (6166B)


      1 package pgx
      2 
      3 import (
      4 	"database/sql/driver"
      5 	"fmt"
      6 
      7 	"github.com/jackc/pgx/v5/internal/anynil"
      8 	"github.com/jackc/pgx/v5/pgconn"
      9 	"github.com/jackc/pgx/v5/pgtype"
     10 )
     11 
     12 // ExtendedQueryBuilder is used to choose the parameter formats, to format the parameters and to choose the result
     13 // formats for an extended query.
     14 type ExtendedQueryBuilder struct {
     15 	ParamValues     [][]byte
     16 	paramValueBytes []byte
     17 	ParamFormats    []int16
     18 	ResultFormats   []int16
     19 }
     20 
     21 // Build sets ParamValues, ParamFormats, and ResultFormats for use with *PgConn.ExecParams or *PgConn.ExecPrepared. If
     22 // sd is nil then QueryExecModeExec behavior will be used.
     23 func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
     24 	eqb.reset()
     25 
     26 	anynil.NormalizeSlice(args)
     27 
     28 	if sd == nil {
     29 		return eqb.appendParamsForQueryExecModeExec(m, args)
     30 	}
     31 
     32 	if len(sd.ParamOIDs) != len(args) {
     33 		return fmt.Errorf("mismatched param and argument count")
     34 	}
     35 
     36 	for i := range args {
     37 		err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i])
     38 		if err != nil {
     39 			err = fmt.Errorf("failed to encode args[%d]: %v", i, err)
     40 			return err
     41 		}
     42 	}
     43 
     44 	for i := range sd.Fields {
     45 		eqb.appendResultFormat(m.FormatCodeForOID(sd.Fields[i].DataTypeOID))
     46 	}
     47 
     48 	return nil
     49 }
     50 
     51 // appendParam appends a parameter to the query. format may be -1 to automatically choose the format. If arg is nil it
     52 // must be an untyped nil.
     53 func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error {
     54 	if format == -1 {
     55 		preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg)
     56 		preferredErr := eqb.appendParam(m, oid, preferredFormat, arg)
     57 		if preferredErr == nil {
     58 			return nil
     59 		}
     60 
     61 		var otherFormat int16
     62 		if preferredFormat == TextFormatCode {
     63 			otherFormat = BinaryFormatCode
     64 		} else {
     65 			otherFormat = TextFormatCode
     66 		}
     67 
     68 		otherErr := eqb.appendParam(m, oid, otherFormat, arg)
     69 		if otherErr == nil {
     70 			return nil
     71 		}
     72 
     73 		return preferredErr // return the error from the preferred format
     74 	}
     75 
     76 	v, err := eqb.encodeExtendedParamValue(m, oid, format, arg)
     77 	if err != nil {
     78 		return err
     79 	}
     80 
     81 	eqb.ParamFormats = append(eqb.ParamFormats, format)
     82 	eqb.ParamValues = append(eqb.ParamValues, v)
     83 
     84 	return nil
     85 }
     86 
     87 // appendResultFormat appends a result format to the query.
     88 func (eqb *ExtendedQueryBuilder) appendResultFormat(format int16) {
     89 	eqb.ResultFormats = append(eqb.ResultFormats, format)
     90 }
     91 
     92 // reset readies eqb to build another query.
     93 func (eqb *ExtendedQueryBuilder) reset() {
     94 	eqb.ParamValues = eqb.ParamValues[0:0]
     95 	eqb.paramValueBytes = eqb.paramValueBytes[0:0]
     96 	eqb.ParamFormats = eqb.ParamFormats[0:0]
     97 	eqb.ResultFormats = eqb.ResultFormats[0:0]
     98 
     99 	if cap(eqb.ParamValues) > 64 {
    100 		eqb.ParamValues = make([][]byte, 0, 64)
    101 	}
    102 
    103 	if cap(eqb.paramValueBytes) > 256 {
    104 		eqb.paramValueBytes = make([]byte, 0, 256)
    105 	}
    106 
    107 	if cap(eqb.ParamFormats) > 64 {
    108 		eqb.ParamFormats = make([]int16, 0, 64)
    109 	}
    110 	if cap(eqb.ResultFormats) > 64 {
    111 		eqb.ResultFormats = make([]int16, 0, 64)
    112 	}
    113 }
    114 
    115 func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
    116 	if anynil.Is(arg) {
    117 		return nil, nil
    118 	}
    119 
    120 	if eqb.paramValueBytes == nil {
    121 		eqb.paramValueBytes = make([]byte, 0, 128)
    122 	}
    123 
    124 	pos := len(eqb.paramValueBytes)
    125 
    126 	buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes)
    127 	if err != nil {
    128 		return nil, err
    129 	}
    130 	if buf == nil {
    131 		return nil, nil
    132 	}
    133 	eqb.paramValueBytes = buf
    134 	return eqb.paramValueBytes[pos:], nil
    135 }
    136 
    137 // chooseParameterFormatCode determines the correct format code for an
    138 // argument to a prepared statement. It defaults to TextFormatCode if no
    139 // determination can be made.
    140 func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 {
    141 	switch arg.(type) {
    142 	case string, *string:
    143 		return TextFormatCode
    144 	}
    145 
    146 	return m.FormatCodeForOID(oid)
    147 }
    148 
    149 // appendParamsForQueryExecModeExec appends the args to eqb.
    150 //
    151 // Parameters must be encoded in the text format because of differences in type conversion between timestamps and
    152 // dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the
    153 // Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both
    154 // PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL
    155 // type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date.
    156 // This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion
    157 // before converting it to date. This means that dates can be shifted by one day. In text format without that double
    158 // type conversion it takes the date directly and ignores time zone (i.e. it works).
    159 //
    160 // Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is
    161 // no way to safely use binary or to specify the parameter OIDs.
    162 func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error {
    163 	for _, arg := range args {
    164 		if arg == nil {
    165 			err := eqb.appendParam(m, 0, TextFormatCode, arg)
    166 			if err != nil {
    167 				return err
    168 			}
    169 		} else {
    170 			dt, ok := m.TypeForValue(arg)
    171 			if !ok {
    172 				var tv pgtype.TextValuer
    173 				if tv, ok = arg.(pgtype.TextValuer); ok {
    174 					t, err := tv.TextValue()
    175 					if err != nil {
    176 						return err
    177 					}
    178 
    179 					dt, ok = m.TypeForOID(pgtype.TextOID)
    180 					if ok {
    181 						arg = t
    182 					}
    183 				}
    184 			}
    185 			if !ok {
    186 				var dv driver.Valuer
    187 				if dv, ok = arg.(driver.Valuer); ok {
    188 					v, err := dv.Value()
    189 					if err != nil {
    190 						return err
    191 					}
    192 					dt, ok = m.TypeForValue(v)
    193 					if ok {
    194 						arg = v
    195 					}
    196 				}
    197 			}
    198 			if !ok {
    199 				var str fmt.Stringer
    200 				if str, ok = arg.(fmt.Stringer); ok {
    201 					dt, ok = m.TypeForOID(pgtype.TextOID)
    202 					if ok {
    203 						arg = str.String()
    204 					}
    205 				}
    206 			}
    207 			if !ok {
    208 				return &unknownArgumentTypeQueryExecModeExecError{arg: arg}
    209 			}
    210 			err := eqb.appendParam(m, dt.OID, TextFormatCode, arg)
    211 			if err != nil {
    212 				return err
    213 			}
    214 		}
    215 	}
    216 
    217 	return nil
    218 }