scan.go (10140B)
1 package schema 2 3 import ( 4 "bytes" 5 "database/sql" 6 "fmt" 7 "net" 8 "reflect" 9 "strconv" 10 "strings" 11 "sync" 12 "time" 13 14 "github.com/vmihailenco/msgpack/v5" 15 16 "github.com/uptrace/bun/dialect/sqltype" 17 "github.com/uptrace/bun/extra/bunjson" 18 "github.com/uptrace/bun/internal" 19 ) 20 21 var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() 22 23 type ScannerFunc func(dest reflect.Value, src interface{}) error 24 25 var scanners []ScannerFunc 26 27 func init() { 28 scanners = []ScannerFunc{ 29 reflect.Bool: scanBool, 30 reflect.Int: scanInt64, 31 reflect.Int8: scanInt64, 32 reflect.Int16: scanInt64, 33 reflect.Int32: scanInt64, 34 reflect.Int64: scanInt64, 35 reflect.Uint: scanUint64, 36 reflect.Uint8: scanUint64, 37 reflect.Uint16: scanUint64, 38 reflect.Uint32: scanUint64, 39 reflect.Uint64: scanUint64, 40 reflect.Uintptr: scanUint64, 41 reflect.Float32: scanFloat64, 42 reflect.Float64: scanFloat64, 43 reflect.Complex64: nil, 44 reflect.Complex128: nil, 45 reflect.Array: nil, 46 reflect.Interface: scanInterface, 47 reflect.Map: scanJSON, 48 reflect.Ptr: nil, 49 reflect.Slice: scanJSON, 50 reflect.String: scanString, 51 reflect.Struct: scanJSON, 52 reflect.UnsafePointer: nil, 53 } 54 } 55 56 var scannerMap sync.Map 57 58 func FieldScanner(dialect Dialect, field *Field) ScannerFunc { 59 if field.Tag.HasOption("msgpack") { 60 return scanMsgpack 61 } 62 if field.Tag.HasOption("json_use_number") { 63 return scanJSONUseNumber 64 } 65 if field.StructField.Type.Kind() == reflect.Interface { 66 switch strings.ToUpper(field.UserSQLType) { 67 case sqltype.JSON, sqltype.JSONB: 68 return scanJSONIntoInterface 69 } 70 } 71 return Scanner(field.StructField.Type) 72 } 73 74 func Scanner(typ reflect.Type) ScannerFunc { 75 if v, ok := scannerMap.Load(typ); ok { 76 return v.(ScannerFunc) 77 } 78 79 fn := scanner(typ) 80 81 if v, ok := scannerMap.LoadOrStore(typ, fn); ok { 82 return v.(ScannerFunc) 83 } 84 return fn 85 } 86 87 func scanner(typ reflect.Type) ScannerFunc { 88 kind := typ.Kind() 89 90 if kind == reflect.Ptr { 91 if fn := Scanner(typ.Elem()); fn != nil { 92 return PtrScanner(fn) 93 } 94 } 95 96 switch typ { 97 case bytesType: 98 return scanBytes 99 case timeType: 100 return scanTime 101 case ipType: 102 return scanIP 103 case ipNetType: 104 return scanIPNet 105 case jsonRawMessageType: 106 return scanBytes 107 } 108 109 if typ.Implements(scannerType) { 110 return scanScanner 111 } 112 113 if kind != reflect.Ptr { 114 ptr := reflect.PtrTo(typ) 115 if ptr.Implements(scannerType) { 116 return addrScanner(scanScanner) 117 } 118 } 119 120 if typ.Kind() == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 { 121 return scanBytes 122 } 123 124 return scanners[kind] 125 } 126 127 func scanBool(dest reflect.Value, src interface{}) error { 128 switch src := src.(type) { 129 case nil: 130 dest.SetBool(false) 131 return nil 132 case bool: 133 dest.SetBool(src) 134 return nil 135 case int64: 136 dest.SetBool(src != 0) 137 return nil 138 case []byte: 139 f, err := strconv.ParseBool(internal.String(src)) 140 if err != nil { 141 return err 142 } 143 dest.SetBool(f) 144 return nil 145 case string: 146 f, err := strconv.ParseBool(src) 147 if err != nil { 148 return err 149 } 150 dest.SetBool(f) 151 return nil 152 default: 153 return scanError(dest.Type(), src) 154 } 155 } 156 157 func scanInt64(dest reflect.Value, src interface{}) error { 158 switch src := src.(type) { 159 case nil: 160 dest.SetInt(0) 161 return nil 162 case int64: 163 dest.SetInt(src) 164 return nil 165 case uint64: 166 dest.SetInt(int64(src)) 167 return nil 168 case []byte: 169 n, err := strconv.ParseInt(internal.String(src), 10, 64) 170 if err != nil { 171 return err 172 } 173 dest.SetInt(n) 174 return nil 175 case string: 176 n, err := strconv.ParseInt(src, 10, 64) 177 if err != nil { 178 return err 179 } 180 dest.SetInt(n) 181 return nil 182 default: 183 return scanError(dest.Type(), src) 184 } 185 } 186 187 func scanUint64(dest reflect.Value, src interface{}) error { 188 switch src := src.(type) { 189 case nil: 190 dest.SetUint(0) 191 return nil 192 case uint64: 193 dest.SetUint(src) 194 return nil 195 case int64: 196 dest.SetUint(uint64(src)) 197 return nil 198 case []byte: 199 n, err := strconv.ParseUint(internal.String(src), 10, 64) 200 if err != nil { 201 return err 202 } 203 dest.SetUint(n) 204 return nil 205 case string: 206 n, err := strconv.ParseUint(src, 10, 64) 207 if err != nil { 208 return err 209 } 210 dest.SetUint(n) 211 return nil 212 default: 213 return scanError(dest.Type(), src) 214 } 215 } 216 217 func scanFloat64(dest reflect.Value, src interface{}) error { 218 switch src := src.(type) { 219 case nil: 220 dest.SetFloat(0) 221 return nil 222 case float64: 223 dest.SetFloat(src) 224 return nil 225 case []byte: 226 f, err := strconv.ParseFloat(internal.String(src), 64) 227 if err != nil { 228 return err 229 } 230 dest.SetFloat(f) 231 return nil 232 case string: 233 f, err := strconv.ParseFloat(src, 64) 234 if err != nil { 235 return err 236 } 237 dest.SetFloat(f) 238 return nil 239 default: 240 return scanError(dest.Type(), src) 241 } 242 } 243 244 func scanString(dest reflect.Value, src interface{}) error { 245 switch src := src.(type) { 246 case nil: 247 dest.SetString("") 248 return nil 249 case string: 250 dest.SetString(src) 251 return nil 252 case []byte: 253 dest.SetString(string(src)) 254 return nil 255 case time.Time: 256 dest.SetString(src.Format(time.RFC3339Nano)) 257 return nil 258 case int64: 259 dest.SetString(strconv.FormatInt(src, 10)) 260 return nil 261 case uint64: 262 dest.SetString(strconv.FormatUint(src, 10)) 263 return nil 264 case float64: 265 dest.SetString(strconv.FormatFloat(src, 'G', -1, 64)) 266 return nil 267 default: 268 return scanError(dest.Type(), src) 269 } 270 } 271 272 func scanBytes(dest reflect.Value, src interface{}) error { 273 switch src := src.(type) { 274 case nil: 275 dest.SetBytes(nil) 276 return nil 277 case string: 278 dest.SetBytes([]byte(src)) 279 return nil 280 case []byte: 281 clone := make([]byte, len(src)) 282 copy(clone, src) 283 284 dest.SetBytes(clone) 285 return nil 286 default: 287 return scanError(dest.Type(), src) 288 } 289 } 290 291 func scanTime(dest reflect.Value, src interface{}) error { 292 switch src := src.(type) { 293 case nil: 294 destTime := dest.Addr().Interface().(*time.Time) 295 *destTime = time.Time{} 296 return nil 297 case time.Time: 298 destTime := dest.Addr().Interface().(*time.Time) 299 *destTime = src 300 return nil 301 case string: 302 srcTime, err := internal.ParseTime(src) 303 if err != nil { 304 return err 305 } 306 destTime := dest.Addr().Interface().(*time.Time) 307 *destTime = srcTime 308 return nil 309 case []byte: 310 srcTime, err := internal.ParseTime(internal.String(src)) 311 if err != nil { 312 return err 313 } 314 destTime := dest.Addr().Interface().(*time.Time) 315 *destTime = srcTime 316 return nil 317 default: 318 return scanError(dest.Type(), src) 319 } 320 } 321 322 func scanScanner(dest reflect.Value, src interface{}) error { 323 return dest.Interface().(sql.Scanner).Scan(src) 324 } 325 326 func scanMsgpack(dest reflect.Value, src interface{}) error { 327 if src == nil { 328 return scanNull(dest) 329 } 330 331 b, err := toBytes(src) 332 if err != nil { 333 return err 334 } 335 336 dec := msgpack.GetDecoder() 337 defer msgpack.PutDecoder(dec) 338 339 dec.Reset(bytes.NewReader(b)) 340 return dec.DecodeValue(dest) 341 } 342 343 func scanJSON(dest reflect.Value, src interface{}) error { 344 if src == nil { 345 return scanNull(dest) 346 } 347 348 b, err := toBytes(src) 349 if err != nil { 350 return err 351 } 352 353 return bunjson.Unmarshal(b, dest.Addr().Interface()) 354 } 355 356 func scanJSONUseNumber(dest reflect.Value, src interface{}) error { 357 if src == nil { 358 return scanNull(dest) 359 } 360 361 b, err := toBytes(src) 362 if err != nil { 363 return err 364 } 365 366 dec := bunjson.NewDecoder(bytes.NewReader(b)) 367 dec.UseNumber() 368 return dec.Decode(dest.Addr().Interface()) 369 } 370 371 func scanIP(dest reflect.Value, src interface{}) error { 372 if src == nil { 373 return scanNull(dest) 374 } 375 376 b, err := toBytes(src) 377 if err != nil { 378 return err 379 } 380 381 ip := net.ParseIP(internal.String(b)) 382 if ip == nil { 383 return fmt.Errorf("bun: invalid ip: %q", b) 384 } 385 386 ptr := dest.Addr().Interface().(*net.IP) 387 *ptr = ip 388 389 return nil 390 } 391 392 func scanIPNet(dest reflect.Value, src interface{}) error { 393 if src == nil { 394 return scanNull(dest) 395 } 396 397 b, err := toBytes(src) 398 if err != nil { 399 return err 400 } 401 402 _, ipnet, err := net.ParseCIDR(internal.String(b)) 403 if err != nil { 404 return err 405 } 406 407 ptr := dest.Addr().Interface().(*net.IPNet) 408 *ptr = *ipnet 409 410 return nil 411 } 412 413 func addrScanner(fn ScannerFunc) ScannerFunc { 414 return func(dest reflect.Value, src interface{}) error { 415 if !dest.CanAddr() { 416 return fmt.Errorf("bun: Scan(nonaddressable %T)", dest.Interface()) 417 } 418 return fn(dest.Addr(), src) 419 } 420 } 421 422 func toBytes(src interface{}) ([]byte, error) { 423 switch src := src.(type) { 424 case string: 425 return internal.Bytes(src), nil 426 case []byte: 427 return src, nil 428 default: 429 return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src) 430 } 431 } 432 433 func PtrScanner(fn ScannerFunc) ScannerFunc { 434 return func(dest reflect.Value, src interface{}) error { 435 if src == nil { 436 if !dest.CanAddr() { 437 if dest.IsNil() { 438 return nil 439 } 440 return fn(dest.Elem(), src) 441 } 442 443 if !dest.IsNil() { 444 dest.Set(reflect.New(dest.Type().Elem())) 445 } 446 return nil 447 } 448 449 if dest.IsNil() { 450 dest.Set(reflect.New(dest.Type().Elem())) 451 } 452 453 if dest.Kind() == reflect.Map { 454 return fn(dest, src) 455 } 456 457 return fn(dest.Elem(), src) 458 } 459 } 460 461 func scanNull(dest reflect.Value) error { 462 if nilable(dest.Kind()) && dest.IsNil() { 463 return nil 464 } 465 dest.Set(reflect.New(dest.Type()).Elem()) 466 return nil 467 } 468 469 func scanJSONIntoInterface(dest reflect.Value, src interface{}) error { 470 if dest.IsNil() { 471 if src == nil { 472 return nil 473 } 474 475 b, err := toBytes(src) 476 if err != nil { 477 return err 478 } 479 480 return bunjson.Unmarshal(b, dest.Addr().Interface()) 481 } 482 483 dest = dest.Elem() 484 if fn := Scanner(dest.Type()); fn != nil { 485 return fn(dest, src) 486 } 487 return scanError(dest.Type(), src) 488 } 489 490 func scanInterface(dest reflect.Value, src interface{}) error { 491 if dest.IsNil() { 492 if src == nil { 493 return nil 494 } 495 dest.Set(reflect.ValueOf(src)) 496 return nil 497 } 498 499 dest = dest.Elem() 500 if fn := Scanner(dest.Type()); fn != nil { 501 return fn(dest, src) 502 } 503 return scanError(dest.Type(), src) 504 } 505 506 func nilable(kind reflect.Kind) bool { 507 switch kind { 508 case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: 509 return true 510 } 511 return false 512 } 513 514 func scanError(dest reflect.Type, src interface{}) error { 515 return fmt.Errorf("bun: can't scan %#v (%T) into %s", src, src, dest.String()) 516 }