commit 6a2d0d9508392425392738743fa4dd5b5d93bdb4 parent 71a4f8667c63c2ffbe58eee7b5e963d3e09a42de Author: kim (grufwub) <grufwub@gmail.com> Date: Wed, 8 Sep 2021 21:13:54 +0100 Merge remote-tracking branch 'upstream/main' into update/sqlite-library Signed-off-by: kim (grufwub) <grufwub@gmail.com> Diffstat:
35 files changed, 517 insertions(+), 395 deletions(-)
diff --git a/go.mod b/go.mod @@ -47,7 +47,7 @@ require ( github.com/superseriousbusiness/oauth2/v4 v4.3.0-SSB github.com/tdewolff/minify/v2 v2.9.21 github.com/tidwall/buntdb v1.2.4 // indirect - github.com/uptrace/bun v0.4.3 + github.com/uptrace/bun v1.0.4 github.com/uptrace/bun/dialect/pgdialect v0.4.3 github.com/uptrace/bun/dialect/sqlitedialect v0.4.3 github.com/urfave/cli/v2 v2.3.0 diff --git a/go.sum b/go.sum @@ -454,8 +454,9 @@ github.com/ugorji/go v1.2.6/go.mod h1:anCg0y61KIhDlPZmnH+so+RQbysYVyDko0IMgJv0Nn github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/ugorji/go/codec v1.2.6 h1:7kbGefxLoDBuYXOms4yD7223OpNMMPNPZxXk5TvFcyQ= github.com/ugorji/go/codec v1.2.6/go.mod h1:V6TCNZ4PHqoHGFZuSG1W8nrCzzdgA2DozYxWFFpvxTw= -github.com/uptrace/bun v0.4.3 h1:x6bjDqwjxwM/9Q1eauhkznuvTrz/rLiCK2p4tT63sAE= github.com/uptrace/bun v0.4.3/go.mod h1:aL6D9vPw8DXaTQTwGrEPtUderBYXx7ShUmPfnxnqscw= +github.com/uptrace/bun v1.0.4 h1:XKkddp+F5rbjyZCfEXPHc9ZEG3RE8VktO4HCcg5nzCQ= +github.com/uptrace/bun v1.0.4/go.mod h1:aL6D9vPw8DXaTQTwGrEPtUderBYXx7ShUmPfnxnqscw= github.com/uptrace/bun/dialect/pgdialect v0.4.3 h1:lM2IUKpU99110chKkupw3oTfXiOKpB0hTJIe6frqQDo= github.com/uptrace/bun/dialect/pgdialect v0.4.3/go.mod h1:BaNvWejl32oKUhwpFkw/eNcWldzIlVY4nfw/sNul0s8= github.com/uptrace/bun/dialect/sqlitedialect v0.4.3 h1:h+vqLGCeY22PFrbCOpQqK5+/p1qWCXYIhIUm/D5Vw08= diff --git a/vendor/github.com/uptrace/bun/CHANGELOG.md b/vendor/github.com/uptrace/bun/CHANGELOG.md @@ -1,5 +1,30 @@ # Changelog +## v1.0.4 - Sep 06 2021 + +- Added support for MariaDB. +- Restored default `SET` for `ON CONFLICT DO UPDATE` queries. + +## v1.0.3 - Sep 06 2021 + +- Fixed bulk soft deletes. +- pgdialect: fixed scanning into an array pointer. + +## v1.0.2 - Sep 04 2021 + +- Changed to completely ignore fields marked with `bun:"-"`. If you want to be able to scan into + such columns, use `bun:",scanonly"`. +- pgdriver: fixed SASL authentication handling. + +## v1.0.1 - Sep 02 2021 + +- pgdriver: added erroneous zero writes retry. +- Improved column handling in Relation callback. + +## v1.0.0 - Sep 01 2021 + +- First stable release. + ## v0.4.1 - Aug 18 2021 - Fixed migrate package to properly rollback migrations. diff --git a/vendor/github.com/uptrace/bun/CONTRIBUTING.md b/vendor/github.com/uptrace/bun/CONTRIBUTING.md @@ -0,0 +1,30 @@ +## Running tests + +To run tests, you need Docker which starts PostgreSQL and MySQL servers: + +```shell +cd internal/dbtest +./test.sh +``` + +## Releasing + +1. Run `release.sh` script which updates versions in go.mod files and pushes a new branch to GitHub: + +```shell +./scripts/release.sh -t v1.0.0 +``` + +2. Open a pull request and wait for the build to finish. + +3. Merge the pull request and run `tag.sh` to create tags for packages: + +```shell +./scripts/tag.sh -t v1.0.0 +``` + +4. Push the tags: + +```shell +git push origin --tags +``` diff --git a/vendor/github.com/uptrace/bun/README.md b/vendor/github.com/uptrace/bun/README.md @@ -4,17 +4,21 @@ </a> </p> -# Simple and performant SQL database client +# Simple and performant client for PostgreSQL, MySQL, and SQLite [![build workflow](https://github.com/uptrace/bun/actions/workflows/build.yml/badge.svg)](https://github.com/uptrace/bun/actions) [![PkgGoDev](https://pkg.go.dev/badge/github.com/uptrace/bun)](https://pkg.go.dev/github.com/uptrace/bun) [![Documentation](https://img.shields.io/badge/bun-documentation-informational)](https://bun.uptrace.dev/) [![Chat](https://discordapp.com/api/guilds/752070105847955518/widget.png)](https://discord.gg/rWtp5Aj) +**Status**: API freeze (stable release). Note that all sub-packages (mainly extra/\* packages) are +not part of the API freeze and are developed independently. You can think of them as 3-rd party +packages that share one repo with the core. + Main features are: - Works with [PostgreSQL](https://bun.uptrace.dev/guide/drivers.html#postgresql), - [MySQL](https://bun.uptrace.dev/guide/drivers.html#mysql), + [MySQL](https://bun.uptrace.dev/guide/drivers.html#mysql) (including MariaDB), [SQLite](https://bun.uptrace.dev/guide/drivers.html#sqlite). - [Selecting](/example/basic/) into a map, struct, slice of maps/structs/vars. - [Bulk inserts](https://bun.uptrace.dev/guide/queries.html#insert). @@ -96,7 +100,7 @@ You also need to install a database/sql driver and the corresponding Bun ## Quickstart First you need to create a `sql.DB`. Here we are using the -[sqliteshim](https://pkg.go.dev/github.com/uptrace/bun/driver/sqliteshim) driver which choses +[sqliteshim](https://pkg.go.dev/github.com/uptrace/bun/driver/sqliteshim) driver which chooses between [modernc.org/sqlite](https://modernc.org/sqlite/) and [mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) depending on your platform. @@ -109,7 +113,8 @@ if err != nil { } ``` -And then create a `bun.DB` on top of it using the corresponding SQLite dialect: +And then create a `bun.DB` on top of it using the corresponding SQLite +[dialect](https://bun.uptrace.dev/guide/drivers.html) that comes with Bun: ```go import ( diff --git a/vendor/github.com/uptrace/bun/RELEASING.md b/vendor/github.com/uptrace/bun/RELEASING.md @@ -1,21 +0,0 @@ -# Releasing - -1. Run `release.sh` script which updates versions in go.mod files and pushes a new branch to GitHub: - -```shell -./scripts/release.sh -t v1.0.0 -``` - -2. Open a pull request and wait for the build to finish. - -3. Merge the pull request and run `tag.sh` to create tags for packages: - -```shell -./scripts/tag.sh -t v1.0.0 -``` - -4. Push the tags: - -```shell -git push origin --tags -``` diff --git a/vendor/github.com/uptrace/bun/bun.go b/vendor/github.com/uptrace/bun/bun.go @@ -5,19 +5,17 @@ import ( "fmt" "reflect" + "github.com/uptrace/bun/internal" "github.com/uptrace/bun/schema" ) type ( Safe = schema.Safe Ident = schema.Ident -) - -type NullTime = schema.NullTime -type BaseModel = schema.BaseModel + NullTime = schema.NullTime + BaseModel = schema.BaseModel -type ( BeforeScanHook = schema.BeforeScanHook AfterScanHook = schema.AfterScanHook ) @@ -70,6 +68,11 @@ type AfterDropTableHook interface { AfterDropTable(ctx context.Context, query *DropTableQuery) error } +// SetLogger overwriters default Bun logger. +func SetLogger(logger internal.Logging) { + internal.Logger = logger +} + //------------------------------------------------------------------------------ type InValues struct { diff --git a/vendor/github.com/uptrace/bun/db.go b/vendor/github.com/uptrace/bun/db.go @@ -3,7 +3,6 @@ package bun import ( "context" "database/sql" - "errors" "fmt" "reflect" "strings" @@ -473,30 +472,9 @@ func (tx Tx) NewDropColumn() *DropColumnQuery { return NewDropColumnQuery(tx.db).Conn(tx) } -//------------------------------------------------------------------------------0 +//------------------------------------------------------------------------------ func (db *DB) makeQueryBytes() []byte { // TODO: make this configurable? return make([]byte, 0, 4096) } - -//------------------------------------------------------------------------------ - -type result struct { - r sql.Result - n int -} - -func (r result) RowsAffected() (int64, error) { - if r.r != nil { - return r.r.RowsAffected() - } - return int64(r.n), nil -} - -func (r result) LastInsertId() (int64, error) { - if r.r != nil { - return r.r.LastInsertId() - } - return 0, errors.New("LastInsertId is not available") -} diff --git a/vendor/github.com/uptrace/bun/dialect/dialect.go b/vendor/github.com/uptrace/bun/dialect/dialect.go @@ -8,10 +8,8 @@ func (n Name) String() string { return "pg" case SQLite: return "sqlite" - case MySQL5: - return "mysql5" - case MySQL8: - return "mysql8" + case MySQL: + return "mysql" default: return "invalid" } @@ -21,6 +19,5 @@ const ( Invalid Name = iota PG SQLite - MySQL5 - MySQL8 + MySQL ) diff --git a/vendor/github.com/uptrace/bun/dialect/feature/feature.go b/vendor/github.com/uptrace/bun/dialect/feature/feature.go @@ -4,10 +4,9 @@ import "github.com/uptrace/bun/internal" type Feature = internal.Flag -const DefaultFeatures = Returning | TableCascade - const ( - Returning Feature = 1 << iota + CTE Feature = 1 << iota + Returning DefaultPlaceholder DoubleColonCast ValuesRow diff --git a/vendor/github.com/uptrace/bun/go.sum b/vendor/github.com/uptrace/bun/go.sum @@ -20,4 +20,4 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBc gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +\ No newline at end of file diff --git a/vendor/github.com/uptrace/bun/join.go b/vendor/github.com/uptrace/bun/join.go @@ -8,18 +8,18 @@ import ( "github.com/uptrace/bun/schema" ) -type join struct { - Parent *join +type relationJoin struct { + Parent *relationJoin BaseModel tableModel JoinModel tableModel Relation *schema.Relation - ApplyQueryFunc func(*SelectQuery) *SelectQuery - columns []schema.QueryWithArgs + apply func(*SelectQuery) *SelectQuery + columns []schema.QueryWithArgs } -func (j *join) applyQuery(q *SelectQuery) { - if j.ApplyQueryFunc == nil { +func (j *relationJoin) applyTo(q *SelectQuery) { + if j.apply == nil { return } @@ -30,24 +30,20 @@ func (j *join) applyQuery(q *SelectQuery) { table, q.table = q.table, j.JoinModel.Table() columns, q.columns = q.columns, nil - q = j.ApplyQueryFunc(q) + q = j.apply(q) // Restore state. q.table = table j.columns, q.columns = q.columns, columns } -func (j *join) Select(ctx context.Context, q *SelectQuery) error { +func (j *relationJoin) Select(ctx context.Context, q *SelectQuery) error { switch j.Relation.Type { - case schema.HasManyRelation: - return j.selectMany(ctx, q) - case schema.ManyToManyRelation: - return j.selectM2M(ctx, q) } panic("not reached") } -func (j *join) selectMany(ctx context.Context, q *SelectQuery) error { +func (j *relationJoin) selectMany(ctx context.Context, q *SelectQuery) error { q = j.manyQuery(q) if q == nil { return nil @@ -55,7 +51,7 @@ func (j *join) selectMany(ctx context.Context, q *SelectQuery) error { return q.Scan(ctx) } -func (j *join) manyQuery(q *SelectQuery) *SelectQuery { +func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery { hasManyModel := newHasManyModel(j) if hasManyModel == nil { return nil @@ -86,13 +82,13 @@ func (j *join) manyQuery(q *SelectQuery) *SelectQuery { q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue) } - j.applyQuery(q) + j.applyTo(q) q = q.Apply(j.hasManyColumns) return q } -func (j *join) hasManyColumns(q *SelectQuery) *SelectQuery { +func (j *relationJoin) hasManyColumns(q *SelectQuery) *SelectQuery { if j.Relation.M2MTable != nil { q = q.ColumnExpr(string(j.Relation.M2MTable.SQLAlias) + ".*") } @@ -122,7 +118,7 @@ func (j *join) hasManyColumns(q *SelectQuery) *SelectQuery { return q } -func (j *join) selectM2M(ctx context.Context, q *SelectQuery) error { +func (j *relationJoin) selectM2M(ctx context.Context, q *SelectQuery) error { q = j.m2mQuery(q) if q == nil { return nil @@ -130,7 +126,7 @@ func (j *join) selectM2M(ctx context.Context, q *SelectQuery) error { return q.Scan(ctx) } -func (j *join) m2mQuery(q *SelectQuery) *SelectQuery { +func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery { fmter := q.db.fmter m2mModel := newM2MModel(j) @@ -170,13 +166,13 @@ func (j *join) m2mQuery(q *SelectQuery) *SelectQuery { j.Relation.M2MTable.SQLAlias, m2mJoinField.SQLName) } - j.applyQuery(q) + j.applyTo(q) q = q.Apply(j.hasManyColumns) return q } -func (j *join) hasParent() bool { +func (j *relationJoin) hasParent() bool { if j.Parent != nil { switch j.Parent.Relation.Type { case schema.HasOneRelation, schema.BelongsToRelation: @@ -186,7 +182,7 @@ func (j *join) hasParent() bool { return false } -func (j *join) appendAlias(fmter schema.Formatter, b []byte) []byte { +func (j *relationJoin) appendAlias(fmter schema.Formatter, b []byte) []byte { quote := fmter.IdentQuote() b = append(b, quote) @@ -195,7 +191,7 @@ func (j *join) appendAlias(fmter schema.Formatter, b []byte) []byte { return b } -func (j *join) appendAliasColumn(fmter schema.Formatter, b []byte, column string) []byte { +func (j *relationJoin) appendAliasColumn(fmter schema.Formatter, b []byte, column string) []byte { quote := fmter.IdentQuote() b = append(b, quote) @@ -206,7 +202,7 @@ func (j *join) appendAliasColumn(fmter schema.Formatter, b []byte, column string return b } -func (j *join) appendBaseAlias(fmter schema.Formatter, b []byte) []byte { +func (j *relationJoin) appendBaseAlias(fmter schema.Formatter, b []byte) []byte { quote := fmter.IdentQuote() if j.hasParent() { @@ -218,7 +214,7 @@ func (j *join) appendBaseAlias(fmter schema.Formatter, b []byte) []byte { return append(b, j.BaseModel.Table().SQLAlias...) } -func (j *join) appendSoftDelete(b []byte, flags internal.Flag) []byte { +func (j *relationJoin) appendSoftDelete(b []byte, flags internal.Flag) []byte { b = append(b, '.') b = append(b, j.JoinModel.Table().SoftDeleteField.SQLName...) if flags.Has(deletedFlag) { @@ -229,7 +225,7 @@ func (j *join) appendSoftDelete(b []byte, flags internal.Flag) []byte { return b } -func appendAlias(b []byte, j *join) []byte { +func appendAlias(b []byte, j *relationJoin) []byte { if j.hasParent() { b = appendAlias(b, j.Parent) b = append(b, "__"...) @@ -238,7 +234,7 @@ func appendAlias(b []byte, j *join) []byte { return b } -func (j *join) appendHasOneJoin( +func (j *relationJoin) appendHasOneJoin( fmter schema.Formatter, b []byte, q *SelectQuery, ) (_ []byte, err error) { isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag) diff --git a/vendor/github.com/uptrace/bun/migrate/migrator.go b/vendor/github.com/uptrace/bun/migrate/migrator.go @@ -115,6 +115,7 @@ func (m *Migrator) Reset(ctx context.Context) error { return m.Init(ctx) } +// Migrate runs unapplied migrations. If a migration fails, migrate immediately exits. func (m *Migrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) { cfg := newMigrationConfig(opts) @@ -146,7 +147,7 @@ func (m *Migrator) Migrate(ctx context.Context, opts ...MigrationOption) (*Migra if !cfg.nop && migration.Up != nil { if err := migration.Up(ctx, m.db); err != nil { - return nil, err + return group, err } } diff --git a/vendor/github.com/uptrace/bun/model.go b/vendor/github.com/uptrace/bun/model.go @@ -38,16 +38,16 @@ type tableModel interface { Table() *schema.Table Relation() *schema.Relation - Join(string, func(*SelectQuery) *SelectQuery) *join - GetJoin(string) *join - GetJoins() []join - AddJoin(join) *join + Join(string) *relationJoin + GetJoin(string) *relationJoin + GetJoins() []relationJoin + AddJoin(relationJoin) *relationJoin Root() reflect.Value ParentIndex() []int Mount(reflect.Value) - updateSoftDeleteField() error + updateSoftDeleteField(time.Time) error } func newModel(db *DB, dest []interface{}) (model, error) { diff --git a/vendor/github.com/uptrace/bun/model_table_has_many.go b/vendor/github.com/uptrace/bun/model_table_has_many.go @@ -21,7 +21,7 @@ type hasManyModel struct { var _ tableModel = (*hasManyModel)(nil) -func newHasManyModel(j *join) *hasManyModel { +func newHasManyModel(j *relationJoin) *hasManyModel { baseTable := j.BaseModel.Table() joinModel := j.JoinModel.(*sliceTableModel) baseValues := baseValues(joinModel, j.Relation.BaseFields) diff --git a/vendor/github.com/uptrace/bun/model_table_m2m.go b/vendor/github.com/uptrace/bun/model_table_m2m.go @@ -21,7 +21,7 @@ type m2mModel struct { var _ tableModel = (*m2mModel)(nil) -func newM2MModel(j *join) *m2mModel { +func newM2MModel(j *relationJoin) *m2mModel { baseTable := j.BaseModel.Table() joinModel := j.JoinModel.(*sliceTableModel) baseValues := baseValues(joinModel, baseTable.PKs) diff --git a/vendor/github.com/uptrace/bun/model_table_slice.go b/vendor/github.com/uptrace/bun/model_table_slice.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "reflect" + "time" "github.com/uptrace/bun/schema" ) @@ -45,8 +46,8 @@ func (m *sliceTableModel) init(sliceType reflect.Type) { } } -func (m *sliceTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join { - return m.join(m.slice, name, apply) +func (m *sliceTableModel) Join(name string) *relationJoin { + return m.join(m.slice, name) } func (m *sliceTableModel) Bind(bind reflect.Value) { @@ -100,12 +101,12 @@ var ( _ schema.AfterScanHook = (*sliceTableModel)(nil) ) -func (m *sliceTableModel) updateSoftDeleteField() error { +func (m *sliceTableModel) updateSoftDeleteField(tm time.Time) error { sliceLen := m.slice.Len() for i := 0; i < sliceLen; i++ { strct := indirect(m.slice.Index(i)) fv := m.table.SoftDeleteField.Value(strct) - if err := m.table.UpdateSoftDeleteField(fv); err != nil { + if err := m.table.UpdateSoftDeleteField(fv, tm); err != nil { return err } } diff --git a/vendor/github.com/uptrace/bun/model_table_struct.go b/vendor/github.com/uptrace/bun/model_table_struct.go @@ -6,8 +6,8 @@ import ( "fmt" "reflect" "strings" + "time" - "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/schema" ) @@ -16,7 +16,7 @@ type structTableModel struct { table *schema.Table rel *schema.Relation - joins []join + joins []relationJoin dest interface{} root reflect.Value @@ -151,7 +151,7 @@ func (m *structTableModel) AfterScan(ctx context.Context) error { return firstErr } -func (m *structTableModel) GetJoin(name string) *join { +func (m *structTableModel) GetJoin(name string) *relationJoin { for i := range m.joins { j := &m.joins[i] if j.Relation.Field.Name == name || j.Relation.Field.GoName == name { @@ -161,30 +161,28 @@ func (m *structTableModel) GetJoin(name string) *join { return nil } -func (m *structTableModel) GetJoins() []join { +func (m *structTableModel) GetJoins() []relationJoin { return m.joins } -func (m *structTableModel) AddJoin(j join) *join { +func (m *structTableModel) AddJoin(j relationJoin) *relationJoin { m.joins = append(m.joins, j) return &m.joins[len(m.joins)-1] } -func (m *structTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join { - return m.join(m.strct, name, apply) +func (m *structTableModel) Join(name string) *relationJoin { + return m.join(m.strct, name) } -func (m *structTableModel) join( - bind reflect.Value, name string, apply func(*SelectQuery) *SelectQuery, -) *join { +func (m *structTableModel) join(bind reflect.Value, name string) *relationJoin { path := strings.Split(name, ".") index := make([]int, 0, len(path)) - currJoin := join{ + currJoin := relationJoin{ BaseModel: m, JoinModel: m, } - var lastJoin *join + var lastJoin *relationJoin for _, name := range path { relation, ok := currJoin.JoinModel.Table().Relations[name] @@ -214,20 +212,12 @@ func (m *structTableModel) join( } } - // No joins with such name. - if lastJoin == nil { - return nil - } - if apply != nil { - lastJoin.ApplyQueryFunc = apply - } - return lastJoin } -func (m *structTableModel) updateSoftDeleteField() error { +func (m *structTableModel) updateSoftDeleteField(tm time.Time) error { fv := m.table.SoftDeleteField.Value(m.strct) - return m.table.UpdateSoftDeleteField(fv) + return m.table.UpdateSoftDeleteField(fv, tm) } func (m *structTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { @@ -235,20 +225,24 @@ func (m *structTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, e return 0, rows.Err() } + var n int + if err := m.ScanRow(ctx, rows); err != nil { return 0, err } + n++ - // For inserts, SQLite3 can return a row like it was inserted sucessfully and then - // an actual error for the next row. See issues/100. - if m.db.dialect.Name() == dialect.SQLite { - _ = rows.Next() - if err := rows.Err(); err != nil { - return 0, err - } + // And discard the rest. This is especially important for SQLite3, which can return + // a row like it was inserted sucessfully and then return an actual error for the next row. + // See issues/100. + for rows.Next() { + n++ + } + if err := rows.Err(); err != nil { + return 0, err } - return 1, nil + return n, nil } func (m *structTableModel) ScanRow(ctx context.Context, rows *sql.Rows) error { @@ -305,6 +299,9 @@ func (m *structTableModel) scanColumn(column string, src interface{}) (bool, err } if field, ok := m.table.FieldMap[column]; ok { + if src == nil && m.isNil() { + return true, nil + } return true, field.ScanValue(m.strct, src) } @@ -312,6 +309,7 @@ func (m *structTableModel) scanColumn(column string, src interface{}) (bool, err if join := m.GetJoin(joinName); join != nil { return true, join.JoinModel.ScanColumn(column, src) } + if m.table.ModelName == joinName { return true, m.ScanColumn(column, src) } @@ -320,6 +318,10 @@ func (m *structTableModel) scanColumn(column string, src interface{}) (bool, err return false, nil } +func (m *structTableModel) isNil() bool { + return m.strct.Kind() == reflect.Ptr && m.strct.IsNil() +} + func (m *structTableModel) AppendNamedArg( fmter schema.Formatter, b []byte, name string, ) ([]byte, bool) { diff --git a/vendor/github.com/uptrace/bun/query_base.go b/vendor/github.com/uptrace/bun/query_base.go @@ -3,6 +3,7 @@ package bun import ( "context" "database/sql" + "database/sql/driver" "errors" "fmt" @@ -262,7 +263,10 @@ func (q *baseQuery) _excludeColumn(column string) bool { //------------------------------------------------------------------------------ func (q *baseQuery) modelHasTableName() bool { - return !q.modelTable.IsZero() || q.table != nil + if !q.modelTable.IsZero() { + return q.modelTable.Query != "" + } + return q.table != nil } func (q *baseQuery) hasTables() bool { @@ -387,18 +391,10 @@ func (q *baseQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, e } func (q *baseQuery) getFields() ([]*schema.Field, error) { - table := q.tableModel.Table() - if len(q.columns) == 0 { - return table.Fields, nil + return q.table.Fields, nil } - - fields, err := q._getFields(false) - if err != nil { - return nil, err - } - - return fields, nil + return q._getFields(false) } func (q *baseQuery) getDataFields() ([]*schema.Field, error) { @@ -435,28 +431,28 @@ func (q *baseQuery) scan( query string, model model, hasDest bool, -) (res result, _ error) { +) (sql.Result, error) { ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil) rows, err := q.conn.QueryContext(ctx, query) if err != nil { q.db.afterQuery(ctx, event, nil, err) - return res, err + return nil, err } defer rows.Close() - n, err := model.ScanRows(ctx, rows) + numRow, err := model.ScanRows(ctx, rows) if err != nil { q.db.afterQuery(ctx, event, nil, err) - return res, err + return nil, err } - res.n = n - if n == 0 && hasDest && isSingleRowModel(model) { + if numRow == 0 && hasDest && isSingleRowModel(model) { err = sql.ErrNoRows } - q.db.afterQuery(ctx, event, nil, err) + res := driver.RowsAffected(numRow) + q.db.afterQuery(ctx, event, res, err) return res, err } @@ -465,18 +461,16 @@ func (q *baseQuery) exec( ctx context.Context, queryApp schema.QueryAppender, query string, -) (res result, _ error) { +) (sql.Result, error) { ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil) - r, err := q.conn.ExecContext(ctx, query) + res, err := q.conn.ExecContext(ctx, query) if err != nil { q.db.afterQuery(ctx, event, nil, err) return res, err } - res.r = r - - q.db.afterQuery(ctx, event, nil, err) + q.db.afterQuery(ctx, event, res, err) return res, nil } @@ -556,10 +550,12 @@ func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep) return } - where[0].Sep = "" + q.addWhere(schema.SafeQueryWithSep("", nil, sep)) + q.addWhere(schema.SafeQueryWithSep("", nil, "(")) - q.addWhere(schema.SafeQueryWithSep("", nil, sep+"(")) + where[0].Sep = "" q.where = append(q.where, where...) + q.addWhere(schema.SafeQueryWithSep("", nil, ")")) } @@ -623,11 +619,11 @@ func appendWhere( fmter schema.Formatter, b []byte, where []schema.QueryWithSep, ) (_ []byte, err error) { for i, where := range where { - if i > 0 || where.Sep == "(" { + if i > 0 { b = append(b, where.Sep...) } - if where.Query == "" && where.Args == nil { + if where.Query == "" { continue } diff --git a/vendor/github.com/uptrace/bun/query_delete.go b/vendor/github.com/uptrace/bun/query_delete.go @@ -3,6 +3,7 @@ package bun import ( "context" "database/sql" + "time" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/internal" @@ -135,15 +136,18 @@ func (q *DeleteQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e fmter = formatterWithModel(fmter, q) if q.isSoftDelete() { - if err := q.tableModel.updateSoftDeleteField(); err != nil { + now := time.Now() + + if err := q.tableModel.updateSoftDeleteField(now); err != nil { return nil, err } - upd := UpdateQuery{ + upd := &UpdateQuery{ whereBaseQuery: q.whereBaseQuery, returningQuery: q.returningQuery, } - upd.Column(q.table.SoftDeleteField.Name) + upd.Set(q.softDeleteSet(fmter, now)) + return upd.AppendQuery(fmter, b) } @@ -193,6 +197,18 @@ func (q *DeleteQuery) isSoftDelete() bool { return q.tableModel != nil && q.table.SoftDeleteField != nil && !q.flags.Has(forceDeleteFlag) } +func (q *DeleteQuery) softDeleteSet(fmter schema.Formatter, tm time.Time) string { + b := make([]byte, 0, 32) + if fmter.HasFeature(feature.UpdateMultiTable) { + b = append(b, q.table.SQLAlias...) + b = append(b, '.') + } + b = append(b, q.table.SoftDeleteField.SQLName...) + b = append(b, " = "...) + b = q.db.Dialect().Append(fmter, b, tm) + return internal.String(b) +} + //------------------------------------------------------------------------------ func (q *DeleteQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { diff --git a/vendor/github.com/uptrace/bun/query_insert.go b/vendor/github.com/uptrace/bun/query_insert.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "reflect" + "strings" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/internal" @@ -16,7 +17,7 @@ type InsertQuery struct { returningQuery customValueQuery - onConflict schema.QueryWithArgs + on schema.QueryWithArgs setQuery ignore bool @@ -88,13 +89,13 @@ func (q *InsertQuery) ExcludeColumn(columns ...string) *InsertQuery { return q } -// Value overwrites model value for the column in INSERT and UPDATE queries. -func (q *InsertQuery) Value(column string, value string, args ...interface{}) *InsertQuery { +// Value overwrites model value for the column. +func (q *InsertQuery) Value(column string, expr string, args ...interface{}) *InsertQuery { if q.table == nil { q.err = errNilModel return q } - q.addValue(q.table, column, value, args) + q.addValue(q.table, column, expr, args) return q } @@ -162,7 +163,7 @@ func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e } b = append(b, "INTO "...) - if q.db.features.Has(feature.InsertTableAlias) && !q.onConflict.IsZero() { + if q.db.features.Has(feature.InsertTableAlias) && !q.on.IsZero() { b, err = q.appendFirstTableWithAlias(fmter, b) } else { b, err = q.appendFirstTable(fmter, b) @@ -382,7 +383,7 @@ func (q *InsertQuery) appendFields( //------------------------------------------------------------------------------ func (q *InsertQuery) On(s string, args ...interface{}) *InsertQuery { - q.onConflict = schema.SafeQuery(s, args) + q.on = schema.SafeQuery(s, args) return q } @@ -392,12 +393,12 @@ func (q *InsertQuery) Set(query string, args ...interface{}) *InsertQuery { } func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err error) { - if q.onConflict.IsZero() { + if q.on.IsZero() { return b, nil } b = append(b, " ON "...) - b, err = q.onConflict.AppendQuery(fmter, b) + b, err = q.on.AppendQuery(fmter, b) if err != nil { return nil, err } @@ -413,7 +414,7 @@ func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err if err != nil { return nil, err } - } else if len(q.columns) > 0 { + } else if q.onConflictDoUpdate() { fields, err := q.getDataFields() if err != nil { return nil, err @@ -434,6 +435,10 @@ func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err return b, nil } +func (q *InsertQuery) onConflictDoUpdate() bool { + return strings.HasSuffix(strings.ToUpper(q.on.Query), " DO UPDATE") +} + func (q *InsertQuery) appendSetExcluded(b []byte, fields []*schema.Field) []byte { b = append(b, " SET "...) for i, f := range fields { diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go @@ -286,41 +286,38 @@ func (q *SelectQuery) joinOn(cond string, args []interface{}, sep string) *Selec //------------------------------------------------------------------------------ -// Relation adds a relation to the query. Relation name can be: -// - RelationName to select all columns, -// - RelationName.column_name, -// - RelationName._ to join relation without selecting relation columns. +// Relation adds a relation to the query. func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQuery) *SelectQuery { + if len(apply) > 1 { + panic("only one apply function is supported") + } + if q.tableModel == nil { q.setErr(errNilModel) return q } - var fn func(*SelectQuery) *SelectQuery - - if len(apply) == 1 { - fn = apply[0] - } else if len(apply) > 1 { - panic("only one apply function is supported") - } - - join := q.tableModel.Join(name, fn) + join := q.tableModel.Join(name) if join == nil { q.setErr(fmt.Errorf("%s does not have relation=%q", q.table, name)) return q } + if len(apply) == 1 { + join.apply = apply[0] + } + return q } -func (q *SelectQuery) forEachHasOneJoin(fn func(*join) error) error { +func (q *SelectQuery) forEachHasOneJoin(fn func(*relationJoin) error) error { if q.tableModel == nil { return nil } return q._forEachHasOneJoin(fn, q.tableModel.GetJoins()) } -func (q *SelectQuery) _forEachHasOneJoin(fn func(*join) error, joins []join) error { +func (q *SelectQuery) _forEachHasOneJoin(fn func(*relationJoin) error, joins []relationJoin) error { for i := range joins { j := &joins[i] switch j.Relation.Type { @@ -336,16 +333,23 @@ func (q *SelectQuery) _forEachHasOneJoin(fn func(*join) error, joins []join) err return nil } -func (q *SelectQuery) selectJoins(ctx context.Context, joins []join) error { - var err error +func (q *SelectQuery) selectJoins(ctx context.Context, joins []relationJoin) error { for i := range joins { j := &joins[i] + + var err error + switch j.Relation.Type { case schema.HasOneRelation, schema.BelongsToRelation: err = q.selectJoins(ctx, j.JoinModel.GetJoins()) + case schema.HasManyRelation: + err = j.selectMany(ctx, q.db.NewSelect()) + case schema.ManyToManyRelation: + err = j.selectM2M(ctx, q.db.NewSelect()) default: - err = j.Select(ctx, q.db.NewSelect()) + panic("not reached") } + if err != nil { return err } @@ -415,7 +419,7 @@ func (q *SelectQuery) appendQuery( } } - if err := q.forEachHasOneJoin(func(j *join) error { + if err := q.forEachHasOneJoin(func(j *relationJoin) error { b = append(b, ' ') b, err = j.appendHasOneJoin(fmter, b, q) return err @@ -545,13 +549,13 @@ func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, b = append(b, '*') } - if err := q.forEachHasOneJoin(func(j *join) error { + if err := q.forEachHasOneJoin(func(join *relationJoin) error { if len(b) != start { b = append(b, ", "...) start = len(b) } - b, err = q.appendHasOneColumns(fmter, b, j) + b, err = q.appendHasOneColumns(fmter, b, join) if err != nil { return err } @@ -567,18 +571,19 @@ func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, } func (q *SelectQuery) appendHasOneColumns( - fmter schema.Formatter, b []byte, join *join, + fmter schema.Formatter, b []byte, join *relationJoin, ) (_ []byte, err error) { - join.applyQuery(q) + join.applyTo(q) if join.columns != nil { + table := join.JoinModel.Table() for i, col := range join.columns { if i > 0 { b = append(b, ", "...) } if col.Args == nil { - if field, ok := q.table.FieldMap[col.Query]; ok { + if field, ok := table.FieldMap[col.Query]; ok { b = join.appendAlias(fmter, b) b = append(b, '.') b = append(b, field.SQLName...) @@ -691,7 +696,7 @@ func (q *SelectQuery) Scan(ctx context.Context, dest ...interface{}) error { return err } - if res.n > 0 { + if n, _ := res.RowsAffected(); n > 0 { if tableModel, ok := model.(tableModel); ok { if err := q.selectJoins(ctx, tableModel.GetJoins()); err != nil { return err diff --git a/vendor/github.com/uptrace/bun/query_update.go b/vendor/github.com/uptrace/bun/query_update.go @@ -90,13 +90,13 @@ func (q *UpdateQuery) Set(query string, args ...interface{}) *UpdateQuery { return q } -// Value overwrites model value for the column in INSERT and UPDATE queries. -func (q *UpdateQuery) Value(column string, value string, args ...interface{}) *UpdateQuery { +// Value overwrites model value for the column. +func (q *UpdateQuery) Value(column string, expr string, args ...interface{}) *UpdateQuery { if q.table == nil { q.err = errNilModel return q } - q.addValue(q.table, column, value, args) + q.addValue(q.table, column, expr, args) return q } @@ -321,20 +321,36 @@ func (q *UpdateQuery) Bulk() *UpdateQuery { return q } - return q.With("_data", q.db.NewValues(model)). + set, err := q.updateSliceSet(q.db.fmter, model) + if err != nil { + q.setErr(err) + return q + } + + values := q.db.NewValues(model) + values.customValueQuery = q.customValueQuery + + return q.With("_data", values). Model(model). TableExpr("_data"). - Set(q.updateSliceSet(model)). + Set(set). Where(q.updateSliceWhere(model)) } -func (q *UpdateQuery) updateSliceSet(model *sliceTableModel) string { +func (q *UpdateQuery) updateSliceSet( + fmter schema.Formatter, model *sliceTableModel, +) (string, error) { + fields, err := q.getDataFields() + if err != nil { + return "", err + } + var b []byte - for i, field := range model.table.DataFields { + for i, field := range fields { if i > 0 { b = append(b, ", "...) } - if q.db.fmter.HasFeature(feature.UpdateMultiTable) { + if fmter.HasFeature(feature.UpdateMultiTable) { b = append(b, model.table.SQLAlias...) b = append(b, '.') } @@ -342,7 +358,7 @@ func (q *UpdateQuery) updateSliceSet(model *sliceTableModel) string { b = append(b, " = _data."...) b = append(b, field.SQLName...) } - return internal.String(b) + return internal.String(b), nil } func (db *UpdateQuery) updateSliceWhere(model *sliceTableModel) string { diff --git a/vendor/github.com/uptrace/bun/query_values.go b/vendor/github.com/uptrace/bun/query_values.go @@ -34,6 +34,16 @@ func (q *ValuesQuery) Conn(db IConn) *ValuesQuery { return q } +// Value overwrites model value for the column. +func (q *ValuesQuery) Value(column string, expr string, args ...interface{}) *ValuesQuery { + if q.table == nil { + q.err = errNilModel + return q + } + q.addValue(q.table, column, expr, args) + return q +} + func (q *ValuesQuery) WithOrder() *ValuesQuery { q.withOrder = true return q diff --git a/vendor/github.com/uptrace/bun/schema/append_value.go b/vendor/github.com/uptrace/bun/schema/append_value.go @@ -2,7 +2,6 @@ package schema import ( "database/sql/driver" - "encoding/json" "fmt" "net" "reflect" @@ -14,16 +13,6 @@ import ( "github.com/uptrace/bun/internal" ) -var ( - timeType = reflect.TypeOf((*time.Time)(nil)).Elem() - ipType = reflect.TypeOf((*net.IP)(nil)).Elem() - ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() - jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() - - driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() - queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem() -) - type ( AppenderFunc func(fmter Formatter, b []byte, v reflect.Value) []byte CustomAppender func(typ reflect.Type) AppenderFunc @@ -60,6 +49,8 @@ var appenders = []AppenderFunc{ func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc { switch typ { + case bytesType: + return appendBytesValue case timeType: return appendTimeValue case ipType: @@ -93,7 +84,9 @@ func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc { case reflect.Interface: return ifaceAppenderFunc(typ, custom) case reflect.Ptr: - return ptrAppenderFunc(typ, custom) + if fn := Appender(typ.Elem(), custom); fn != nil { + return PtrAppender(fn) + } case reflect.Slice: if typ.Elem().Kind() == reflect.Uint8 { return appendBytesValue @@ -123,13 +116,12 @@ func ifaceAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) } } -func ptrAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc { - appender := Appender(typ.Elem(), custom) +func PtrAppender(fn AppenderFunc) AppenderFunc { return func(fmter Formatter, b []byte, v reflect.Value) []byte { if v.IsNil() { return dialect.AppendNull(b) } - return appender(fmter, b, v.Elem()) + return fn(fmter, b, v.Elem()) } } diff --git a/vendor/github.com/uptrace/bun/schema/formatter.go b/vendor/github.com/uptrace/bun/schema/formatter.go @@ -89,10 +89,10 @@ func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) [] func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte { var namedArgs NamedArgAppender if len(args) == 1 { - var ok bool - namedArgs, ok = args[0].(NamedArgAppender) - if !ok { - namedArgs, _ = newStructArgs(f, args[0]) + if v, ok := args[0].(NamedArgAppender); ok { + namedArgs = v + } else if v, ok := newStructArgs(f, args[0]); ok { + namedArgs = v } } diff --git a/vendor/github.com/uptrace/bun/schema/reflect.go b/vendor/github.com/uptrace/bun/schema/reflect.go @@ -0,0 +1,70 @@ +package schema + +import ( + "database/sql/driver" + "encoding/json" + "net" + "reflect" + "time" +) + +var ( + bytesType = reflect.TypeOf((*[]byte)(nil)).Elem() + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + ipType = reflect.TypeOf((*net.IP)(nil)).Elem() + ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() + + driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem() +) + +func indirectType(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) { + if len(index) == 1 { + return v.Field(index[0]), true + } + + for i, idx := range index { + if i > 0 { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return v, false + } + v = v.Elem() + } + } + v = v.Field(idx) + } + return v, true +} + +func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value { + if len(index) == 1 { + return v.Field(index[0]) + } + + for i, idx := range index { + if i > 0 { + v = indirectNil(v) + } + v = v.Field(idx) + } + return v +} + +func indirectNil(v reflect.Value) reflect.Value { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + return v +} diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go @@ -7,10 +7,12 @@ import ( "net" "reflect" "strconv" + "strings" "time" "github.com/vmihailenco/msgpack/v5" + "github.com/uptrace/bun/dialect/sqltype" "github.com/uptrace/bun/extra/bunjson" "github.com/uptrace/bun/internal" ) @@ -19,32 +21,35 @@ var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() type ScannerFunc func(dest reflect.Value, src interface{}) error -var scanners = []ScannerFunc{ - reflect.Bool: scanBool, - reflect.Int: scanInt64, - reflect.Int8: scanInt64, - reflect.Int16: scanInt64, - reflect.Int32: scanInt64, - reflect.Int64: scanInt64, - reflect.Uint: scanUint64, - reflect.Uint8: scanUint64, - reflect.Uint16: scanUint64, - reflect.Uint32: scanUint64, - reflect.Uint64: scanUint64, - reflect.Uintptr: scanUint64, - reflect.Float32: scanFloat64, - reflect.Float64: scanFloat64, - reflect.Complex64: nil, - reflect.Complex128: nil, - reflect.Array: nil, - reflect.Chan: nil, - reflect.Func: nil, - reflect.Map: scanJSON, - reflect.Ptr: nil, - reflect.Slice: scanJSON, - reflect.String: scanString, - reflect.Struct: scanJSON, - reflect.UnsafePointer: nil, +var scanners []ScannerFunc + +func init() { + scanners = []ScannerFunc{ + reflect.Bool: scanBool, + reflect.Int: scanInt64, + reflect.Int8: scanInt64, + reflect.Int16: scanInt64, + reflect.Int32: scanInt64, + reflect.Int64: scanInt64, + reflect.Uint: scanUint64, + reflect.Uint8: scanUint64, + reflect.Uint16: scanUint64, + reflect.Uint32: scanUint64, + reflect.Uint64: scanUint64, + reflect.Uintptr: scanUint64, + reflect.Float32: scanFloat64, + reflect.Float64: scanFloat64, + reflect.Complex64: nil, + reflect.Complex128: nil, + reflect.Array: nil, + reflect.Interface: scanInterface, + reflect.Map: scanJSON, + reflect.Ptr: nil, + reflect.Slice: scanJSON, + reflect.String: scanString, + reflect.Struct: scanJSON, + reflect.UnsafePointer: nil, + } } func FieldScanner(dialect Dialect, field *Field) ScannerFunc { @@ -54,6 +59,12 @@ func FieldScanner(dialect Dialect, field *Field) ScannerFunc { if field.Tag.HasOption("json_use_number") { return scanJSONUseNumber } + if field.StructField.Type.Kind() == reflect.Interface { + switch strings.ToUpper(field.UserSQLType) { + case sqltype.JSON, sqltype.JSONB: + return scanJSONIntoInterface + } + } return dialect.Scanner(field.StructField.Type) } @@ -62,7 +73,7 @@ func Scanner(typ reflect.Type) ScannerFunc { if kind == reflect.Ptr { if fn := Scanner(typ.Elem()); fn != nil { - return ptrScanner(fn) + return PtrScanner(fn) } } @@ -84,6 +95,8 @@ func Scanner(typ reflect.Type) ScannerFunc { return scanIP case ipNetType: return scanIPNet + case bytesType: + return scanBytes case jsonRawMessageType: return scanJSONRawMessage } @@ -196,6 +209,21 @@ func scanString(dest reflect.Value, src interface{}) error { return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) } +func scanBytes(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetBytes(nil) + return nil + case string: + dest.SetBytes([]byte(src)) + return nil + case []byte: + dest.SetBytes(src) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + func scanTime(dest reflect.Value, src interface{}) error { switch src := src.(type) { case nil: @@ -352,7 +380,7 @@ func toBytes(src interface{}) ([]byte, error) { } } -func ptrScanner(fn ScannerFunc) ScannerFunc { +func PtrScanner(fn ScannerFunc) ScannerFunc { return func(dest reflect.Value, src interface{}) error { if src == nil { if !dest.CanAddr() { @@ -383,6 +411,43 @@ func scanNull(dest reflect.Value) error { return nil } +func scanJSONIntoInterface(dest reflect.Value, src interface{}) error { + if dest.IsNil() { + if src == nil { + return nil + } + + b, err := toBytes(src) + if err != nil { + return err + } + + return bunjson.Unmarshal(b, dest.Addr().Interface()) + } + + dest = dest.Elem() + if fn := Scanner(dest.Type()); fn != nil { + return fn(dest, src) + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanInterface(dest reflect.Value, src interface{}) error { + if dest.IsNil() { + if src == nil { + return nil + } + dest.Set(reflect.ValueOf(src)) + return nil + } + + dest = dest.Elem() + if fn := Scanner(dest.Type()); fn != nil { + return fn(dest, src) + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + func nilable(kind reflect.Kind) bool { switch kind { case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: diff --git a/vendor/github.com/uptrace/bun/schema/sqlfmt.go b/vendor/github.com/uptrace/bun/schema/sqlfmt.go @@ -40,7 +40,7 @@ type QueryWithArgs struct { var _ QueryAppender = QueryWithArgs{} func SafeQuery(query string, args []interface{}) QueryWithArgs { - if query != "" && args == nil { + if args == nil { args = make([]interface{}, 0) } return QueryWithArgs{Query: query, Args: args} diff --git a/vendor/github.com/uptrace/bun/schema/sqltype.go b/vendor/github.com/uptrace/bun/schema/sqltype.go @@ -23,32 +23,29 @@ var ( ) var sqlTypes = []string{ - reflect.Bool: sqltype.Boolean, - reflect.Int: sqltype.BigInt, - reflect.Int8: sqltype.SmallInt, - reflect.Int16: sqltype.SmallInt, - reflect.Int32: sqltype.Integer, - reflect.Int64: sqltype.BigInt, - reflect.Uint: sqltype.BigInt, - reflect.Uint8: sqltype.SmallInt, - reflect.Uint16: sqltype.SmallInt, - reflect.Uint32: sqltype.Integer, - reflect.Uint64: sqltype.BigInt, - reflect.Uintptr: sqltype.BigInt, - reflect.Float32: sqltype.Real, - reflect.Float64: sqltype.DoublePrecision, - reflect.Complex64: "", - reflect.Complex128: "", - reflect.Array: "", - reflect.Chan: "", - reflect.Func: "", - reflect.Interface: "", - reflect.Map: sqltype.VarChar, - reflect.Ptr: "", - reflect.Slice: sqltype.VarChar, - reflect.String: sqltype.VarChar, - reflect.Struct: sqltype.VarChar, - reflect.UnsafePointer: "", + reflect.Bool: sqltype.Boolean, + reflect.Int: sqltype.BigInt, + reflect.Int8: sqltype.SmallInt, + reflect.Int16: sqltype.SmallInt, + reflect.Int32: sqltype.Integer, + reflect.Int64: sqltype.BigInt, + reflect.Uint: sqltype.BigInt, + reflect.Uint8: sqltype.SmallInt, + reflect.Uint16: sqltype.SmallInt, + reflect.Uint32: sqltype.Integer, + reflect.Uint64: sqltype.BigInt, + reflect.Uintptr: sqltype.BigInt, + reflect.Float32: sqltype.Real, + reflect.Float64: sqltype.DoublePrecision, + reflect.Complex64: "", + reflect.Complex128: "", + reflect.Array: "", + reflect.Interface: "", + reflect.Map: sqltype.VarChar, + reflect.Ptr: "", + reflect.Slice: sqltype.VarChar, + reflect.String: sqltype.VarChar, + reflect.Struct: sqltype.VarChar, } func DiscoverSQLType(typ reflect.Type) string { diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go @@ -60,10 +60,9 @@ type Table struct { Unique map[string][]*Field SoftDeleteField *Field - UpdateSoftDeleteField func(fv reflect.Value) error + UpdateSoftDeleteField func(fv reflect.Value, tm time.Time) error - allFields []*Field // read only - skippedFields []*Field + allFields []*Field // read only flags internal.Flag } @@ -104,9 +103,7 @@ func (t *Table) init1() { } func (t *Table) init2() { - t.initInlines() t.initRelations() - t.skippedFields = nil } func (t *Table) setName(name string) { @@ -207,15 +204,20 @@ func (t *Table) initFields() { func (t *Table) addFields(typ reflect.Type, baseIndex []int) { for i := 0; i < typ.NumField(); i++ { f := typ.Field(i) + unexported := f.PkgPath != "" - // Make a copy so slice is not shared between fields. + if unexported && !f.Anonymous { // unexported + continue + } + if f.Tag.Get("bun") == "-" { + continue + } + + // Make a copy so the slice is not shared between fields. index := make([]int, len(baseIndex)) copy(index, baseIndex) if f.Anonymous { - if f.Tag.Get("bun") == "-" { - continue - } if f.Name == "BaseModel" && f.Type == baseModelType { if len(index) == 0 { t.processBaseModelField(f) @@ -243,8 +245,7 @@ func (t *Table) addFields(typ reflect.Type, baseIndex []int) { continue } - field := t.newField(f, index) - if field != nil { + if field := t.newField(f, index); field != nil { t.addField(field) } } @@ -284,11 +285,10 @@ func (t *Table) processBaseModelField(f reflect.StructField) { func (t *Table) newField(f reflect.StructField, index []int) *Field { tag := tagparser.Parse(f.Tag.Get("bun")) - if f.PkgPath != "" { - return nil - } - sqlName := internal.Underscore(f.Name) + if tag.Name != "" { + sqlName = tag.Name + } if tag.Name != sqlName && isKnownFieldOption(tag.Name) { internal.Warn.Printf( @@ -303,11 +303,6 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field { } } - skip := tag.Name == "-" - if !skip && tag.Name != "" { - sqlName = tag.Name - } - index = append(index, f.Index...) if field := t.fieldWithLock(sqlName); field != nil { if indexEqual(field.Index, index) { @@ -371,9 +366,11 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field { } t.allFields = append(t.allFields, field) - if skip { - t.skippedFields = append(t.skippedFields, field) + if tag.HasOption("scanonly") { t.FieldMap[field.Name] = field + if field.IndirectType.Kind() == reflect.Struct { + t.inlineFields(field, nil) + } return nil } @@ -386,14 +383,6 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field { return field } -func (t *Table) initInlines() { - for _, f := range t.skippedFields { - if f.IndirectType.Kind() == reflect.Struct { - t.inlineFields(f, nil) - } - } -} - //--------------------------------------------------------------------------------------- func (t *Table) initRelations() { @@ -745,17 +734,15 @@ func (t *Table) m2mRelation(field *Field) *Relation { return rel } -func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) { - if path == nil { - path = map[reflect.Type]struct{}{ - t.Type: {}, - } +func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) { + if seen == nil { + seen = map[reflect.Type]struct{}{t.Type: {}} } - if _, ok := path[field.IndirectType]; ok { + if _, ok := seen[field.IndirectType]; ok { return } - path[field.IndirectType] = struct{}{} + seen[field.IndirectType] = struct{}{} joinTable := t.dialect.Tables().Ref(field.IndirectType) for _, f := range joinTable.allFields { @@ -775,18 +762,15 @@ func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) { continue } - if _, ok := path[f.IndirectType]; !ok { - t.inlineFields(f, path) + if _, ok := seen[f.IndirectType]; !ok { + t.inlineFields(f, seen) } } } //------------------------------------------------------------------------------ -func (t *Table) Dialect() Dialect { return t.dialect } - -//------------------------------------------------------------------------------ - +func (t *Table) Dialect() Dialect { return t.dialect } func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) } func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) } @@ -845,6 +829,7 @@ func isKnownFieldOption(name string) bool { "default", "unique", "soft_delete", + "scanonly", "pk", "autoincrement", @@ -883,35 +868,35 @@ func parseRelationJoin(join string) ([]string, []string) { //------------------------------------------------------------------------------ -func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error { +func softDeleteFieldUpdater(field *Field) func(fv reflect.Value, tm time.Time) error { typ := field.StructField.Type switch typ { case timeType: - return func(fv reflect.Value) error { + return func(fv reflect.Value, tm time.Time) error { ptr := fv.Addr().Interface().(*time.Time) - *ptr = time.Now() + *ptr = tm return nil } case nullTimeType: - return func(fv reflect.Value) error { + return func(fv reflect.Value, tm time.Time) error { ptr := fv.Addr().Interface().(*sql.NullTime) - *ptr = sql.NullTime{Time: time.Now()} + *ptr = sql.NullTime{Time: tm} return nil } case nullIntType: - return func(fv reflect.Value) error { + return func(fv reflect.Value, tm time.Time) error { ptr := fv.Addr().Interface().(*sql.NullInt64) - *ptr = sql.NullInt64{Int64: time.Now().UnixNano()} + *ptr = sql.NullInt64{Int64: tm.UnixNano()} return nil } } switch field.IndirectType.Kind() { case reflect.Int64: - return func(fv reflect.Value) error { + return func(fv reflect.Value, tm time.Time) error { ptr := fv.Addr().Interface().(*int64) - *ptr = time.Now().UnixNano() + *ptr = tm.UnixNano() return nil } case reflect.Ptr: @@ -922,17 +907,16 @@ func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error { switch typ { //nolint:gocritic case timeType: - return func(fv reflect.Value) error { - now := time.Now() - fv.Set(reflect.ValueOf(&now)) + return func(fv reflect.Value, tm time.Time) error { + fv.Set(reflect.ValueOf(&tm)) return nil } } switch typ.Kind() { //nolint:gocritic case reflect.Int64: - return func(fv reflect.Value) error { - utime := time.Now().UnixNano() + return func(fv reflect.Value, tm time.Time) error { + utime := tm.UnixNano() fv.Set(reflect.ValueOf(&utime)) return nil } @@ -941,8 +925,8 @@ func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error { return softDeleteFieldUpdaterFallback(field) } -func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value) error { - return func(fv reflect.Value) error { - return field.ScanWithCheck(fv, time.Now()) +func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value, tm time.Time) error { + return func(fv reflect.Value, tm time.Time) error { + return field.ScanWithCheck(fv, tm) } } diff --git a/vendor/github.com/uptrace/bun/schema/tables.go b/vendor/github.com/uptrace/bun/schema/tables.go @@ -67,6 +67,7 @@ func (t *Tables) Ref(typ reflect.Type) *Table { } func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table { + typ = indirectType(typ) if typ.Kind() != reflect.Struct { panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct)) } diff --git a/vendor/github.com/uptrace/bun/schema/util.go b/vendor/github.com/uptrace/bun/schema/util.go @@ -1,53 +0,0 @@ -package schema - -import "reflect" - -func indirectType(t reflect.Type) reflect.Type { - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - return t -} - -func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) { - if len(index) == 1 { - return v.Field(index[0]), true - } - - for i, idx := range index { - if i > 0 { - if v.Kind() == reflect.Ptr { - if v.IsNil() { - return v, false - } - v = v.Elem() - } - } - v = v.Field(idx) - } - return v, true -} - -func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value { - if len(index) == 1 { - return v.Field(index[0]) - } - - for i, idx := range index { - if i > 0 { - v = indirectNil(v) - } - v = v.Field(idx) - } - return v -} - -func indirectNil(v reflect.Value) reflect.Value { - if v.Kind() == reflect.Ptr { - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - v = v.Elem() - } - return v -} diff --git a/vendor/github.com/uptrace/bun/version.go b/vendor/github.com/uptrace/bun/version.go @@ -2,5 +2,5 @@ package bun // Version is the current release version. func Version() string { - return "0.4.3" + return "1.0.4" } diff --git a/vendor/modules.txt b/vendor/modules.txt @@ -394,7 +394,7 @@ github.com/tdewolff/parse/v2/strconv github.com/tmthrgd/go-hex # github.com/ugorji/go/codec v1.2.6 github.com/ugorji/go/codec -# github.com/uptrace/bun v0.4.3 +# github.com/uptrace/bun v1.0.4 ## explicit github.com/uptrace/bun github.com/uptrace/bun/dialect