gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

named_args.go (5795B)


      1 package pgx
      2 
      3 import (
      4 	"context"
      5 	"strconv"
      6 	"strings"
      7 	"unicode/utf8"
      8 )
      9 
     10 // NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$'
     11 // ordinal placeholder and construct the appropriate arguments.
     12 //
     13 // For example, the following two queries are equivalent:
     14 //
     15 //	conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})
     16 //	conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2)
     17 type NamedArgs map[string]any
     18 
     19 // RewriteQuery implements the QueryRewriter interface.
     20 func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
     21 	l := &sqlLexer{
     22 		src:           sql,
     23 		stateFn:       rawState,
     24 		nameToOrdinal: make(map[namedArg]int, len(na)),
     25 	}
     26 
     27 	for l.stateFn != nil {
     28 		l.stateFn = l.stateFn(l)
     29 	}
     30 
     31 	sb := strings.Builder{}
     32 	for _, p := range l.parts {
     33 		switch p := p.(type) {
     34 		case string:
     35 			sb.WriteString(p)
     36 		case namedArg:
     37 			sb.WriteRune('$')
     38 			sb.WriteString(strconv.Itoa(l.nameToOrdinal[p]))
     39 		}
     40 	}
     41 
     42 	newArgs = make([]any, len(l.nameToOrdinal))
     43 	for name, ordinal := range l.nameToOrdinal {
     44 		newArgs[ordinal-1] = na[string(name)]
     45 	}
     46 
     47 	return sb.String(), newArgs, nil
     48 }
     49 
     50 type namedArg string
     51 
     52 type sqlLexer struct {
     53 	src     string
     54 	start   int
     55 	pos     int
     56 	nested  int // multiline comment nesting level.
     57 	stateFn stateFn
     58 	parts   []any
     59 
     60 	nameToOrdinal map[namedArg]int
     61 }
     62 
     63 type stateFn func(*sqlLexer) stateFn
     64 
     65 func rawState(l *sqlLexer) stateFn {
     66 	for {
     67 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
     68 		l.pos += width
     69 
     70 		switch r {
     71 		case 'e', 'E':
     72 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
     73 			if nextRune == '\'' {
     74 				l.pos += width
     75 				return escapeStringState
     76 			}
     77 		case '\'':
     78 			return singleQuoteState
     79 		case '"':
     80 			return doubleQuoteState
     81 		case '@':
     82 			nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
     83 			if isLetter(nextRune) {
     84 				if l.pos-l.start > 0 {
     85 					l.parts = append(l.parts, l.src[l.start:l.pos-width])
     86 				}
     87 				l.start = l.pos
     88 				return namedArgState
     89 			}
     90 		case '-':
     91 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
     92 			if nextRune == '-' {
     93 				l.pos += width
     94 				return oneLineCommentState
     95 			}
     96 		case '/':
     97 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
     98 			if nextRune == '*' {
     99 				l.pos += width
    100 				return multilineCommentState
    101 			}
    102 		case utf8.RuneError:
    103 			if l.pos-l.start > 0 {
    104 				l.parts = append(l.parts, l.src[l.start:l.pos])
    105 				l.start = l.pos
    106 			}
    107 			return nil
    108 		}
    109 	}
    110 }
    111 
    112 func isLetter(r rune) bool {
    113 	return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z')
    114 }
    115 
    116 func namedArgState(l *sqlLexer) stateFn {
    117 	for {
    118 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    119 		l.pos += width
    120 
    121 		if r == utf8.RuneError {
    122 			if l.pos-l.start > 0 {
    123 				na := namedArg(l.src[l.start:l.pos])
    124 				if _, found := l.nameToOrdinal[na]; !found {
    125 					l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
    126 				}
    127 				l.parts = append(l.parts, na)
    128 				l.start = l.pos
    129 			}
    130 			return nil
    131 		} else if !(isLetter(r) || (r >= '0' && r <= '9') || r == '_') {
    132 			l.pos -= width
    133 			na := namedArg(l.src[l.start:l.pos])
    134 			if _, found := l.nameToOrdinal[na]; !found {
    135 				l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
    136 			}
    137 			l.parts = append(l.parts, namedArg(na))
    138 			l.start = l.pos
    139 			return rawState
    140 		}
    141 	}
    142 }
    143 
    144 func singleQuoteState(l *sqlLexer) stateFn {
    145 	for {
    146 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    147 		l.pos += width
    148 
    149 		switch r {
    150 		case '\'':
    151 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    152 			if nextRune != '\'' {
    153 				return rawState
    154 			}
    155 			l.pos += width
    156 		case utf8.RuneError:
    157 			if l.pos-l.start > 0 {
    158 				l.parts = append(l.parts, l.src[l.start:l.pos])
    159 				l.start = l.pos
    160 			}
    161 			return nil
    162 		}
    163 	}
    164 }
    165 
    166 func doubleQuoteState(l *sqlLexer) stateFn {
    167 	for {
    168 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    169 		l.pos += width
    170 
    171 		switch r {
    172 		case '"':
    173 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    174 			if nextRune != '"' {
    175 				return rawState
    176 			}
    177 			l.pos += width
    178 		case utf8.RuneError:
    179 			if l.pos-l.start > 0 {
    180 				l.parts = append(l.parts, l.src[l.start:l.pos])
    181 				l.start = l.pos
    182 			}
    183 			return nil
    184 		}
    185 	}
    186 }
    187 
    188 func escapeStringState(l *sqlLexer) stateFn {
    189 	for {
    190 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    191 		l.pos += width
    192 
    193 		switch r {
    194 		case '\\':
    195 			_, width = utf8.DecodeRuneInString(l.src[l.pos:])
    196 			l.pos += width
    197 		case '\'':
    198 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    199 			if nextRune != '\'' {
    200 				return rawState
    201 			}
    202 			l.pos += width
    203 		case utf8.RuneError:
    204 			if l.pos-l.start > 0 {
    205 				l.parts = append(l.parts, l.src[l.start:l.pos])
    206 				l.start = l.pos
    207 			}
    208 			return nil
    209 		}
    210 	}
    211 }
    212 
    213 func oneLineCommentState(l *sqlLexer) stateFn {
    214 	for {
    215 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    216 		l.pos += width
    217 
    218 		switch r {
    219 		case '\\':
    220 			_, width = utf8.DecodeRuneInString(l.src[l.pos:])
    221 			l.pos += width
    222 		case '\n', '\r':
    223 			return rawState
    224 		case utf8.RuneError:
    225 			if l.pos-l.start > 0 {
    226 				l.parts = append(l.parts, l.src[l.start:l.pos])
    227 				l.start = l.pos
    228 			}
    229 			return nil
    230 		}
    231 	}
    232 }
    233 
    234 func multilineCommentState(l *sqlLexer) stateFn {
    235 	for {
    236 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    237 		l.pos += width
    238 
    239 		switch r {
    240 		case '/':
    241 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    242 			if nextRune == '*' {
    243 				l.pos += width
    244 				l.nested++
    245 			}
    246 		case '*':
    247 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    248 			if nextRune != '/' {
    249 				continue
    250 			}
    251 
    252 			l.pos += width
    253 			if l.nested == 0 {
    254 				return rawState
    255 			}
    256 			l.nested--
    257 
    258 		case utf8.RuneError:
    259 			if l.pos-l.start > 0 {
    260 				l.parts = append(l.parts, l.src[l.start:l.pos])
    261 				l.start = l.pos
    262 			}
    263 			return nil
    264 		}
    265 	}
    266 }