sanitize.go (6790B)
1 package sanitize 2 3 import ( 4 "bytes" 5 "encoding/hex" 6 "fmt" 7 "strconv" 8 "strings" 9 "time" 10 "unicode/utf8" 11 ) 12 13 // Part is either a string or an int. A string is raw SQL. An int is a 14 // argument placeholder. 15 type Part any 16 17 type Query struct { 18 Parts []Part 19 } 20 21 // utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement 22 // character. utf8.RuneError is not an error if it is also width 3. 23 // 24 // https://github.com/jackc/pgx/issues/1380 25 const replacementcharacterwidth = 3 26 27 func (q *Query) Sanitize(args ...any) (string, error) { 28 argUse := make([]bool, len(args)) 29 buf := &bytes.Buffer{} 30 31 for _, part := range q.Parts { 32 var str string 33 switch part := part.(type) { 34 case string: 35 str = part 36 case int: 37 argIdx := part - 1 38 if argIdx >= len(args) { 39 return "", fmt.Errorf("insufficient arguments") 40 } 41 arg := args[argIdx] 42 switch arg := arg.(type) { 43 case nil: 44 str = "null" 45 case int64: 46 str = strconv.FormatInt(arg, 10) 47 case float64: 48 str = strconv.FormatFloat(arg, 'f', -1, 64) 49 case bool: 50 str = strconv.FormatBool(arg) 51 case []byte: 52 str = QuoteBytes(arg) 53 case string: 54 str = QuoteString(arg) 55 case time.Time: 56 str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") 57 default: 58 return "", fmt.Errorf("invalid arg type: %T", arg) 59 } 60 argUse[argIdx] = true 61 default: 62 return "", fmt.Errorf("invalid Part type: %T", part) 63 } 64 buf.WriteString(str) 65 } 66 67 for i, used := range argUse { 68 if !used { 69 return "", fmt.Errorf("unused argument: %d", i) 70 } 71 } 72 return buf.String(), nil 73 } 74 75 func NewQuery(sql string) (*Query, error) { 76 l := &sqlLexer{ 77 src: sql, 78 stateFn: rawState, 79 } 80 81 for l.stateFn != nil { 82 l.stateFn = l.stateFn(l) 83 } 84 85 query := &Query{Parts: l.parts} 86 87 return query, nil 88 } 89 90 func QuoteString(str string) string { 91 return "'" + strings.ReplaceAll(str, "'", "''") + "'" 92 } 93 94 func QuoteBytes(buf []byte) string { 95 return `'\x` + hex.EncodeToString(buf) + "'" 96 } 97 98 type sqlLexer struct { 99 src string 100 start int 101 pos int 102 nested int // multiline comment nesting level. 103 stateFn stateFn 104 parts []Part 105 } 106 107 type stateFn func(*sqlLexer) stateFn 108 109 func rawState(l *sqlLexer) stateFn { 110 for { 111 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 112 l.pos += width 113 114 switch r { 115 case 'e', 'E': 116 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 117 if nextRune == '\'' { 118 l.pos += width 119 return escapeStringState 120 } 121 case '\'': 122 return singleQuoteState 123 case '"': 124 return doubleQuoteState 125 case '$': 126 nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) 127 if '0' <= nextRune && nextRune <= '9' { 128 if l.pos-l.start > 0 { 129 l.parts = append(l.parts, l.src[l.start:l.pos-width]) 130 } 131 l.start = l.pos 132 return placeholderState 133 } 134 case '-': 135 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 136 if nextRune == '-' { 137 l.pos += width 138 return oneLineCommentState 139 } 140 case '/': 141 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 142 if nextRune == '*' { 143 l.pos += width 144 return multilineCommentState 145 } 146 case utf8.RuneError: 147 if width != replacementcharacterwidth { 148 if l.pos-l.start > 0 { 149 l.parts = append(l.parts, l.src[l.start:l.pos]) 150 l.start = l.pos 151 } 152 return nil 153 } 154 } 155 } 156 } 157 158 func singleQuoteState(l *sqlLexer) stateFn { 159 for { 160 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 161 l.pos += width 162 163 switch r { 164 case '\'': 165 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 166 if nextRune != '\'' { 167 return rawState 168 } 169 l.pos += width 170 case utf8.RuneError: 171 if width != replacementcharacterwidth { 172 if l.pos-l.start > 0 { 173 l.parts = append(l.parts, l.src[l.start:l.pos]) 174 l.start = l.pos 175 } 176 return nil 177 } 178 } 179 } 180 } 181 182 func doubleQuoteState(l *sqlLexer) stateFn { 183 for { 184 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 185 l.pos += width 186 187 switch r { 188 case '"': 189 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 190 if nextRune != '"' { 191 return rawState 192 } 193 l.pos += width 194 case utf8.RuneError: 195 if width != replacementcharacterwidth { 196 if l.pos-l.start > 0 { 197 l.parts = append(l.parts, l.src[l.start:l.pos]) 198 l.start = l.pos 199 } 200 return nil 201 } 202 } 203 } 204 } 205 206 // placeholderState consumes a placeholder value. The $ must have already has 207 // already been consumed. The first rune must be a digit. 208 func placeholderState(l *sqlLexer) stateFn { 209 num := 0 210 211 for { 212 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 213 l.pos += width 214 215 if '0' <= r && r <= '9' { 216 num *= 10 217 num += int(r - '0') 218 } else { 219 l.parts = append(l.parts, num) 220 l.pos -= width 221 l.start = l.pos 222 return rawState 223 } 224 } 225 } 226 227 func escapeStringState(l *sqlLexer) stateFn { 228 for { 229 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 230 l.pos += width 231 232 switch r { 233 case '\\': 234 _, width = utf8.DecodeRuneInString(l.src[l.pos:]) 235 l.pos += width 236 case '\'': 237 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 238 if nextRune != '\'' { 239 return rawState 240 } 241 l.pos += width 242 case utf8.RuneError: 243 if width != replacementcharacterwidth { 244 if l.pos-l.start > 0 { 245 l.parts = append(l.parts, l.src[l.start:l.pos]) 246 l.start = l.pos 247 } 248 return nil 249 } 250 } 251 } 252 } 253 254 func oneLineCommentState(l *sqlLexer) stateFn { 255 for { 256 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 257 l.pos += width 258 259 switch r { 260 case '\\': 261 _, width = utf8.DecodeRuneInString(l.src[l.pos:]) 262 l.pos += width 263 case '\n', '\r': 264 return rawState 265 case utf8.RuneError: 266 if width != replacementcharacterwidth { 267 if l.pos-l.start > 0 { 268 l.parts = append(l.parts, l.src[l.start:l.pos]) 269 l.start = l.pos 270 } 271 return nil 272 } 273 } 274 } 275 } 276 277 func multilineCommentState(l *sqlLexer) stateFn { 278 for { 279 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 280 l.pos += width 281 282 switch r { 283 case '/': 284 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 285 if nextRune == '*' { 286 l.pos += width 287 l.nested++ 288 } 289 case '*': 290 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 291 if nextRune != '/' { 292 continue 293 } 294 295 l.pos += width 296 if l.nested == 0 { 297 return rawState 298 } 299 l.nested-- 300 301 case utf8.RuneError: 302 if width != replacementcharacterwidth { 303 if l.pos-l.start > 0 { 304 l.parts = append(l.parts, l.src[l.start:l.pos]) 305 l.start = l.pos 306 } 307 return nil 308 } 309 } 310 } 311 } 312 313 // SanitizeSQL replaces placeholder values with args. It quotes and escapes args 314 // as necessary. This function is only safe when standard_conforming_strings is 315 // on. 316 func SanitizeSQL(sql string, args ...any) (string, error) { 317 query, err := NewQuery(sql) 318 if err != nil { 319 return "", err 320 } 321 return query.Sanitize(args...) 322 }