gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

query_merge.go (7316B)


      1 package bun
      2 
      3 import (
      4 	"context"
      5 	"database/sql"
      6 	"errors"
      7 
      8 	"github.com/uptrace/bun/dialect"
      9 	"github.com/uptrace/bun/dialect/feature"
     10 	"github.com/uptrace/bun/internal"
     11 	"github.com/uptrace/bun/schema"
     12 )
     13 
     14 type MergeQuery struct {
     15 	baseQuery
     16 	returningQuery
     17 
     18 	using schema.QueryWithArgs
     19 	on    schema.QueryWithArgs
     20 	when  []schema.QueryAppender
     21 }
     22 
     23 var _ Query = (*MergeQuery)(nil)
     24 
     25 func NewMergeQuery(db *DB) *MergeQuery {
     26 	q := &MergeQuery{
     27 		baseQuery: baseQuery{
     28 			db:   db,
     29 			conn: db.DB,
     30 		},
     31 	}
     32 	if !(q.db.dialect.Name() == dialect.MSSQL || q.db.dialect.Name() == dialect.PG) {
     33 		q.err = errors.New("bun: merge not supported for current dialect")
     34 	}
     35 	return q
     36 }
     37 
     38 func (q *MergeQuery) Conn(db IConn) *MergeQuery {
     39 	q.setConn(db)
     40 	return q
     41 }
     42 
     43 func (q *MergeQuery) Model(model interface{}) *MergeQuery {
     44 	q.setModel(model)
     45 	return q
     46 }
     47 
     48 func (q *MergeQuery) Err(err error) *MergeQuery {
     49 	q.setErr(err)
     50 	return q
     51 }
     52 
     53 // Apply calls the fn passing the MergeQuery as an argument.
     54 func (q *MergeQuery) Apply(fn func(*MergeQuery) *MergeQuery) *MergeQuery {
     55 	if fn != nil {
     56 		return fn(q)
     57 	}
     58 	return q
     59 }
     60 
     61 func (q *MergeQuery) With(name string, query schema.QueryAppender) *MergeQuery {
     62 	q.addWith(name, query, false)
     63 	return q
     64 }
     65 
     66 func (q *MergeQuery) WithRecursive(name string, query schema.QueryAppender) *MergeQuery {
     67 	q.addWith(name, query, true)
     68 	return q
     69 }
     70 
     71 //------------------------------------------------------------------------------
     72 
     73 func (q *MergeQuery) Table(tables ...string) *MergeQuery {
     74 	for _, table := range tables {
     75 		q.addTable(schema.UnsafeIdent(table))
     76 	}
     77 	return q
     78 }
     79 
     80 func (q *MergeQuery) TableExpr(query string, args ...interface{}) *MergeQuery {
     81 	q.addTable(schema.SafeQuery(query, args))
     82 	return q
     83 }
     84 
     85 func (q *MergeQuery) ModelTableExpr(query string, args ...interface{}) *MergeQuery {
     86 	q.modelTableName = schema.SafeQuery(query, args)
     87 	return q
     88 }
     89 
     90 //------------------------------------------------------------------------------
     91 
     92 // Returning adds a RETURNING clause to the query.
     93 //
     94 // To suppress the auto-generated RETURNING clause, use `Returning("NULL")`.
     95 // Only for mssql output, postgres not supported returning in merge query
     96 func (q *MergeQuery) Returning(query string, args ...interface{}) *MergeQuery {
     97 	q.addReturning(schema.SafeQuery(query, args))
     98 	return q
     99 }
    100 
    101 //------------------------------------------------------------------------------
    102 
    103 func (q *MergeQuery) Using(s string, args ...interface{}) *MergeQuery {
    104 	q.using = schema.SafeQuery(s, args)
    105 	return q
    106 }
    107 
    108 func (q *MergeQuery) On(s string, args ...interface{}) *MergeQuery {
    109 	q.on = schema.SafeQuery(s, args)
    110 	return q
    111 }
    112 
    113 // WhenInsert for when insert clause.
    114 func (q *MergeQuery) WhenInsert(expr string, fn func(q *InsertQuery) *InsertQuery) *MergeQuery {
    115 	sq := NewInsertQuery(q.db)
    116 	// apply the model as default into sub query, since appendColumnsValues required
    117 	if q.model != nil {
    118 		sq = sq.Model(q.model)
    119 	}
    120 	sq = sq.Apply(fn)
    121 	q.when = append(q.when, &whenInsert{expr: expr, query: sq})
    122 	return q
    123 }
    124 
    125 // WhenUpdate for when update clause.
    126 func (q *MergeQuery) WhenUpdate(expr string, fn func(q *UpdateQuery) *UpdateQuery) *MergeQuery {
    127 	sq := NewUpdateQuery(q.db)
    128 	// apply the model as default into sub query
    129 	if q.model != nil {
    130 		sq = sq.Model(q.model)
    131 	}
    132 	sq = sq.Apply(fn)
    133 	q.when = append(q.when, &whenUpdate{expr: expr, query: sq})
    134 	return q
    135 }
    136 
    137 // WhenDelete for when delete clause.
    138 func (q *MergeQuery) WhenDelete(expr string) *MergeQuery {
    139 	q.when = append(q.when, &whenDelete{expr: expr})
    140 	return q
    141 }
    142 
    143 // When for raw expression clause.
    144 func (q *MergeQuery) When(expr string, args ...interface{}) *MergeQuery {
    145 	q.when = append(q.when, schema.SafeQuery(expr, args))
    146 	return q
    147 }
    148 
    149 //------------------------------------------------------------------------------
    150 
    151 func (q *MergeQuery) Operation() string {
    152 	return "MERGE"
    153 }
    154 
    155 func (q *MergeQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
    156 	if q.err != nil {
    157 		return nil, q.err
    158 	}
    159 
    160 	fmter = formatterWithModel(fmter, q)
    161 
    162 	b, err = q.appendWith(fmter, b)
    163 	if err != nil {
    164 		return nil, err
    165 	}
    166 
    167 	b = append(b, "MERGE "...)
    168 	if q.db.dialect.Name() == dialect.PG {
    169 		b = append(b, "INTO "...)
    170 	}
    171 
    172 	b, err = q.appendFirstTableWithAlias(fmter, b)
    173 	if err != nil {
    174 		return nil, err
    175 	}
    176 
    177 	b = append(b, " USING "...)
    178 	b, err = q.using.AppendQuery(fmter, b)
    179 	if err != nil {
    180 		return nil, err
    181 	}
    182 
    183 	b = append(b, " ON "...)
    184 	b, err = q.on.AppendQuery(fmter, b)
    185 	if err != nil {
    186 		return nil, err
    187 	}
    188 
    189 	for _, w := range q.when {
    190 		b = append(b, " WHEN "...)
    191 		b, err = w.AppendQuery(fmter, b)
    192 		if err != nil {
    193 			return nil, err
    194 		}
    195 	}
    196 
    197 	if q.hasFeature(feature.Output) && q.hasReturning() {
    198 		b = append(b, " OUTPUT "...)
    199 		b, err = q.appendOutput(fmter, b)
    200 		if err != nil {
    201 			return nil, err
    202 		}
    203 	}
    204 
    205 	// A MERGE statement must be terminated by a semi-colon (;).
    206 	b = append(b, ";"...)
    207 
    208 	return b, nil
    209 }
    210 
    211 //------------------------------------------------------------------------------
    212 
    213 func (q *MergeQuery) Scan(ctx context.Context, dest ...interface{}) error {
    214 	_, err := q.scanOrExec(ctx, dest, true)
    215 	return err
    216 }
    217 
    218 func (q *MergeQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
    219 	return q.scanOrExec(ctx, dest, len(dest) > 0)
    220 }
    221 
    222 func (q *MergeQuery) scanOrExec(
    223 	ctx context.Context, dest []interface{}, hasDest bool,
    224 ) (sql.Result, error) {
    225 	if q.err != nil {
    226 		return nil, q.err
    227 	}
    228 
    229 	// Run append model hooks before generating the query.
    230 	if err := q.beforeAppendModel(ctx, q); err != nil {
    231 		return nil, err
    232 	}
    233 
    234 	// Generate the query before checking hasReturning.
    235 	queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
    236 	if err != nil {
    237 		return nil, err
    238 	}
    239 
    240 	useScan := hasDest || (q.hasReturning() && q.hasFeature(feature.InsertReturning|feature.Output))
    241 	var model Model
    242 
    243 	if useScan {
    244 		var err error
    245 		model, err = q.getModel(dest)
    246 		if err != nil {
    247 			return nil, err
    248 		}
    249 	}
    250 
    251 	query := internal.String(queryBytes)
    252 	var res sql.Result
    253 
    254 	if useScan {
    255 		res, err = q.scan(ctx, q, query, model, true)
    256 		if err != nil {
    257 			return nil, err
    258 		}
    259 	} else {
    260 		res, err = q.exec(ctx, q, query)
    261 		if err != nil {
    262 			return nil, err
    263 		}
    264 	}
    265 
    266 	return res, nil
    267 }
    268 
    269 func (q *MergeQuery) String() string {
    270 	buf, err := q.AppendQuery(q.db.Formatter(), nil)
    271 	if err != nil {
    272 		panic(err)
    273 	}
    274 
    275 	return string(buf)
    276 }
    277 
    278 //------------------------------------------------------------------------------
    279 
    280 type whenInsert struct {
    281 	expr  string
    282 	query *InsertQuery
    283 }
    284 
    285 func (w *whenInsert) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
    286 	b = append(b, w.expr...)
    287 	if w.query != nil {
    288 		b = append(b, " THEN INSERT"...)
    289 		b, err = w.query.appendColumnsValues(fmter, b, true)
    290 		if err != nil {
    291 			return nil, err
    292 		}
    293 	}
    294 	return b, nil
    295 }
    296 
    297 type whenUpdate struct {
    298 	expr  string
    299 	query *UpdateQuery
    300 }
    301 
    302 func (w *whenUpdate) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
    303 	b = append(b, w.expr...)
    304 	if w.query != nil {
    305 		b = append(b, " THEN UPDATE SET "...)
    306 		b, err = w.query.appendSet(fmter, b)
    307 		if err != nil {
    308 			return nil, err
    309 		}
    310 	}
    311 	return b, nil
    312 }
    313 
    314 type whenDelete struct {
    315 	expr string
    316 }
    317 
    318 func (w *whenDelete) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
    319 	b = append(b, w.expr...)
    320 	b = append(b, " THEN DELETE"...)
    321 	return b, nil
    322 }