gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

model_table_struct.go (8034B)


      1 package bun
      2 
      3 import (
      4 	"context"
      5 	"database/sql"
      6 	"fmt"
      7 	"reflect"
      8 	"strings"
      9 	"time"
     10 
     11 	"github.com/uptrace/bun/schema"
     12 )
     13 
     14 type structTableModel struct {
     15 	db    *DB
     16 	table *schema.Table
     17 
     18 	rel   *schema.Relation
     19 	joins []relationJoin
     20 
     21 	dest  interface{}
     22 	root  reflect.Value
     23 	index []int
     24 
     25 	strct         reflect.Value
     26 	structInited  bool
     27 	structInitErr error
     28 
     29 	columns   []string
     30 	scanIndex int
     31 }
     32 
     33 var _ TableModel = (*structTableModel)(nil)
     34 
     35 func newStructTableModel(db *DB, dest interface{}, table *schema.Table) *structTableModel {
     36 	return &structTableModel{
     37 		db:    db,
     38 		table: table,
     39 		dest:  dest,
     40 	}
     41 }
     42 
     43 func newStructTableModelValue(db *DB, dest interface{}, v reflect.Value) *structTableModel {
     44 	return &structTableModel{
     45 		db:    db,
     46 		table: db.Table(v.Type()),
     47 		dest:  dest,
     48 		root:  v,
     49 		strct: v,
     50 	}
     51 }
     52 
     53 func (m *structTableModel) Value() interface{} {
     54 	return m.dest
     55 }
     56 
     57 func (m *structTableModel) Table() *schema.Table {
     58 	return m.table
     59 }
     60 
     61 func (m *structTableModel) Relation() *schema.Relation {
     62 	return m.rel
     63 }
     64 
     65 func (m *structTableModel) initStruct() error {
     66 	if m.structInited {
     67 		return m.structInitErr
     68 	}
     69 	m.structInited = true
     70 
     71 	switch m.strct.Kind() {
     72 	case reflect.Invalid:
     73 		m.structInitErr = errNilModel
     74 		return m.structInitErr
     75 	case reflect.Interface:
     76 		m.strct = m.strct.Elem()
     77 	}
     78 
     79 	if m.strct.Kind() == reflect.Ptr {
     80 		if m.strct.IsNil() {
     81 			m.strct.Set(reflect.New(m.strct.Type().Elem()))
     82 			m.strct = m.strct.Elem()
     83 		} else {
     84 			m.strct = m.strct.Elem()
     85 		}
     86 	}
     87 
     88 	m.mountJoins()
     89 
     90 	return nil
     91 }
     92 
     93 func (m *structTableModel) mountJoins() {
     94 	for i := range m.joins {
     95 		j := &m.joins[i]
     96 		switch j.Relation.Type {
     97 		case schema.HasOneRelation, schema.BelongsToRelation:
     98 			j.JoinModel.mount(m.strct)
     99 		}
    100 	}
    101 }
    102 
    103 var _ schema.BeforeAppendModelHook = (*structTableModel)(nil)
    104 
    105 func (m *structTableModel) BeforeAppendModel(ctx context.Context, query Query) error {
    106 	if !m.table.HasBeforeAppendModelHook() || !m.strct.IsValid() {
    107 		return nil
    108 	}
    109 	return m.strct.Addr().Interface().(schema.BeforeAppendModelHook).BeforeAppendModel(ctx, query)
    110 }
    111 
    112 var _ schema.BeforeScanRowHook = (*structTableModel)(nil)
    113 
    114 func (m *structTableModel) BeforeScanRow(ctx context.Context) error {
    115 	if m.table.HasBeforeScanRowHook() {
    116 		return m.strct.Addr().Interface().(schema.BeforeScanRowHook).BeforeScanRow(ctx)
    117 	}
    118 	if m.table.HasBeforeScanHook() {
    119 		return m.strct.Addr().Interface().(schema.BeforeScanHook).BeforeScan(ctx)
    120 	}
    121 	return nil
    122 }
    123 
    124 var _ schema.AfterScanRowHook = (*structTableModel)(nil)
    125 
    126 func (m *structTableModel) AfterScanRow(ctx context.Context) error {
    127 	if !m.structInited {
    128 		return nil
    129 	}
    130 
    131 	if m.table.HasAfterScanRowHook() {
    132 		firstErr := m.strct.Addr().Interface().(schema.AfterScanRowHook).AfterScanRow(ctx)
    133 
    134 		for _, j := range m.joins {
    135 			switch j.Relation.Type {
    136 			case schema.HasOneRelation, schema.BelongsToRelation:
    137 				if err := j.JoinModel.AfterScanRow(ctx); err != nil && firstErr == nil {
    138 					firstErr = err
    139 				}
    140 			}
    141 		}
    142 
    143 		return firstErr
    144 	}
    145 
    146 	if m.table.HasAfterScanHook() {
    147 		firstErr := m.strct.Addr().Interface().(schema.AfterScanHook).AfterScan(ctx)
    148 
    149 		for _, j := range m.joins {
    150 			switch j.Relation.Type {
    151 			case schema.HasOneRelation, schema.BelongsToRelation:
    152 				if err := j.JoinModel.AfterScanRow(ctx); err != nil && firstErr == nil {
    153 					firstErr = err
    154 				}
    155 			}
    156 		}
    157 
    158 		return firstErr
    159 	}
    160 
    161 	return nil
    162 }
    163 
    164 func (m *structTableModel) getJoin(name string) *relationJoin {
    165 	for i := range m.joins {
    166 		j := &m.joins[i]
    167 		if j.Relation.Field.Name == name || j.Relation.Field.GoName == name {
    168 			return j
    169 		}
    170 	}
    171 	return nil
    172 }
    173 
    174 func (m *structTableModel) getJoins() []relationJoin {
    175 	return m.joins
    176 }
    177 
    178 func (m *structTableModel) addJoin(j relationJoin) *relationJoin {
    179 	m.joins = append(m.joins, j)
    180 	return &m.joins[len(m.joins)-1]
    181 }
    182 
    183 func (m *structTableModel) join(name string) *relationJoin {
    184 	return m._join(m.strct, name)
    185 }
    186 
    187 func (m *structTableModel) _join(bind reflect.Value, name string) *relationJoin {
    188 	path := strings.Split(name, ".")
    189 	index := make([]int, 0, len(path))
    190 
    191 	currJoin := relationJoin{
    192 		BaseModel: m,
    193 		JoinModel: m,
    194 	}
    195 	var lastJoin *relationJoin
    196 
    197 	for _, name := range path {
    198 		relation, ok := currJoin.JoinModel.Table().Relations[name]
    199 		if !ok {
    200 			return nil
    201 		}
    202 
    203 		currJoin.Relation = relation
    204 		index = append(index, relation.Field.Index...)
    205 
    206 		if j := currJoin.JoinModel.getJoin(name); j != nil {
    207 			currJoin.BaseModel = j.BaseModel
    208 			currJoin.JoinModel = j.JoinModel
    209 
    210 			lastJoin = j
    211 		} else {
    212 			model, err := newTableModelIndex(m.db, m.table, bind, index, relation)
    213 			if err != nil {
    214 				return nil
    215 			}
    216 
    217 			currJoin.Parent = lastJoin
    218 			currJoin.BaseModel = currJoin.JoinModel
    219 			currJoin.JoinModel = model
    220 
    221 			lastJoin = currJoin.BaseModel.addJoin(currJoin)
    222 		}
    223 	}
    224 
    225 	return lastJoin
    226 }
    227 
    228 func (m *structTableModel) rootValue() reflect.Value {
    229 	return m.root
    230 }
    231 
    232 func (m *structTableModel) parentIndex() []int {
    233 	return m.index[:len(m.index)-len(m.rel.Field.Index)]
    234 }
    235 
    236 func (m *structTableModel) mount(host reflect.Value) {
    237 	m.strct = host.FieldByIndex(m.rel.Field.Index)
    238 	m.structInited = false
    239 }
    240 
    241 func (m *structTableModel) updateSoftDeleteField(tm time.Time) error {
    242 	if !m.strct.IsValid() {
    243 		return nil
    244 	}
    245 	fv := m.table.SoftDeleteField.Value(m.strct)
    246 	return m.table.UpdateSoftDeleteField(fv, tm)
    247 }
    248 
    249 func (m *structTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
    250 	if !rows.Next() {
    251 		return 0, rows.Err()
    252 	}
    253 
    254 	var n int
    255 
    256 	if err := m.ScanRow(ctx, rows); err != nil {
    257 		return 0, err
    258 	}
    259 	n++
    260 
    261 	// And discard the rest. This is especially important for SQLite3, which can return
    262 	// a row like it was inserted sucessfully and then return an actual error for the next row.
    263 	// See issues/100.
    264 	for rows.Next() {
    265 		n++
    266 	}
    267 	if err := rows.Err(); err != nil {
    268 		return 0, err
    269 	}
    270 
    271 	return n, nil
    272 }
    273 
    274 func (m *structTableModel) ScanRow(ctx context.Context, rows *sql.Rows) error {
    275 	columns, err := rows.Columns()
    276 	if err != nil {
    277 		return err
    278 	}
    279 
    280 	m.columns = columns
    281 	dest := makeDest(m, len(columns))
    282 
    283 	return m.scanRow(ctx, rows, dest)
    284 }
    285 
    286 func (m *structTableModel) scanRow(ctx context.Context, rows *sql.Rows, dest []interface{}) error {
    287 	if err := m.BeforeScanRow(ctx); err != nil {
    288 		return err
    289 	}
    290 
    291 	m.scanIndex = 0
    292 	if err := rows.Scan(dest...); err != nil {
    293 		return err
    294 	}
    295 
    296 	if err := m.AfterScanRow(ctx); err != nil {
    297 		return err
    298 	}
    299 
    300 	return nil
    301 }
    302 
    303 func (m *structTableModel) Scan(src interface{}) error {
    304 	column := m.columns[m.scanIndex]
    305 	m.scanIndex++
    306 
    307 	return m.ScanColumn(unquote(column), src)
    308 }
    309 
    310 func (m *structTableModel) ScanColumn(column string, src interface{}) error {
    311 	if ok, err := m.scanColumn(column, src); ok {
    312 		return err
    313 	}
    314 	if column == "" || column[0] == '_' || m.db.flags.Has(discardUnknownColumns) {
    315 		return nil
    316 	}
    317 	return fmt.Errorf("bun: %s does not have column %q", m.table.TypeName, column)
    318 }
    319 
    320 func (m *structTableModel) scanColumn(column string, src interface{}) (bool, error) {
    321 	if src != nil {
    322 		if err := m.initStruct(); err != nil {
    323 			return true, err
    324 		}
    325 	}
    326 
    327 	if field, ok := m.table.FieldMap[column]; ok {
    328 		if src == nil && m.isNil() {
    329 			return true, nil
    330 		}
    331 		return true, field.ScanValue(m.strct, src)
    332 	}
    333 
    334 	if joinName, column := splitColumn(column); joinName != "" {
    335 		if join := m.getJoin(joinName); join != nil {
    336 			return true, join.JoinModel.ScanColumn(column, src)
    337 		}
    338 
    339 		if m.table.ModelName == joinName {
    340 			return true, m.ScanColumn(column, src)
    341 		}
    342 	}
    343 
    344 	return false, nil
    345 }
    346 
    347 func (m *structTableModel) isNil() bool {
    348 	return m.strct.Kind() == reflect.Ptr && m.strct.IsNil()
    349 }
    350 
    351 func (m *structTableModel) AppendNamedArg(
    352 	fmter schema.Formatter, b []byte, name string,
    353 ) ([]byte, bool) {
    354 	return m.table.AppendNamedArg(fmter, b, name, m.strct)
    355 }
    356 
    357 // sqlite3 sometimes does not unquote columns.
    358 func unquote(s string) string {
    359 	if s == "" {
    360 		return s
    361 	}
    362 	if s[0] == '"' && s[len(s)-1] == '"' {
    363 		return s[1 : len(s)-1]
    364 	}
    365 	return s
    366 }
    367 
    368 func splitColumn(s string) (string, string) {
    369 	if i := strings.Index(s, "__"); i >= 0 {
    370 		return s[:i], s[i+2:]
    371 	}
    372 	return "", s
    373 }