gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

sanitize.go (6790B)


      1 package sanitize
      2 
      3 import (
      4 	"bytes"
      5 	"encoding/hex"
      6 	"fmt"
      7 	"strconv"
      8 	"strings"
      9 	"time"
     10 	"unicode/utf8"
     11 )
     12 
     13 // Part is either a string or an int. A string is raw SQL. An int is a
     14 // argument placeholder.
     15 type Part any
     16 
     17 type Query struct {
     18 	Parts []Part
     19 }
     20 
     21 // utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
     22 // character. utf8.RuneError is not an error if it is also width 3.
     23 //
     24 // https://github.com/jackc/pgx/issues/1380
     25 const replacementcharacterwidth = 3
     26 
     27 func (q *Query) Sanitize(args ...any) (string, error) {
     28 	argUse := make([]bool, len(args))
     29 	buf := &bytes.Buffer{}
     30 
     31 	for _, part := range q.Parts {
     32 		var str string
     33 		switch part := part.(type) {
     34 		case string:
     35 			str = part
     36 		case int:
     37 			argIdx := part - 1
     38 			if argIdx >= len(args) {
     39 				return "", fmt.Errorf("insufficient arguments")
     40 			}
     41 			arg := args[argIdx]
     42 			switch arg := arg.(type) {
     43 			case nil:
     44 				str = "null"
     45 			case int64:
     46 				str = strconv.FormatInt(arg, 10)
     47 			case float64:
     48 				str = strconv.FormatFloat(arg, 'f', -1, 64)
     49 			case bool:
     50 				str = strconv.FormatBool(arg)
     51 			case []byte:
     52 				str = QuoteBytes(arg)
     53 			case string:
     54 				str = QuoteString(arg)
     55 			case time.Time:
     56 				str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
     57 			default:
     58 				return "", fmt.Errorf("invalid arg type: %T", arg)
     59 			}
     60 			argUse[argIdx] = true
     61 		default:
     62 			return "", fmt.Errorf("invalid Part type: %T", part)
     63 		}
     64 		buf.WriteString(str)
     65 	}
     66 
     67 	for i, used := range argUse {
     68 		if !used {
     69 			return "", fmt.Errorf("unused argument: %d", i)
     70 		}
     71 	}
     72 	return buf.String(), nil
     73 }
     74 
     75 func NewQuery(sql string) (*Query, error) {
     76 	l := &sqlLexer{
     77 		src:     sql,
     78 		stateFn: rawState,
     79 	}
     80 
     81 	for l.stateFn != nil {
     82 		l.stateFn = l.stateFn(l)
     83 	}
     84 
     85 	query := &Query{Parts: l.parts}
     86 
     87 	return query, nil
     88 }
     89 
     90 func QuoteString(str string) string {
     91 	return "'" + strings.ReplaceAll(str, "'", "''") + "'"
     92 }
     93 
     94 func QuoteBytes(buf []byte) string {
     95 	return `'\x` + hex.EncodeToString(buf) + "'"
     96 }
     97 
     98 type sqlLexer struct {
     99 	src     string
    100 	start   int
    101 	pos     int
    102 	nested  int // multiline comment nesting level.
    103 	stateFn stateFn
    104 	parts   []Part
    105 }
    106 
    107 type stateFn func(*sqlLexer) stateFn
    108 
    109 func rawState(l *sqlLexer) stateFn {
    110 	for {
    111 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    112 		l.pos += width
    113 
    114 		switch r {
    115 		case 'e', 'E':
    116 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    117 			if nextRune == '\'' {
    118 				l.pos += width
    119 				return escapeStringState
    120 			}
    121 		case '\'':
    122 			return singleQuoteState
    123 		case '"':
    124 			return doubleQuoteState
    125 		case '$':
    126 			nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
    127 			if '0' <= nextRune && nextRune <= '9' {
    128 				if l.pos-l.start > 0 {
    129 					l.parts = append(l.parts, l.src[l.start:l.pos-width])
    130 				}
    131 				l.start = l.pos
    132 				return placeholderState
    133 			}
    134 		case '-':
    135 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    136 			if nextRune == '-' {
    137 				l.pos += width
    138 				return oneLineCommentState
    139 			}
    140 		case '/':
    141 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    142 			if nextRune == '*' {
    143 				l.pos += width
    144 				return multilineCommentState
    145 			}
    146 		case utf8.RuneError:
    147 			if width != replacementcharacterwidth {
    148 				if l.pos-l.start > 0 {
    149 					l.parts = append(l.parts, l.src[l.start:l.pos])
    150 					l.start = l.pos
    151 				}
    152 				return nil
    153 			}
    154 		}
    155 	}
    156 }
    157 
    158 func singleQuoteState(l *sqlLexer) stateFn {
    159 	for {
    160 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    161 		l.pos += width
    162 
    163 		switch r {
    164 		case '\'':
    165 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    166 			if nextRune != '\'' {
    167 				return rawState
    168 			}
    169 			l.pos += width
    170 		case utf8.RuneError:
    171 			if width != replacementcharacterwidth {
    172 				if l.pos-l.start > 0 {
    173 					l.parts = append(l.parts, l.src[l.start:l.pos])
    174 					l.start = l.pos
    175 				}
    176 				return nil
    177 			}
    178 		}
    179 	}
    180 }
    181 
    182 func doubleQuoteState(l *sqlLexer) stateFn {
    183 	for {
    184 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    185 		l.pos += width
    186 
    187 		switch r {
    188 		case '"':
    189 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    190 			if nextRune != '"' {
    191 				return rawState
    192 			}
    193 			l.pos += width
    194 		case utf8.RuneError:
    195 			if width != replacementcharacterwidth {
    196 				if l.pos-l.start > 0 {
    197 					l.parts = append(l.parts, l.src[l.start:l.pos])
    198 					l.start = l.pos
    199 				}
    200 				return nil
    201 			}
    202 		}
    203 	}
    204 }
    205 
    206 // placeholderState consumes a placeholder value. The $ must have already has
    207 // already been consumed. The first rune must be a digit.
    208 func placeholderState(l *sqlLexer) stateFn {
    209 	num := 0
    210 
    211 	for {
    212 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    213 		l.pos += width
    214 
    215 		if '0' <= r && r <= '9' {
    216 			num *= 10
    217 			num += int(r - '0')
    218 		} else {
    219 			l.parts = append(l.parts, num)
    220 			l.pos -= width
    221 			l.start = l.pos
    222 			return rawState
    223 		}
    224 	}
    225 }
    226 
    227 func escapeStringState(l *sqlLexer) stateFn {
    228 	for {
    229 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    230 		l.pos += width
    231 
    232 		switch r {
    233 		case '\\':
    234 			_, width = utf8.DecodeRuneInString(l.src[l.pos:])
    235 			l.pos += width
    236 		case '\'':
    237 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    238 			if nextRune != '\'' {
    239 				return rawState
    240 			}
    241 			l.pos += width
    242 		case utf8.RuneError:
    243 			if width != replacementcharacterwidth {
    244 				if l.pos-l.start > 0 {
    245 					l.parts = append(l.parts, l.src[l.start:l.pos])
    246 					l.start = l.pos
    247 				}
    248 				return nil
    249 			}
    250 		}
    251 	}
    252 }
    253 
    254 func oneLineCommentState(l *sqlLexer) stateFn {
    255 	for {
    256 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    257 		l.pos += width
    258 
    259 		switch r {
    260 		case '\\':
    261 			_, width = utf8.DecodeRuneInString(l.src[l.pos:])
    262 			l.pos += width
    263 		case '\n', '\r':
    264 			return rawState
    265 		case utf8.RuneError:
    266 			if width != replacementcharacterwidth {
    267 				if l.pos-l.start > 0 {
    268 					l.parts = append(l.parts, l.src[l.start:l.pos])
    269 					l.start = l.pos
    270 				}
    271 				return nil
    272 			}
    273 		}
    274 	}
    275 }
    276 
    277 func multilineCommentState(l *sqlLexer) stateFn {
    278 	for {
    279 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    280 		l.pos += width
    281 
    282 		switch r {
    283 		case '/':
    284 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    285 			if nextRune == '*' {
    286 				l.pos += width
    287 				l.nested++
    288 			}
    289 		case '*':
    290 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    291 			if nextRune != '/' {
    292 				continue
    293 			}
    294 
    295 			l.pos += width
    296 			if l.nested == 0 {
    297 				return rawState
    298 			}
    299 			l.nested--
    300 
    301 		case utf8.RuneError:
    302 			if width != replacementcharacterwidth {
    303 				if l.pos-l.start > 0 {
    304 					l.parts = append(l.parts, l.src[l.start:l.pos])
    305 					l.start = l.pos
    306 				}
    307 				return nil
    308 			}
    309 		}
    310 	}
    311 }
    312 
    313 // SanitizeSQL replaces placeholder values with args. It quotes and escapes args
    314 // as necessary. This function is only safe when standard_conforming_strings is
    315 // on.
    316 func SanitizeSQL(sql string, args ...any) (string, error) {
    317 	query, err := NewQuery(sql)
    318 	if err != nil {
    319 		return "", err
    320 	}
    321 	return query.Sanitize(args...)
    322 }