relation_join.go (9256B)
1 package bun 2 3 import ( 4 "context" 5 "reflect" 6 "time" 7 8 "github.com/uptrace/bun/dialect/feature" 9 "github.com/uptrace/bun/internal" 10 "github.com/uptrace/bun/schema" 11 ) 12 13 type relationJoin struct { 14 Parent *relationJoin 15 BaseModel TableModel 16 JoinModel TableModel 17 Relation *schema.Relation 18 19 apply func(*SelectQuery) *SelectQuery 20 columns []schema.QueryWithArgs 21 } 22 23 func (j *relationJoin) applyTo(q *SelectQuery) { 24 if j.apply == nil { 25 return 26 } 27 28 var table *schema.Table 29 var columns []schema.QueryWithArgs 30 31 // Save state. 32 table, q.table = q.table, j.JoinModel.Table() 33 columns, q.columns = q.columns, nil 34 35 q = j.apply(q) 36 37 // Restore state. 38 q.table = table 39 j.columns, q.columns = q.columns, columns 40 } 41 42 func (j *relationJoin) Select(ctx context.Context, q *SelectQuery) error { 43 switch j.Relation.Type { 44 } 45 panic("not reached") 46 } 47 48 func (j *relationJoin) selectMany(ctx context.Context, q *SelectQuery) error { 49 q = j.manyQuery(q) 50 if q == nil { 51 return nil 52 } 53 return q.Scan(ctx) 54 } 55 56 func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery { 57 hasManyModel := newHasManyModel(j) 58 if hasManyModel == nil { 59 return nil 60 } 61 62 q = q.Model(hasManyModel) 63 64 var where []byte 65 66 if q.db.dialect.Features().Has(feature.CompositeIn) { 67 return j.manyQueryCompositeIn(where, q) 68 } 69 return j.manyQueryMulti(where, q) 70 } 71 72 func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *SelectQuery { 73 if len(j.Relation.JoinFields) > 1 { 74 where = append(where, '(') 75 } 76 where = appendColumns(where, j.JoinModel.Table().SQLAlias, j.Relation.JoinFields) 77 if len(j.Relation.JoinFields) > 1 { 78 where = append(where, ')') 79 } 80 where = append(where, " IN ("...) 81 where = appendChildValues( 82 q.db.Formatter(), 83 where, 84 j.JoinModel.rootValue(), 85 j.JoinModel.parentIndex(), 86 j.Relation.BaseFields, 87 ) 88 where = append(where, ")"...) 89 q = q.Where(internal.String(where)) 90 91 if j.Relation.PolymorphicField != nil { 92 q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue) 93 } 94 95 j.applyTo(q) 96 q = q.Apply(j.hasManyColumns) 97 98 return q 99 } 100 101 func (j *relationJoin) manyQueryMulti(where []byte, q *SelectQuery) *SelectQuery { 102 where = appendMultiValues( 103 q.db.Formatter(), 104 where, 105 j.JoinModel.rootValue(), 106 j.JoinModel.parentIndex(), 107 j.Relation.BaseFields, 108 j.Relation.JoinFields, 109 j.JoinModel.Table().SQLAlias, 110 ) 111 112 q = q.Where(internal.String(where)) 113 114 if j.Relation.PolymorphicField != nil { 115 q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue) 116 } 117 118 j.applyTo(q) 119 q = q.Apply(j.hasManyColumns) 120 121 return q 122 } 123 124 func (j *relationJoin) hasManyColumns(q *SelectQuery) *SelectQuery { 125 b := make([]byte, 0, 32) 126 127 joinTable := j.JoinModel.Table() 128 if len(j.columns) > 0 { 129 for i, col := range j.columns { 130 if i > 0 { 131 b = append(b, ", "...) 132 } 133 134 if col.Args == nil { 135 if field, ok := joinTable.FieldMap[col.Query]; ok { 136 b = append(b, joinTable.SQLAlias...) 137 b = append(b, '.') 138 b = append(b, field.SQLName...) 139 continue 140 } 141 } 142 143 var err error 144 b, err = col.AppendQuery(q.db.fmter, b) 145 if err != nil { 146 q.setErr(err) 147 return q 148 } 149 150 } 151 } else { 152 b = appendColumns(b, joinTable.SQLAlias, joinTable.Fields) 153 } 154 155 q = q.ColumnExpr(internal.String(b)) 156 157 return q 158 } 159 160 func (j *relationJoin) selectM2M(ctx context.Context, q *SelectQuery) error { 161 q = j.m2mQuery(q) 162 if q == nil { 163 return nil 164 } 165 return q.Scan(ctx) 166 } 167 168 func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery { 169 fmter := q.db.fmter 170 171 m2mModel := newM2MModel(j) 172 if m2mModel == nil { 173 return nil 174 } 175 q = q.Model(m2mModel) 176 177 index := j.JoinModel.parentIndex() 178 baseTable := j.BaseModel.Table() 179 180 if j.Relation.M2MTable != nil { 181 q = q.ColumnExpr(string(j.Relation.M2MTable.SQLAlias) + ".*") 182 } 183 184 //nolint 185 var join []byte 186 join = append(join, "JOIN "...) 187 join = fmter.AppendQuery(join, string(j.Relation.M2MTable.SQLName)) 188 join = append(join, " AS "...) 189 join = append(join, j.Relation.M2MTable.SQLAlias...) 190 join = append(join, " ON ("...) 191 for i, col := range j.Relation.M2MBaseFields { 192 if i > 0 { 193 join = append(join, ", "...) 194 } 195 join = append(join, j.Relation.M2MTable.SQLAlias...) 196 join = append(join, '.') 197 join = append(join, col.SQLName...) 198 } 199 join = append(join, ") IN ("...) 200 join = appendChildValues(fmter, join, j.BaseModel.rootValue(), index, baseTable.PKs) 201 join = append(join, ")"...) 202 q = q.Join(internal.String(join)) 203 204 joinTable := j.JoinModel.Table() 205 for i, m2mJoinField := range j.Relation.M2MJoinFields { 206 joinField := j.Relation.JoinFields[i] 207 q = q.Where("?.? = ?.?", 208 joinTable.SQLAlias, joinField.SQLName, 209 j.Relation.M2MTable.SQLAlias, m2mJoinField.SQLName) 210 } 211 212 j.applyTo(q) 213 q = q.Apply(j.hasManyColumns) 214 215 return q 216 } 217 218 func (j *relationJoin) hasParent() bool { 219 if j.Parent != nil { 220 switch j.Parent.Relation.Type { 221 case schema.HasOneRelation, schema.BelongsToRelation: 222 return true 223 } 224 } 225 return false 226 } 227 228 func (j *relationJoin) appendAlias(fmter schema.Formatter, b []byte) []byte { 229 quote := fmter.IdentQuote() 230 231 b = append(b, quote) 232 b = appendAlias(b, j) 233 b = append(b, quote) 234 return b 235 } 236 237 func (j *relationJoin) appendAliasColumn(fmter schema.Formatter, b []byte, column string) []byte { 238 quote := fmter.IdentQuote() 239 240 b = append(b, quote) 241 b = appendAlias(b, j) 242 b = append(b, "__"...) 243 b = append(b, column...) 244 b = append(b, quote) 245 return b 246 } 247 248 func (j *relationJoin) appendBaseAlias(fmter schema.Formatter, b []byte) []byte { 249 quote := fmter.IdentQuote() 250 251 if j.hasParent() { 252 b = append(b, quote) 253 b = appendAlias(b, j.Parent) 254 b = append(b, quote) 255 return b 256 } 257 return append(b, j.BaseModel.Table().SQLAlias...) 258 } 259 260 func (j *relationJoin) appendSoftDelete(fmter schema.Formatter, b []byte, flags internal.Flag) []byte { 261 b = append(b, '.') 262 263 field := j.JoinModel.Table().SoftDeleteField 264 b = append(b, field.SQLName...) 265 266 if field.IsPtr || field.NullZero { 267 if flags.Has(deletedFlag) { 268 b = append(b, " IS NOT NULL"...) 269 } else { 270 b = append(b, " IS NULL"...) 271 } 272 } else { 273 if flags.Has(deletedFlag) { 274 b = append(b, " != "...) 275 } else { 276 b = append(b, " = "...) 277 } 278 b = fmter.Dialect().AppendTime(b, time.Time{}) 279 } 280 281 return b 282 } 283 284 func appendAlias(b []byte, j *relationJoin) []byte { 285 if j.hasParent() { 286 b = appendAlias(b, j.Parent) 287 b = append(b, "__"...) 288 } 289 b = append(b, j.Relation.Field.Name...) 290 return b 291 } 292 293 func (j *relationJoin) appendHasOneJoin( 294 fmter schema.Formatter, b []byte, q *SelectQuery, 295 ) (_ []byte, err error) { 296 isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag) 297 298 b = append(b, "LEFT JOIN "...) 299 b = fmter.AppendQuery(b, string(j.JoinModel.Table().SQLNameForSelects)) 300 b = append(b, " AS "...) 301 b = j.appendAlias(fmter, b) 302 303 b = append(b, " ON "...) 304 305 b = append(b, '(') 306 for i, baseField := range j.Relation.BaseFields { 307 if i > 0 { 308 b = append(b, " AND "...) 309 } 310 b = j.appendAlias(fmter, b) 311 b = append(b, '.') 312 b = append(b, j.Relation.JoinFields[i].SQLName...) 313 b = append(b, " = "...) 314 b = j.appendBaseAlias(fmter, b) 315 b = append(b, '.') 316 b = append(b, baseField.SQLName...) 317 } 318 b = append(b, ')') 319 320 if isSoftDelete { 321 b = append(b, " AND "...) 322 b = j.appendAlias(fmter, b) 323 b = j.appendSoftDelete(fmter, b, q.flags) 324 } 325 326 return b, nil 327 } 328 329 func appendChildValues( 330 fmter schema.Formatter, b []byte, v reflect.Value, index []int, fields []*schema.Field, 331 ) []byte { 332 seen := make(map[string]struct{}) 333 walk(v, index, func(v reflect.Value) { 334 start := len(b) 335 336 if len(fields) > 1 { 337 b = append(b, '(') 338 } 339 for i, f := range fields { 340 if i > 0 { 341 b = append(b, ", "...) 342 } 343 b = f.AppendValue(fmter, b, v) 344 } 345 if len(fields) > 1 { 346 b = append(b, ')') 347 } 348 b = append(b, ", "...) 349 350 if _, ok := seen[string(b[start:])]; ok { 351 b = b[:start] 352 } else { 353 seen[string(b[start:])] = struct{}{} 354 } 355 }) 356 if len(seen) > 0 { 357 b = b[:len(b)-2] // trim ", " 358 } 359 return b 360 } 361 362 // appendMultiValues is an alternative to appendChildValues that doesn't use the sql keyword ID 363 // but instead use a old style ((k1=v1) AND (k2=v2)) OR (...) of conditions. 364 func appendMultiValues( 365 fmter schema.Formatter, b []byte, v reflect.Value, index []int, baseFields, joinFields []*schema.Field, joinTable schema.Safe, 366 ) []byte { 367 // This is based on a mix of appendChildValues and query_base.appendColumns 368 369 // These should never missmatch in length but nice to know if it does 370 if len(joinFields) != len(baseFields) { 371 panic("not reached") 372 } 373 374 // walk the relations 375 b = append(b, '(') 376 seen := make(map[string]struct{}) 377 walk(v, index, func(v reflect.Value) { 378 start := len(b) 379 for i, f := range baseFields { 380 if i > 0 { 381 b = append(b, " AND "...) 382 } 383 if len(baseFields) > 1 { 384 b = append(b, '(') 385 } 386 // Field name 387 b = append(b, joinTable...) 388 b = append(b, '.') 389 b = append(b, []byte(joinFields[i].SQLName)...) 390 391 // Equals value 392 b = append(b, '=') 393 b = f.AppendValue(fmter, b, v) 394 if len(baseFields) > 1 { 395 b = append(b, ')') 396 } 397 } 398 399 b = append(b, ") OR ("...) 400 401 if _, ok := seen[string(b[start:])]; ok { 402 b = b[:start] 403 } else { 404 seen[string(b[start:])] = struct{}{} 405 } 406 }) 407 if len(seen) > 0 { 408 b = b[:len(b)-6] // trim ") OR (" 409 } 410 b = append(b, ')') 411 return b 412 }