gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

scan.go (10140B)


      1 package schema
      2 
      3 import (
      4 	"bytes"
      5 	"database/sql"
      6 	"fmt"
      7 	"net"
      8 	"reflect"
      9 	"strconv"
     10 	"strings"
     11 	"sync"
     12 	"time"
     13 
     14 	"github.com/vmihailenco/msgpack/v5"
     15 
     16 	"github.com/uptrace/bun/dialect/sqltype"
     17 	"github.com/uptrace/bun/extra/bunjson"
     18 	"github.com/uptrace/bun/internal"
     19 )
     20 
     21 var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
     22 
     23 type ScannerFunc func(dest reflect.Value, src interface{}) error
     24 
     25 var scanners []ScannerFunc
     26 
     27 func init() {
     28 	scanners = []ScannerFunc{
     29 		reflect.Bool:          scanBool,
     30 		reflect.Int:           scanInt64,
     31 		reflect.Int8:          scanInt64,
     32 		reflect.Int16:         scanInt64,
     33 		reflect.Int32:         scanInt64,
     34 		reflect.Int64:         scanInt64,
     35 		reflect.Uint:          scanUint64,
     36 		reflect.Uint8:         scanUint64,
     37 		reflect.Uint16:        scanUint64,
     38 		reflect.Uint32:        scanUint64,
     39 		reflect.Uint64:        scanUint64,
     40 		reflect.Uintptr:       scanUint64,
     41 		reflect.Float32:       scanFloat64,
     42 		reflect.Float64:       scanFloat64,
     43 		reflect.Complex64:     nil,
     44 		reflect.Complex128:    nil,
     45 		reflect.Array:         nil,
     46 		reflect.Interface:     scanInterface,
     47 		reflect.Map:           scanJSON,
     48 		reflect.Ptr:           nil,
     49 		reflect.Slice:         scanJSON,
     50 		reflect.String:        scanString,
     51 		reflect.Struct:        scanJSON,
     52 		reflect.UnsafePointer: nil,
     53 	}
     54 }
     55 
     56 var scannerMap sync.Map
     57 
     58 func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
     59 	if field.Tag.HasOption("msgpack") {
     60 		return scanMsgpack
     61 	}
     62 	if field.Tag.HasOption("json_use_number") {
     63 		return scanJSONUseNumber
     64 	}
     65 	if field.StructField.Type.Kind() == reflect.Interface {
     66 		switch strings.ToUpper(field.UserSQLType) {
     67 		case sqltype.JSON, sqltype.JSONB:
     68 			return scanJSONIntoInterface
     69 		}
     70 	}
     71 	return Scanner(field.StructField.Type)
     72 }
     73 
     74 func Scanner(typ reflect.Type) ScannerFunc {
     75 	if v, ok := scannerMap.Load(typ); ok {
     76 		return v.(ScannerFunc)
     77 	}
     78 
     79 	fn := scanner(typ)
     80 
     81 	if v, ok := scannerMap.LoadOrStore(typ, fn); ok {
     82 		return v.(ScannerFunc)
     83 	}
     84 	return fn
     85 }
     86 
     87 func scanner(typ reflect.Type) ScannerFunc {
     88 	kind := typ.Kind()
     89 
     90 	if kind == reflect.Ptr {
     91 		if fn := Scanner(typ.Elem()); fn != nil {
     92 			return PtrScanner(fn)
     93 		}
     94 	}
     95 
     96 	switch typ {
     97 	case bytesType:
     98 		return scanBytes
     99 	case timeType:
    100 		return scanTime
    101 	case ipType:
    102 		return scanIP
    103 	case ipNetType:
    104 		return scanIPNet
    105 	case jsonRawMessageType:
    106 		return scanBytes
    107 	}
    108 
    109 	if typ.Implements(scannerType) {
    110 		return scanScanner
    111 	}
    112 
    113 	if kind != reflect.Ptr {
    114 		ptr := reflect.PtrTo(typ)
    115 		if ptr.Implements(scannerType) {
    116 			return addrScanner(scanScanner)
    117 		}
    118 	}
    119 
    120 	if typ.Kind() == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 {
    121 		return scanBytes
    122 	}
    123 
    124 	return scanners[kind]
    125 }
    126 
    127 func scanBool(dest reflect.Value, src interface{}) error {
    128 	switch src := src.(type) {
    129 	case nil:
    130 		dest.SetBool(false)
    131 		return nil
    132 	case bool:
    133 		dest.SetBool(src)
    134 		return nil
    135 	case int64:
    136 		dest.SetBool(src != 0)
    137 		return nil
    138 	case []byte:
    139 		f, err := strconv.ParseBool(internal.String(src))
    140 		if err != nil {
    141 			return err
    142 		}
    143 		dest.SetBool(f)
    144 		return nil
    145 	case string:
    146 		f, err := strconv.ParseBool(src)
    147 		if err != nil {
    148 			return err
    149 		}
    150 		dest.SetBool(f)
    151 		return nil
    152 	default:
    153 		return scanError(dest.Type(), src)
    154 	}
    155 }
    156 
    157 func scanInt64(dest reflect.Value, src interface{}) error {
    158 	switch src := src.(type) {
    159 	case nil:
    160 		dest.SetInt(0)
    161 		return nil
    162 	case int64:
    163 		dest.SetInt(src)
    164 		return nil
    165 	case uint64:
    166 		dest.SetInt(int64(src))
    167 		return nil
    168 	case []byte:
    169 		n, err := strconv.ParseInt(internal.String(src), 10, 64)
    170 		if err != nil {
    171 			return err
    172 		}
    173 		dest.SetInt(n)
    174 		return nil
    175 	case string:
    176 		n, err := strconv.ParseInt(src, 10, 64)
    177 		if err != nil {
    178 			return err
    179 		}
    180 		dest.SetInt(n)
    181 		return nil
    182 	default:
    183 		return scanError(dest.Type(), src)
    184 	}
    185 }
    186 
    187 func scanUint64(dest reflect.Value, src interface{}) error {
    188 	switch src := src.(type) {
    189 	case nil:
    190 		dest.SetUint(0)
    191 		return nil
    192 	case uint64:
    193 		dest.SetUint(src)
    194 		return nil
    195 	case int64:
    196 		dest.SetUint(uint64(src))
    197 		return nil
    198 	case []byte:
    199 		n, err := strconv.ParseUint(internal.String(src), 10, 64)
    200 		if err != nil {
    201 			return err
    202 		}
    203 		dest.SetUint(n)
    204 		return nil
    205 	case string:
    206 		n, err := strconv.ParseUint(src, 10, 64)
    207 		if err != nil {
    208 			return err
    209 		}
    210 		dest.SetUint(n)
    211 		return nil
    212 	default:
    213 		return scanError(dest.Type(), src)
    214 	}
    215 }
    216 
    217 func scanFloat64(dest reflect.Value, src interface{}) error {
    218 	switch src := src.(type) {
    219 	case nil:
    220 		dest.SetFloat(0)
    221 		return nil
    222 	case float64:
    223 		dest.SetFloat(src)
    224 		return nil
    225 	case []byte:
    226 		f, err := strconv.ParseFloat(internal.String(src), 64)
    227 		if err != nil {
    228 			return err
    229 		}
    230 		dest.SetFloat(f)
    231 		return nil
    232 	case string:
    233 		f, err := strconv.ParseFloat(src, 64)
    234 		if err != nil {
    235 			return err
    236 		}
    237 		dest.SetFloat(f)
    238 		return nil
    239 	default:
    240 		return scanError(dest.Type(), src)
    241 	}
    242 }
    243 
    244 func scanString(dest reflect.Value, src interface{}) error {
    245 	switch src := src.(type) {
    246 	case nil:
    247 		dest.SetString("")
    248 		return nil
    249 	case string:
    250 		dest.SetString(src)
    251 		return nil
    252 	case []byte:
    253 		dest.SetString(string(src))
    254 		return nil
    255 	case time.Time:
    256 		dest.SetString(src.Format(time.RFC3339Nano))
    257 		return nil
    258 	case int64:
    259 		dest.SetString(strconv.FormatInt(src, 10))
    260 		return nil
    261 	case uint64:
    262 		dest.SetString(strconv.FormatUint(src, 10))
    263 		return nil
    264 	case float64:
    265 		dest.SetString(strconv.FormatFloat(src, 'G', -1, 64))
    266 		return nil
    267 	default:
    268 		return scanError(dest.Type(), src)
    269 	}
    270 }
    271 
    272 func scanBytes(dest reflect.Value, src interface{}) error {
    273 	switch src := src.(type) {
    274 	case nil:
    275 		dest.SetBytes(nil)
    276 		return nil
    277 	case string:
    278 		dest.SetBytes([]byte(src))
    279 		return nil
    280 	case []byte:
    281 		clone := make([]byte, len(src))
    282 		copy(clone, src)
    283 
    284 		dest.SetBytes(clone)
    285 		return nil
    286 	default:
    287 		return scanError(dest.Type(), src)
    288 	}
    289 }
    290 
    291 func scanTime(dest reflect.Value, src interface{}) error {
    292 	switch src := src.(type) {
    293 	case nil:
    294 		destTime := dest.Addr().Interface().(*time.Time)
    295 		*destTime = time.Time{}
    296 		return nil
    297 	case time.Time:
    298 		destTime := dest.Addr().Interface().(*time.Time)
    299 		*destTime = src
    300 		return nil
    301 	case string:
    302 		srcTime, err := internal.ParseTime(src)
    303 		if err != nil {
    304 			return err
    305 		}
    306 		destTime := dest.Addr().Interface().(*time.Time)
    307 		*destTime = srcTime
    308 		return nil
    309 	case []byte:
    310 		srcTime, err := internal.ParseTime(internal.String(src))
    311 		if err != nil {
    312 			return err
    313 		}
    314 		destTime := dest.Addr().Interface().(*time.Time)
    315 		*destTime = srcTime
    316 		return nil
    317 	default:
    318 		return scanError(dest.Type(), src)
    319 	}
    320 }
    321 
    322 func scanScanner(dest reflect.Value, src interface{}) error {
    323 	return dest.Interface().(sql.Scanner).Scan(src)
    324 }
    325 
    326 func scanMsgpack(dest reflect.Value, src interface{}) error {
    327 	if src == nil {
    328 		return scanNull(dest)
    329 	}
    330 
    331 	b, err := toBytes(src)
    332 	if err != nil {
    333 		return err
    334 	}
    335 
    336 	dec := msgpack.GetDecoder()
    337 	defer msgpack.PutDecoder(dec)
    338 
    339 	dec.Reset(bytes.NewReader(b))
    340 	return dec.DecodeValue(dest)
    341 }
    342 
    343 func scanJSON(dest reflect.Value, src interface{}) error {
    344 	if src == nil {
    345 		return scanNull(dest)
    346 	}
    347 
    348 	b, err := toBytes(src)
    349 	if err != nil {
    350 		return err
    351 	}
    352 
    353 	return bunjson.Unmarshal(b, dest.Addr().Interface())
    354 }
    355 
    356 func scanJSONUseNumber(dest reflect.Value, src interface{}) error {
    357 	if src == nil {
    358 		return scanNull(dest)
    359 	}
    360 
    361 	b, err := toBytes(src)
    362 	if err != nil {
    363 		return err
    364 	}
    365 
    366 	dec := bunjson.NewDecoder(bytes.NewReader(b))
    367 	dec.UseNumber()
    368 	return dec.Decode(dest.Addr().Interface())
    369 }
    370 
    371 func scanIP(dest reflect.Value, src interface{}) error {
    372 	if src == nil {
    373 		return scanNull(dest)
    374 	}
    375 
    376 	b, err := toBytes(src)
    377 	if err != nil {
    378 		return err
    379 	}
    380 
    381 	ip := net.ParseIP(internal.String(b))
    382 	if ip == nil {
    383 		return fmt.Errorf("bun: invalid ip: %q", b)
    384 	}
    385 
    386 	ptr := dest.Addr().Interface().(*net.IP)
    387 	*ptr = ip
    388 
    389 	return nil
    390 }
    391 
    392 func scanIPNet(dest reflect.Value, src interface{}) error {
    393 	if src == nil {
    394 		return scanNull(dest)
    395 	}
    396 
    397 	b, err := toBytes(src)
    398 	if err != nil {
    399 		return err
    400 	}
    401 
    402 	_, ipnet, err := net.ParseCIDR(internal.String(b))
    403 	if err != nil {
    404 		return err
    405 	}
    406 
    407 	ptr := dest.Addr().Interface().(*net.IPNet)
    408 	*ptr = *ipnet
    409 
    410 	return nil
    411 }
    412 
    413 func addrScanner(fn ScannerFunc) ScannerFunc {
    414 	return func(dest reflect.Value, src interface{}) error {
    415 		if !dest.CanAddr() {
    416 			return fmt.Errorf("bun: Scan(nonaddressable %T)", dest.Interface())
    417 		}
    418 		return fn(dest.Addr(), src)
    419 	}
    420 }
    421 
    422 func toBytes(src interface{}) ([]byte, error) {
    423 	switch src := src.(type) {
    424 	case string:
    425 		return internal.Bytes(src), nil
    426 	case []byte:
    427 		return src, nil
    428 	default:
    429 		return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src)
    430 	}
    431 }
    432 
    433 func PtrScanner(fn ScannerFunc) ScannerFunc {
    434 	return func(dest reflect.Value, src interface{}) error {
    435 		if src == nil {
    436 			if !dest.CanAddr() {
    437 				if dest.IsNil() {
    438 					return nil
    439 				}
    440 				return fn(dest.Elem(), src)
    441 			}
    442 
    443 			if !dest.IsNil() {
    444 				dest.Set(reflect.New(dest.Type().Elem()))
    445 			}
    446 			return nil
    447 		}
    448 
    449 		if dest.IsNil() {
    450 			dest.Set(reflect.New(dest.Type().Elem()))
    451 		}
    452 
    453 		if dest.Kind() == reflect.Map {
    454 			return fn(dest, src)
    455 		}
    456 
    457 		return fn(dest.Elem(), src)
    458 	}
    459 }
    460 
    461 func scanNull(dest reflect.Value) error {
    462 	if nilable(dest.Kind()) && dest.IsNil() {
    463 		return nil
    464 	}
    465 	dest.Set(reflect.New(dest.Type()).Elem())
    466 	return nil
    467 }
    468 
    469 func scanJSONIntoInterface(dest reflect.Value, src interface{}) error {
    470 	if dest.IsNil() {
    471 		if src == nil {
    472 			return nil
    473 		}
    474 
    475 		b, err := toBytes(src)
    476 		if err != nil {
    477 			return err
    478 		}
    479 
    480 		return bunjson.Unmarshal(b, dest.Addr().Interface())
    481 	}
    482 
    483 	dest = dest.Elem()
    484 	if fn := Scanner(dest.Type()); fn != nil {
    485 		return fn(dest, src)
    486 	}
    487 	return scanError(dest.Type(), src)
    488 }
    489 
    490 func scanInterface(dest reflect.Value, src interface{}) error {
    491 	if dest.IsNil() {
    492 		if src == nil {
    493 			return nil
    494 		}
    495 		dest.Set(reflect.ValueOf(src))
    496 		return nil
    497 	}
    498 
    499 	dest = dest.Elem()
    500 	if fn := Scanner(dest.Type()); fn != nil {
    501 		return fn(dest, src)
    502 	}
    503 	return scanError(dest.Type(), src)
    504 }
    505 
    506 func nilable(kind reflect.Kind) bool {
    507 	switch kind {
    508 	case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
    509 		return true
    510 	}
    511 	return false
    512 }
    513 
    514 func scanError(dest reflect.Type, src interface{}) error {
    515 	return fmt.Errorf("bun: can't scan %#v (%T) into %s", src, src, dest.String())
    516 }