query_merge.go (7316B)
1 package bun 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 8 "github.com/uptrace/bun/dialect" 9 "github.com/uptrace/bun/dialect/feature" 10 "github.com/uptrace/bun/internal" 11 "github.com/uptrace/bun/schema" 12 ) 13 14 type MergeQuery struct { 15 baseQuery 16 returningQuery 17 18 using schema.QueryWithArgs 19 on schema.QueryWithArgs 20 when []schema.QueryAppender 21 } 22 23 var _ Query = (*MergeQuery)(nil) 24 25 func NewMergeQuery(db *DB) *MergeQuery { 26 q := &MergeQuery{ 27 baseQuery: baseQuery{ 28 db: db, 29 conn: db.DB, 30 }, 31 } 32 if !(q.db.dialect.Name() == dialect.MSSQL || q.db.dialect.Name() == dialect.PG) { 33 q.err = errors.New("bun: merge not supported for current dialect") 34 } 35 return q 36 } 37 38 func (q *MergeQuery) Conn(db IConn) *MergeQuery { 39 q.setConn(db) 40 return q 41 } 42 43 func (q *MergeQuery) Model(model interface{}) *MergeQuery { 44 q.setModel(model) 45 return q 46 } 47 48 func (q *MergeQuery) Err(err error) *MergeQuery { 49 q.setErr(err) 50 return q 51 } 52 53 // Apply calls the fn passing the MergeQuery as an argument. 54 func (q *MergeQuery) Apply(fn func(*MergeQuery) *MergeQuery) *MergeQuery { 55 if fn != nil { 56 return fn(q) 57 } 58 return q 59 } 60 61 func (q *MergeQuery) With(name string, query schema.QueryAppender) *MergeQuery { 62 q.addWith(name, query, false) 63 return q 64 } 65 66 func (q *MergeQuery) WithRecursive(name string, query schema.QueryAppender) *MergeQuery { 67 q.addWith(name, query, true) 68 return q 69 } 70 71 //------------------------------------------------------------------------------ 72 73 func (q *MergeQuery) Table(tables ...string) *MergeQuery { 74 for _, table := range tables { 75 q.addTable(schema.UnsafeIdent(table)) 76 } 77 return q 78 } 79 80 func (q *MergeQuery) TableExpr(query string, args ...interface{}) *MergeQuery { 81 q.addTable(schema.SafeQuery(query, args)) 82 return q 83 } 84 85 func (q *MergeQuery) ModelTableExpr(query string, args ...interface{}) *MergeQuery { 86 q.modelTableName = schema.SafeQuery(query, args) 87 return q 88 } 89 90 //------------------------------------------------------------------------------ 91 92 // Returning adds a RETURNING clause to the query. 93 // 94 // To suppress the auto-generated RETURNING clause, use `Returning("NULL")`. 95 // Only for mssql output, postgres not supported returning in merge query 96 func (q *MergeQuery) Returning(query string, args ...interface{}) *MergeQuery { 97 q.addReturning(schema.SafeQuery(query, args)) 98 return q 99 } 100 101 //------------------------------------------------------------------------------ 102 103 func (q *MergeQuery) Using(s string, args ...interface{}) *MergeQuery { 104 q.using = schema.SafeQuery(s, args) 105 return q 106 } 107 108 func (q *MergeQuery) On(s string, args ...interface{}) *MergeQuery { 109 q.on = schema.SafeQuery(s, args) 110 return q 111 } 112 113 // WhenInsert for when insert clause. 114 func (q *MergeQuery) WhenInsert(expr string, fn func(q *InsertQuery) *InsertQuery) *MergeQuery { 115 sq := NewInsertQuery(q.db) 116 // apply the model as default into sub query, since appendColumnsValues required 117 if q.model != nil { 118 sq = sq.Model(q.model) 119 } 120 sq = sq.Apply(fn) 121 q.when = append(q.when, &whenInsert{expr: expr, query: sq}) 122 return q 123 } 124 125 // WhenUpdate for when update clause. 126 func (q *MergeQuery) WhenUpdate(expr string, fn func(q *UpdateQuery) *UpdateQuery) *MergeQuery { 127 sq := NewUpdateQuery(q.db) 128 // apply the model as default into sub query 129 if q.model != nil { 130 sq = sq.Model(q.model) 131 } 132 sq = sq.Apply(fn) 133 q.when = append(q.when, &whenUpdate{expr: expr, query: sq}) 134 return q 135 } 136 137 // WhenDelete for when delete clause. 138 func (q *MergeQuery) WhenDelete(expr string) *MergeQuery { 139 q.when = append(q.when, &whenDelete{expr: expr}) 140 return q 141 } 142 143 // When for raw expression clause. 144 func (q *MergeQuery) When(expr string, args ...interface{}) *MergeQuery { 145 q.when = append(q.when, schema.SafeQuery(expr, args)) 146 return q 147 } 148 149 //------------------------------------------------------------------------------ 150 151 func (q *MergeQuery) Operation() string { 152 return "MERGE" 153 } 154 155 func (q *MergeQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { 156 if q.err != nil { 157 return nil, q.err 158 } 159 160 fmter = formatterWithModel(fmter, q) 161 162 b, err = q.appendWith(fmter, b) 163 if err != nil { 164 return nil, err 165 } 166 167 b = append(b, "MERGE "...) 168 if q.db.dialect.Name() == dialect.PG { 169 b = append(b, "INTO "...) 170 } 171 172 b, err = q.appendFirstTableWithAlias(fmter, b) 173 if err != nil { 174 return nil, err 175 } 176 177 b = append(b, " USING "...) 178 b, err = q.using.AppendQuery(fmter, b) 179 if err != nil { 180 return nil, err 181 } 182 183 b = append(b, " ON "...) 184 b, err = q.on.AppendQuery(fmter, b) 185 if err != nil { 186 return nil, err 187 } 188 189 for _, w := range q.when { 190 b = append(b, " WHEN "...) 191 b, err = w.AppendQuery(fmter, b) 192 if err != nil { 193 return nil, err 194 } 195 } 196 197 if q.hasFeature(feature.Output) && q.hasReturning() { 198 b = append(b, " OUTPUT "...) 199 b, err = q.appendOutput(fmter, b) 200 if err != nil { 201 return nil, err 202 } 203 } 204 205 // A MERGE statement must be terminated by a semi-colon (;). 206 b = append(b, ";"...) 207 208 return b, nil 209 } 210 211 //------------------------------------------------------------------------------ 212 213 func (q *MergeQuery) Scan(ctx context.Context, dest ...interface{}) error { 214 _, err := q.scanOrExec(ctx, dest, true) 215 return err 216 } 217 218 func (q *MergeQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { 219 return q.scanOrExec(ctx, dest, len(dest) > 0) 220 } 221 222 func (q *MergeQuery) scanOrExec( 223 ctx context.Context, dest []interface{}, hasDest bool, 224 ) (sql.Result, error) { 225 if q.err != nil { 226 return nil, q.err 227 } 228 229 // Run append model hooks before generating the query. 230 if err := q.beforeAppendModel(ctx, q); err != nil { 231 return nil, err 232 } 233 234 // Generate the query before checking hasReturning. 235 queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) 236 if err != nil { 237 return nil, err 238 } 239 240 useScan := hasDest || (q.hasReturning() && q.hasFeature(feature.InsertReturning|feature.Output)) 241 var model Model 242 243 if useScan { 244 var err error 245 model, err = q.getModel(dest) 246 if err != nil { 247 return nil, err 248 } 249 } 250 251 query := internal.String(queryBytes) 252 var res sql.Result 253 254 if useScan { 255 res, err = q.scan(ctx, q, query, model, true) 256 if err != nil { 257 return nil, err 258 } 259 } else { 260 res, err = q.exec(ctx, q, query) 261 if err != nil { 262 return nil, err 263 } 264 } 265 266 return res, nil 267 } 268 269 func (q *MergeQuery) String() string { 270 buf, err := q.AppendQuery(q.db.Formatter(), nil) 271 if err != nil { 272 panic(err) 273 } 274 275 return string(buf) 276 } 277 278 //------------------------------------------------------------------------------ 279 280 type whenInsert struct { 281 expr string 282 query *InsertQuery 283 } 284 285 func (w *whenInsert) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { 286 b = append(b, w.expr...) 287 if w.query != nil { 288 b = append(b, " THEN INSERT"...) 289 b, err = w.query.appendColumnsValues(fmter, b, true) 290 if err != nil { 291 return nil, err 292 } 293 } 294 return b, nil 295 } 296 297 type whenUpdate struct { 298 expr string 299 query *UpdateQuery 300 } 301 302 func (w *whenUpdate) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { 303 b = append(b, w.expr...) 304 if w.query != nil { 305 b = append(b, " THEN UPDATE SET "...) 306 b, err = w.query.appendSet(fmter, b) 307 if err != nil { 308 return nil, err 309 } 310 } 311 return b, nil 312 } 313 314 type whenDelete struct { 315 expr string 316 } 317 318 func (w *whenDelete) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { 319 b = append(b, w.expr...) 320 b = append(b, " THEN DELETE"...) 321 return b, nil 322 }