model_table_struct.go (8034B)
1 package bun 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "reflect" 8 "strings" 9 "time" 10 11 "github.com/uptrace/bun/schema" 12 ) 13 14 type structTableModel struct { 15 db *DB 16 table *schema.Table 17 18 rel *schema.Relation 19 joins []relationJoin 20 21 dest interface{} 22 root reflect.Value 23 index []int 24 25 strct reflect.Value 26 structInited bool 27 structInitErr error 28 29 columns []string 30 scanIndex int 31 } 32 33 var _ TableModel = (*structTableModel)(nil) 34 35 func newStructTableModel(db *DB, dest interface{}, table *schema.Table) *structTableModel { 36 return &structTableModel{ 37 db: db, 38 table: table, 39 dest: dest, 40 } 41 } 42 43 func newStructTableModelValue(db *DB, dest interface{}, v reflect.Value) *structTableModel { 44 return &structTableModel{ 45 db: db, 46 table: db.Table(v.Type()), 47 dest: dest, 48 root: v, 49 strct: v, 50 } 51 } 52 53 func (m *structTableModel) Value() interface{} { 54 return m.dest 55 } 56 57 func (m *structTableModel) Table() *schema.Table { 58 return m.table 59 } 60 61 func (m *structTableModel) Relation() *schema.Relation { 62 return m.rel 63 } 64 65 func (m *structTableModel) initStruct() error { 66 if m.structInited { 67 return m.structInitErr 68 } 69 m.structInited = true 70 71 switch m.strct.Kind() { 72 case reflect.Invalid: 73 m.structInitErr = errNilModel 74 return m.structInitErr 75 case reflect.Interface: 76 m.strct = m.strct.Elem() 77 } 78 79 if m.strct.Kind() == reflect.Ptr { 80 if m.strct.IsNil() { 81 m.strct.Set(reflect.New(m.strct.Type().Elem())) 82 m.strct = m.strct.Elem() 83 } else { 84 m.strct = m.strct.Elem() 85 } 86 } 87 88 m.mountJoins() 89 90 return nil 91 } 92 93 func (m *structTableModel) mountJoins() { 94 for i := range m.joins { 95 j := &m.joins[i] 96 switch j.Relation.Type { 97 case schema.HasOneRelation, schema.BelongsToRelation: 98 j.JoinModel.mount(m.strct) 99 } 100 } 101 } 102 103 var _ schema.BeforeAppendModelHook = (*structTableModel)(nil) 104 105 func (m *structTableModel) BeforeAppendModel(ctx context.Context, query Query) error { 106 if !m.table.HasBeforeAppendModelHook() || !m.strct.IsValid() { 107 return nil 108 } 109 return m.strct.Addr().Interface().(schema.BeforeAppendModelHook).BeforeAppendModel(ctx, query) 110 } 111 112 var _ schema.BeforeScanRowHook = (*structTableModel)(nil) 113 114 func (m *structTableModel) BeforeScanRow(ctx context.Context) error { 115 if m.table.HasBeforeScanRowHook() { 116 return m.strct.Addr().Interface().(schema.BeforeScanRowHook).BeforeScanRow(ctx) 117 } 118 if m.table.HasBeforeScanHook() { 119 return m.strct.Addr().Interface().(schema.BeforeScanHook).BeforeScan(ctx) 120 } 121 return nil 122 } 123 124 var _ schema.AfterScanRowHook = (*structTableModel)(nil) 125 126 func (m *structTableModel) AfterScanRow(ctx context.Context) error { 127 if !m.structInited { 128 return nil 129 } 130 131 if m.table.HasAfterScanRowHook() { 132 firstErr := m.strct.Addr().Interface().(schema.AfterScanRowHook).AfterScanRow(ctx) 133 134 for _, j := range m.joins { 135 switch j.Relation.Type { 136 case schema.HasOneRelation, schema.BelongsToRelation: 137 if err := j.JoinModel.AfterScanRow(ctx); err != nil && firstErr == nil { 138 firstErr = err 139 } 140 } 141 } 142 143 return firstErr 144 } 145 146 if m.table.HasAfterScanHook() { 147 firstErr := m.strct.Addr().Interface().(schema.AfterScanHook).AfterScan(ctx) 148 149 for _, j := range m.joins { 150 switch j.Relation.Type { 151 case schema.HasOneRelation, schema.BelongsToRelation: 152 if err := j.JoinModel.AfterScanRow(ctx); err != nil && firstErr == nil { 153 firstErr = err 154 } 155 } 156 } 157 158 return firstErr 159 } 160 161 return nil 162 } 163 164 func (m *structTableModel) getJoin(name string) *relationJoin { 165 for i := range m.joins { 166 j := &m.joins[i] 167 if j.Relation.Field.Name == name || j.Relation.Field.GoName == name { 168 return j 169 } 170 } 171 return nil 172 } 173 174 func (m *structTableModel) getJoins() []relationJoin { 175 return m.joins 176 } 177 178 func (m *structTableModel) addJoin(j relationJoin) *relationJoin { 179 m.joins = append(m.joins, j) 180 return &m.joins[len(m.joins)-1] 181 } 182 183 func (m *structTableModel) join(name string) *relationJoin { 184 return m._join(m.strct, name) 185 } 186 187 func (m *structTableModel) _join(bind reflect.Value, name string) *relationJoin { 188 path := strings.Split(name, ".") 189 index := make([]int, 0, len(path)) 190 191 currJoin := relationJoin{ 192 BaseModel: m, 193 JoinModel: m, 194 } 195 var lastJoin *relationJoin 196 197 for _, name := range path { 198 relation, ok := currJoin.JoinModel.Table().Relations[name] 199 if !ok { 200 return nil 201 } 202 203 currJoin.Relation = relation 204 index = append(index, relation.Field.Index...) 205 206 if j := currJoin.JoinModel.getJoin(name); j != nil { 207 currJoin.BaseModel = j.BaseModel 208 currJoin.JoinModel = j.JoinModel 209 210 lastJoin = j 211 } else { 212 model, err := newTableModelIndex(m.db, m.table, bind, index, relation) 213 if err != nil { 214 return nil 215 } 216 217 currJoin.Parent = lastJoin 218 currJoin.BaseModel = currJoin.JoinModel 219 currJoin.JoinModel = model 220 221 lastJoin = currJoin.BaseModel.addJoin(currJoin) 222 } 223 } 224 225 return lastJoin 226 } 227 228 func (m *structTableModel) rootValue() reflect.Value { 229 return m.root 230 } 231 232 func (m *structTableModel) parentIndex() []int { 233 return m.index[:len(m.index)-len(m.rel.Field.Index)] 234 } 235 236 func (m *structTableModel) mount(host reflect.Value) { 237 m.strct = host.FieldByIndex(m.rel.Field.Index) 238 m.structInited = false 239 } 240 241 func (m *structTableModel) updateSoftDeleteField(tm time.Time) error { 242 if !m.strct.IsValid() { 243 return nil 244 } 245 fv := m.table.SoftDeleteField.Value(m.strct) 246 return m.table.UpdateSoftDeleteField(fv, tm) 247 } 248 249 func (m *structTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { 250 if !rows.Next() { 251 return 0, rows.Err() 252 } 253 254 var n int 255 256 if err := m.ScanRow(ctx, rows); err != nil { 257 return 0, err 258 } 259 n++ 260 261 // And discard the rest. This is especially important for SQLite3, which can return 262 // a row like it was inserted sucessfully and then return an actual error for the next row. 263 // See issues/100. 264 for rows.Next() { 265 n++ 266 } 267 if err := rows.Err(); err != nil { 268 return 0, err 269 } 270 271 return n, nil 272 } 273 274 func (m *structTableModel) ScanRow(ctx context.Context, rows *sql.Rows) error { 275 columns, err := rows.Columns() 276 if err != nil { 277 return err 278 } 279 280 m.columns = columns 281 dest := makeDest(m, len(columns)) 282 283 return m.scanRow(ctx, rows, dest) 284 } 285 286 func (m *structTableModel) scanRow(ctx context.Context, rows *sql.Rows, dest []interface{}) error { 287 if err := m.BeforeScanRow(ctx); err != nil { 288 return err 289 } 290 291 m.scanIndex = 0 292 if err := rows.Scan(dest...); err != nil { 293 return err 294 } 295 296 if err := m.AfterScanRow(ctx); err != nil { 297 return err 298 } 299 300 return nil 301 } 302 303 func (m *structTableModel) Scan(src interface{}) error { 304 column := m.columns[m.scanIndex] 305 m.scanIndex++ 306 307 return m.ScanColumn(unquote(column), src) 308 } 309 310 func (m *structTableModel) ScanColumn(column string, src interface{}) error { 311 if ok, err := m.scanColumn(column, src); ok { 312 return err 313 } 314 if column == "" || column[0] == '_' || m.db.flags.Has(discardUnknownColumns) { 315 return nil 316 } 317 return fmt.Errorf("bun: %s does not have column %q", m.table.TypeName, column) 318 } 319 320 func (m *structTableModel) scanColumn(column string, src interface{}) (bool, error) { 321 if src != nil { 322 if err := m.initStruct(); err != nil { 323 return true, err 324 } 325 } 326 327 if field, ok := m.table.FieldMap[column]; ok { 328 if src == nil && m.isNil() { 329 return true, nil 330 } 331 return true, field.ScanValue(m.strct, src) 332 } 333 334 if joinName, column := splitColumn(column); joinName != "" { 335 if join := m.getJoin(joinName); join != nil { 336 return true, join.JoinModel.ScanColumn(column, src) 337 } 338 339 if m.table.ModelName == joinName { 340 return true, m.ScanColumn(column, src) 341 } 342 } 343 344 return false, nil 345 } 346 347 func (m *structTableModel) isNil() bool { 348 return m.strct.Kind() == reflect.Ptr && m.strct.IsNil() 349 } 350 351 func (m *structTableModel) AppendNamedArg( 352 fmter schema.Formatter, b []byte, name string, 353 ) ([]byte, bool) { 354 return m.table.AppendNamedArg(fmter, b, name, m.strct) 355 } 356 357 // sqlite3 sometimes does not unquote columns. 358 func unquote(s string) string { 359 if s == "" { 360 return s 361 } 362 if s[0] == '"' && s[len(s)-1] == '"' { 363 return s[1 : len(s)-1] 364 } 365 return s 366 } 367 368 func splitColumn(s string) (string, string) { 369 if i := strings.Index(s, "__"); i >= 0 { 370 return s[:i], s[i+2:] 371 } 372 return "", s 373 }