db.go (16474B)
1 package bun 2 3 import ( 4 "context" 5 "crypto/rand" 6 "database/sql" 7 "encoding/hex" 8 "fmt" 9 "reflect" 10 "strings" 11 "sync/atomic" 12 13 "github.com/uptrace/bun/dialect/feature" 14 "github.com/uptrace/bun/internal" 15 "github.com/uptrace/bun/schema" 16 ) 17 18 const ( 19 discardUnknownColumns internal.Flag = 1 << iota 20 ) 21 22 type DBStats struct { 23 Queries uint32 24 Errors uint32 25 } 26 27 type DBOption func(db *DB) 28 29 func WithDiscardUnknownColumns() DBOption { 30 return func(db *DB) { 31 db.flags = db.flags.Set(discardUnknownColumns) 32 } 33 } 34 35 type DB struct { 36 *sql.DB 37 38 dialect schema.Dialect 39 features feature.Feature 40 41 queryHooks []QueryHook 42 43 fmter schema.Formatter 44 flags internal.Flag 45 46 stats DBStats 47 } 48 49 func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB { 50 dialect.Init(sqldb) 51 52 db := &DB{ 53 DB: sqldb, 54 dialect: dialect, 55 features: dialect.Features(), 56 fmter: schema.NewFormatter(dialect), 57 } 58 59 for _, opt := range opts { 60 opt(db) 61 } 62 63 return db 64 } 65 66 func (db *DB) String() string { 67 var b strings.Builder 68 b.WriteString("DB<dialect=") 69 b.WriteString(db.dialect.Name().String()) 70 b.WriteString(">") 71 return b.String() 72 } 73 74 func (db *DB) DBStats() DBStats { 75 return DBStats{ 76 Queries: atomic.LoadUint32(&db.stats.Queries), 77 Errors: atomic.LoadUint32(&db.stats.Errors), 78 } 79 } 80 81 func (db *DB) NewValues(model interface{}) *ValuesQuery { 82 return NewValuesQuery(db, model) 83 } 84 85 func (db *DB) NewMerge() *MergeQuery { 86 return NewMergeQuery(db) 87 } 88 89 func (db *DB) NewSelect() *SelectQuery { 90 return NewSelectQuery(db) 91 } 92 93 func (db *DB) NewInsert() *InsertQuery { 94 return NewInsertQuery(db) 95 } 96 97 func (db *DB) NewUpdate() *UpdateQuery { 98 return NewUpdateQuery(db) 99 } 100 101 func (db *DB) NewDelete() *DeleteQuery { 102 return NewDeleteQuery(db) 103 } 104 105 func (db *DB) NewRaw(query string, args ...interface{}) *RawQuery { 106 return NewRawQuery(db, query, args...) 107 } 108 109 func (db *DB) NewCreateTable() *CreateTableQuery { 110 return NewCreateTableQuery(db) 111 } 112 113 func (db *DB) NewDropTable() *DropTableQuery { 114 return NewDropTableQuery(db) 115 } 116 117 func (db *DB) NewCreateIndex() *CreateIndexQuery { 118 return NewCreateIndexQuery(db) 119 } 120 121 func (db *DB) NewDropIndex() *DropIndexQuery { 122 return NewDropIndexQuery(db) 123 } 124 125 func (db *DB) NewTruncateTable() *TruncateTableQuery { 126 return NewTruncateTableQuery(db) 127 } 128 129 func (db *DB) NewAddColumn() *AddColumnQuery { 130 return NewAddColumnQuery(db) 131 } 132 133 func (db *DB) NewDropColumn() *DropColumnQuery { 134 return NewDropColumnQuery(db) 135 } 136 137 func (db *DB) ResetModel(ctx context.Context, models ...interface{}) error { 138 for _, model := range models { 139 if _, err := db.NewDropTable().Model(model).IfExists().Cascade().Exec(ctx); err != nil { 140 return err 141 } 142 if _, err := db.NewCreateTable().Model(model).Exec(ctx); err != nil { 143 return err 144 } 145 } 146 return nil 147 } 148 149 func (db *DB) Dialect() schema.Dialect { 150 return db.dialect 151 } 152 153 func (db *DB) ScanRows(ctx context.Context, rows *sql.Rows, dest ...interface{}) error { 154 defer rows.Close() 155 156 model, err := newModel(db, dest) 157 if err != nil { 158 return err 159 } 160 161 _, err = model.ScanRows(ctx, rows) 162 if err != nil { 163 return err 164 } 165 166 return rows.Err() 167 } 168 169 func (db *DB) ScanRow(ctx context.Context, rows *sql.Rows, dest ...interface{}) error { 170 model, err := newModel(db, dest) 171 if err != nil { 172 return err 173 } 174 175 rs, ok := model.(rowScanner) 176 if !ok { 177 return fmt.Errorf("bun: %T does not support ScanRow", model) 178 } 179 180 return rs.ScanRow(ctx, rows) 181 } 182 183 type queryHookIniter interface { 184 Init(db *DB) 185 } 186 187 func (db *DB) AddQueryHook(hook QueryHook) { 188 if initer, ok := hook.(queryHookIniter); ok { 189 initer.Init(db) 190 } 191 db.queryHooks = append(db.queryHooks, hook) 192 } 193 194 func (db *DB) Table(typ reflect.Type) *schema.Table { 195 return db.dialect.Tables().Get(typ) 196 } 197 198 // RegisterModel registers models by name so they can be referenced in table relations 199 // and fixtures. 200 func (db *DB) RegisterModel(models ...interface{}) { 201 db.dialect.Tables().Register(models...) 202 } 203 204 func (db *DB) clone() *DB { 205 clone := *db 206 207 l := len(clone.queryHooks) 208 clone.queryHooks = clone.queryHooks[:l:l] 209 210 return &clone 211 } 212 213 func (db *DB) WithNamedArg(name string, value interface{}) *DB { 214 clone := db.clone() 215 clone.fmter = clone.fmter.WithNamedArg(name, value) 216 return clone 217 } 218 219 func (db *DB) Formatter() schema.Formatter { 220 return db.fmter 221 } 222 223 // UpdateFQN returns a fully qualified column name. For MySQL, it returns the column name with 224 // the table alias. For other RDBMS, it returns just the column name. 225 func (db *DB) UpdateFQN(alias, column string) Ident { 226 if db.HasFeature(feature.UpdateMultiTable) { 227 return Ident(alias + "." + column) 228 } 229 return Ident(column) 230 } 231 232 // HasFeature uses feature package to report whether the underlying DBMS supports this feature. 233 func (db *DB) HasFeature(feat feature.Feature) bool { 234 return db.fmter.HasFeature(feat) 235 } 236 237 //------------------------------------------------------------------------------ 238 239 func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) { 240 return db.ExecContext(context.Background(), query, args...) 241 } 242 243 func (db *DB) ExecContext( 244 ctx context.Context, query string, args ...interface{}, 245 ) (sql.Result, error) { 246 formattedQuery := db.format(query, args) 247 ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) 248 res, err := db.DB.ExecContext(ctx, formattedQuery) 249 db.afterQuery(ctx, event, res, err) 250 return res, err 251 } 252 253 func (db *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { 254 return db.QueryContext(context.Background(), query, args...) 255 } 256 257 func (db *DB) QueryContext( 258 ctx context.Context, query string, args ...interface{}, 259 ) (*sql.Rows, error) { 260 formattedQuery := db.format(query, args) 261 ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) 262 rows, err := db.DB.QueryContext(ctx, formattedQuery) 263 db.afterQuery(ctx, event, nil, err) 264 return rows, err 265 } 266 267 func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row { 268 return db.QueryRowContext(context.Background(), query, args...) 269 } 270 271 func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 272 formattedQuery := db.format(query, args) 273 ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) 274 row := db.DB.QueryRowContext(ctx, formattedQuery) 275 db.afterQuery(ctx, event, nil, row.Err()) 276 return row 277 } 278 279 func (db *DB) format(query string, args []interface{}) string { 280 return db.fmter.FormatQuery(query, args...) 281 } 282 283 //------------------------------------------------------------------------------ 284 285 type Conn struct { 286 db *DB 287 *sql.Conn 288 } 289 290 func (db *DB) Conn(ctx context.Context) (Conn, error) { 291 conn, err := db.DB.Conn(ctx) 292 if err != nil { 293 return Conn{}, err 294 } 295 return Conn{ 296 db: db, 297 Conn: conn, 298 }, nil 299 } 300 301 func (c Conn) ExecContext( 302 ctx context.Context, query string, args ...interface{}, 303 ) (sql.Result, error) { 304 formattedQuery := c.db.format(query, args) 305 ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) 306 res, err := c.Conn.ExecContext(ctx, formattedQuery) 307 c.db.afterQuery(ctx, event, res, err) 308 return res, err 309 } 310 311 func (c Conn) QueryContext( 312 ctx context.Context, query string, args ...interface{}, 313 ) (*sql.Rows, error) { 314 formattedQuery := c.db.format(query, args) 315 ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) 316 rows, err := c.Conn.QueryContext(ctx, formattedQuery) 317 c.db.afterQuery(ctx, event, nil, err) 318 return rows, err 319 } 320 321 func (c Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 322 formattedQuery := c.db.format(query, args) 323 ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) 324 row := c.Conn.QueryRowContext(ctx, formattedQuery) 325 c.db.afterQuery(ctx, event, nil, row.Err()) 326 return row 327 } 328 329 func (c Conn) Dialect() schema.Dialect { 330 return c.db.Dialect() 331 } 332 333 func (c Conn) NewValues(model interface{}) *ValuesQuery { 334 return NewValuesQuery(c.db, model).Conn(c) 335 } 336 337 func (c Conn) NewMerge() *MergeQuery { 338 return NewMergeQuery(c.db).Conn(c) 339 } 340 341 func (c Conn) NewSelect() *SelectQuery { 342 return NewSelectQuery(c.db).Conn(c) 343 } 344 345 func (c Conn) NewInsert() *InsertQuery { 346 return NewInsertQuery(c.db).Conn(c) 347 } 348 349 func (c Conn) NewUpdate() *UpdateQuery { 350 return NewUpdateQuery(c.db).Conn(c) 351 } 352 353 func (c Conn) NewDelete() *DeleteQuery { 354 return NewDeleteQuery(c.db).Conn(c) 355 } 356 357 func (c Conn) NewRaw(query string, args ...interface{}) *RawQuery { 358 return NewRawQuery(c.db, query, args...).Conn(c) 359 } 360 361 func (c Conn) NewCreateTable() *CreateTableQuery { 362 return NewCreateTableQuery(c.db).Conn(c) 363 } 364 365 func (c Conn) NewDropTable() *DropTableQuery { 366 return NewDropTableQuery(c.db).Conn(c) 367 } 368 369 func (c Conn) NewCreateIndex() *CreateIndexQuery { 370 return NewCreateIndexQuery(c.db).Conn(c) 371 } 372 373 func (c Conn) NewDropIndex() *DropIndexQuery { 374 return NewDropIndexQuery(c.db).Conn(c) 375 } 376 377 func (c Conn) NewTruncateTable() *TruncateTableQuery { 378 return NewTruncateTableQuery(c.db).Conn(c) 379 } 380 381 func (c Conn) NewAddColumn() *AddColumnQuery { 382 return NewAddColumnQuery(c.db).Conn(c) 383 } 384 385 func (c Conn) NewDropColumn() *DropColumnQuery { 386 return NewDropColumnQuery(c.db).Conn(c) 387 } 388 389 // RunInTx runs the function in a transaction. If the function returns an error, 390 // the transaction is rolled back. Otherwise, the transaction is committed. 391 func (c Conn) RunInTx( 392 ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx Tx) error, 393 ) error { 394 tx, err := c.BeginTx(ctx, opts) 395 if err != nil { 396 return err 397 } 398 399 var done bool 400 401 defer func() { 402 if !done { 403 _ = tx.Rollback() 404 } 405 }() 406 407 if err := fn(ctx, tx); err != nil { 408 return err 409 } 410 411 done = true 412 return tx.Commit() 413 } 414 415 func (c Conn) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { 416 ctx, event := c.db.beforeQuery(ctx, nil, "BEGIN", nil, "BEGIN", nil) 417 tx, err := c.Conn.BeginTx(ctx, opts) 418 c.db.afterQuery(ctx, event, nil, err) 419 if err != nil { 420 return Tx{}, err 421 } 422 return Tx{ 423 ctx: ctx, 424 db: c.db, 425 Tx: tx, 426 }, nil 427 } 428 429 //------------------------------------------------------------------------------ 430 431 type Stmt struct { 432 *sql.Stmt 433 } 434 435 func (db *DB) Prepare(query string) (Stmt, error) { 436 return db.PrepareContext(context.Background(), query) 437 } 438 439 func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) { 440 stmt, err := db.DB.PrepareContext(ctx, query) 441 if err != nil { 442 return Stmt{}, err 443 } 444 return Stmt{Stmt: stmt}, nil 445 } 446 447 //------------------------------------------------------------------------------ 448 449 type Tx struct { 450 ctx context.Context 451 db *DB 452 // name is the name of a savepoint 453 name string 454 *sql.Tx 455 } 456 457 // RunInTx runs the function in a transaction. If the function returns an error, 458 // the transaction is rolled back. Otherwise, the transaction is committed. 459 func (db *DB) RunInTx( 460 ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx Tx) error, 461 ) error { 462 tx, err := db.BeginTx(ctx, opts) 463 if err != nil { 464 return err 465 } 466 467 var done bool 468 469 defer func() { 470 if !done { 471 _ = tx.Rollback() 472 } 473 }() 474 475 if err := fn(ctx, tx); err != nil { 476 return err 477 } 478 479 done = true 480 return tx.Commit() 481 } 482 483 func (db *DB) Begin() (Tx, error) { 484 return db.BeginTx(context.Background(), nil) 485 } 486 487 func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { 488 ctx, event := db.beforeQuery(ctx, nil, "BEGIN", nil, "BEGIN", nil) 489 tx, err := db.DB.BeginTx(ctx, opts) 490 db.afterQuery(ctx, event, nil, err) 491 if err != nil { 492 return Tx{}, err 493 } 494 return Tx{ 495 ctx: ctx, 496 db: db, 497 Tx: tx, 498 }, nil 499 } 500 501 func (tx Tx) Commit() error { 502 if tx.name == "" { 503 return tx.commitTX() 504 } 505 return tx.commitSP() 506 } 507 508 func (tx Tx) commitTX() error { 509 ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, "COMMIT", nil) 510 err := tx.Tx.Commit() 511 tx.db.afterQuery(ctx, event, nil, err) 512 return err 513 } 514 515 func (tx Tx) commitSP() error { 516 if tx.Dialect().Features().Has(feature.MSSavepoint) { 517 return nil 518 } 519 query := "RELEASE SAVEPOINT " + tx.name 520 _, err := tx.ExecContext(tx.ctx, query) 521 return err 522 } 523 524 func (tx Tx) Rollback() error { 525 if tx.name == "" { 526 return tx.rollbackTX() 527 } 528 return tx.rollbackSP() 529 } 530 531 func (tx Tx) rollbackTX() error { 532 ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, "ROLLBACK", nil) 533 err := tx.Tx.Rollback() 534 tx.db.afterQuery(ctx, event, nil, err) 535 return err 536 } 537 538 func (tx Tx) rollbackSP() error { 539 query := "ROLLBACK TO SAVEPOINT " + tx.name 540 if tx.Dialect().Features().Has(feature.MSSavepoint) { 541 query = "ROLLBACK TRANSACTION " + tx.name 542 } 543 _, err := tx.ExecContext(tx.ctx, query) 544 return err 545 } 546 547 func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) { 548 return tx.ExecContext(context.TODO(), query, args...) 549 } 550 551 func (tx Tx) ExecContext( 552 ctx context.Context, query string, args ...interface{}, 553 ) (sql.Result, error) { 554 formattedQuery := tx.db.format(query, args) 555 ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) 556 res, err := tx.Tx.ExecContext(ctx, formattedQuery) 557 tx.db.afterQuery(ctx, event, res, err) 558 return res, err 559 } 560 561 func (tx Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { 562 return tx.QueryContext(context.TODO(), query, args...) 563 } 564 565 func (tx Tx) QueryContext( 566 ctx context.Context, query string, args ...interface{}, 567 ) (*sql.Rows, error) { 568 formattedQuery := tx.db.format(query, args) 569 ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) 570 rows, err := tx.Tx.QueryContext(ctx, formattedQuery) 571 tx.db.afterQuery(ctx, event, nil, err) 572 return rows, err 573 } 574 575 func (tx Tx) QueryRow(query string, args ...interface{}) *sql.Row { 576 return tx.QueryRowContext(context.TODO(), query, args...) 577 } 578 579 func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 580 formattedQuery := tx.db.format(query, args) 581 ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) 582 row := tx.Tx.QueryRowContext(ctx, formattedQuery) 583 tx.db.afterQuery(ctx, event, nil, row.Err()) 584 return row 585 } 586 587 //------------------------------------------------------------------------------ 588 589 func (tx Tx) Begin() (Tx, error) { 590 return tx.BeginTx(tx.ctx, nil) 591 } 592 593 // BeginTx will save a point in the running transaction. 594 func (tx Tx) BeginTx(ctx context.Context, _ *sql.TxOptions) (Tx, error) { 595 // mssql savepoint names are limited to 32 characters 596 sp := make([]byte, 14) 597 _, err := rand.Read(sp) 598 if err != nil { 599 return Tx{}, err 600 } 601 602 qName := "SP_" + hex.EncodeToString(sp) 603 query := "SAVEPOINT " + qName 604 if tx.Dialect().Features().Has(feature.MSSavepoint) { 605 query = "SAVE TRANSACTION " + qName 606 } 607 _, err = tx.ExecContext(ctx, query) 608 if err != nil { 609 return Tx{}, err 610 } 611 return Tx{ 612 ctx: ctx, 613 db: tx.db, 614 Tx: tx.Tx, 615 name: qName, 616 }, nil 617 } 618 619 func (tx Tx) RunInTx( 620 ctx context.Context, _ *sql.TxOptions, fn func(ctx context.Context, tx Tx) error, 621 ) error { 622 sp, err := tx.BeginTx(ctx, nil) 623 if err != nil { 624 return err 625 } 626 627 var done bool 628 629 defer func() { 630 if !done { 631 _ = sp.Rollback() 632 } 633 }() 634 635 if err := fn(ctx, sp); err != nil { 636 return err 637 } 638 639 done = true 640 return sp.Commit() 641 } 642 643 func (tx Tx) Dialect() schema.Dialect { 644 return tx.db.Dialect() 645 } 646 647 func (tx Tx) NewValues(model interface{}) *ValuesQuery { 648 return NewValuesQuery(tx.db, model).Conn(tx) 649 } 650 651 func (tx Tx) NewMerge() *MergeQuery { 652 return NewMergeQuery(tx.db).Conn(tx) 653 } 654 655 func (tx Tx) NewSelect() *SelectQuery { 656 return NewSelectQuery(tx.db).Conn(tx) 657 } 658 659 func (tx Tx) NewInsert() *InsertQuery { 660 return NewInsertQuery(tx.db).Conn(tx) 661 } 662 663 func (tx Tx) NewUpdate() *UpdateQuery { 664 return NewUpdateQuery(tx.db).Conn(tx) 665 } 666 667 func (tx Tx) NewDelete() *DeleteQuery { 668 return NewDeleteQuery(tx.db).Conn(tx) 669 } 670 671 func (tx Tx) NewRaw(query string, args ...interface{}) *RawQuery { 672 return NewRawQuery(tx.db, query, args...).Conn(tx) 673 } 674 675 func (tx Tx) NewCreateTable() *CreateTableQuery { 676 return NewCreateTableQuery(tx.db).Conn(tx) 677 } 678 679 func (tx Tx) NewDropTable() *DropTableQuery { 680 return NewDropTableQuery(tx.db).Conn(tx) 681 } 682 683 func (tx Tx) NewCreateIndex() *CreateIndexQuery { 684 return NewCreateIndexQuery(tx.db).Conn(tx) 685 } 686 687 func (tx Tx) NewDropIndex() *DropIndexQuery { 688 return NewDropIndexQuery(tx.db).Conn(tx) 689 } 690 691 func (tx Tx) NewTruncateTable() *TruncateTableQuery { 692 return NewTruncateTableQuery(tx.db).Conn(tx) 693 } 694 695 func (tx Tx) NewAddColumn() *AddColumnQuery { 696 return NewAddColumnQuery(tx.db).Conn(tx) 697 } 698 699 func (tx Tx) NewDropColumn() *DropColumnQuery { 700 return NewDropColumnQuery(tx.db).Conn(tx) 701 } 702 703 //------------------------------------------------------------------------------ 704 705 func (db *DB) makeQueryBytes() []byte { 706 // TODO: make this configurable? 707 return make([]byte, 0, 4096) 708 }