named_args.go (5795B)
1 package pgx 2 3 import ( 4 "context" 5 "strconv" 6 "strings" 7 "unicode/utf8" 8 ) 9 10 // NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$' 11 // ordinal placeholder and construct the appropriate arguments. 12 // 13 // For example, the following two queries are equivalent: 14 // 15 // conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}) 16 // conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2) 17 type NamedArgs map[string]any 18 19 // RewriteQuery implements the QueryRewriter interface. 20 func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { 21 l := &sqlLexer{ 22 src: sql, 23 stateFn: rawState, 24 nameToOrdinal: make(map[namedArg]int, len(na)), 25 } 26 27 for l.stateFn != nil { 28 l.stateFn = l.stateFn(l) 29 } 30 31 sb := strings.Builder{} 32 for _, p := range l.parts { 33 switch p := p.(type) { 34 case string: 35 sb.WriteString(p) 36 case namedArg: 37 sb.WriteRune('$') 38 sb.WriteString(strconv.Itoa(l.nameToOrdinal[p])) 39 } 40 } 41 42 newArgs = make([]any, len(l.nameToOrdinal)) 43 for name, ordinal := range l.nameToOrdinal { 44 newArgs[ordinal-1] = na[string(name)] 45 } 46 47 return sb.String(), newArgs, nil 48 } 49 50 type namedArg string 51 52 type sqlLexer struct { 53 src string 54 start int 55 pos int 56 nested int // multiline comment nesting level. 57 stateFn stateFn 58 parts []any 59 60 nameToOrdinal map[namedArg]int 61 } 62 63 type stateFn func(*sqlLexer) stateFn 64 65 func rawState(l *sqlLexer) stateFn { 66 for { 67 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 68 l.pos += width 69 70 switch r { 71 case 'e', 'E': 72 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 73 if nextRune == '\'' { 74 l.pos += width 75 return escapeStringState 76 } 77 case '\'': 78 return singleQuoteState 79 case '"': 80 return doubleQuoteState 81 case '@': 82 nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) 83 if isLetter(nextRune) { 84 if l.pos-l.start > 0 { 85 l.parts = append(l.parts, l.src[l.start:l.pos-width]) 86 } 87 l.start = l.pos 88 return namedArgState 89 } 90 case '-': 91 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 92 if nextRune == '-' { 93 l.pos += width 94 return oneLineCommentState 95 } 96 case '/': 97 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 98 if nextRune == '*' { 99 l.pos += width 100 return multilineCommentState 101 } 102 case utf8.RuneError: 103 if l.pos-l.start > 0 { 104 l.parts = append(l.parts, l.src[l.start:l.pos]) 105 l.start = l.pos 106 } 107 return nil 108 } 109 } 110 } 111 112 func isLetter(r rune) bool { 113 return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') 114 } 115 116 func namedArgState(l *sqlLexer) stateFn { 117 for { 118 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 119 l.pos += width 120 121 if r == utf8.RuneError { 122 if l.pos-l.start > 0 { 123 na := namedArg(l.src[l.start:l.pos]) 124 if _, found := l.nameToOrdinal[na]; !found { 125 l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1 126 } 127 l.parts = append(l.parts, na) 128 l.start = l.pos 129 } 130 return nil 131 } else if !(isLetter(r) || (r >= '0' && r <= '9') || r == '_') { 132 l.pos -= width 133 na := namedArg(l.src[l.start:l.pos]) 134 if _, found := l.nameToOrdinal[na]; !found { 135 l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1 136 } 137 l.parts = append(l.parts, namedArg(na)) 138 l.start = l.pos 139 return rawState 140 } 141 } 142 } 143 144 func singleQuoteState(l *sqlLexer) stateFn { 145 for { 146 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 147 l.pos += width 148 149 switch r { 150 case '\'': 151 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 152 if nextRune != '\'' { 153 return rawState 154 } 155 l.pos += width 156 case utf8.RuneError: 157 if l.pos-l.start > 0 { 158 l.parts = append(l.parts, l.src[l.start:l.pos]) 159 l.start = l.pos 160 } 161 return nil 162 } 163 } 164 } 165 166 func doubleQuoteState(l *sqlLexer) stateFn { 167 for { 168 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 169 l.pos += width 170 171 switch r { 172 case '"': 173 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 174 if nextRune != '"' { 175 return rawState 176 } 177 l.pos += width 178 case utf8.RuneError: 179 if l.pos-l.start > 0 { 180 l.parts = append(l.parts, l.src[l.start:l.pos]) 181 l.start = l.pos 182 } 183 return nil 184 } 185 } 186 } 187 188 func escapeStringState(l *sqlLexer) stateFn { 189 for { 190 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 191 l.pos += width 192 193 switch r { 194 case '\\': 195 _, width = utf8.DecodeRuneInString(l.src[l.pos:]) 196 l.pos += width 197 case '\'': 198 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 199 if nextRune != '\'' { 200 return rawState 201 } 202 l.pos += width 203 case utf8.RuneError: 204 if l.pos-l.start > 0 { 205 l.parts = append(l.parts, l.src[l.start:l.pos]) 206 l.start = l.pos 207 } 208 return nil 209 } 210 } 211 } 212 213 func oneLineCommentState(l *sqlLexer) stateFn { 214 for { 215 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 216 l.pos += width 217 218 switch r { 219 case '\\': 220 _, width = utf8.DecodeRuneInString(l.src[l.pos:]) 221 l.pos += width 222 case '\n', '\r': 223 return rawState 224 case utf8.RuneError: 225 if l.pos-l.start > 0 { 226 l.parts = append(l.parts, l.src[l.start:l.pos]) 227 l.start = l.pos 228 } 229 return nil 230 } 231 } 232 } 233 234 func multilineCommentState(l *sqlLexer) stateFn { 235 for { 236 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 237 l.pos += width 238 239 switch r { 240 case '/': 241 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 242 if nextRune == '*' { 243 l.pos += width 244 l.nested++ 245 } 246 case '*': 247 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 248 if nextRune != '/' { 249 continue 250 } 251 252 l.pos += width 253 if l.nested == 0 { 254 return rawState 255 } 256 l.nested-- 257 258 case utf8.RuneError: 259 if l.pos-l.start > 0 { 260 l.parts = append(l.parts, l.src[l.start:l.pos]) 261 l.start = l.pos 262 } 263 return nil 264 } 265 } 266 }