gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

relation_join.go (9256B)


      1 package bun
      2 
      3 import (
      4 	"context"
      5 	"reflect"
      6 	"time"
      7 
      8 	"github.com/uptrace/bun/dialect/feature"
      9 	"github.com/uptrace/bun/internal"
     10 	"github.com/uptrace/bun/schema"
     11 )
     12 
     13 type relationJoin struct {
     14 	Parent    *relationJoin
     15 	BaseModel TableModel
     16 	JoinModel TableModel
     17 	Relation  *schema.Relation
     18 
     19 	apply   func(*SelectQuery) *SelectQuery
     20 	columns []schema.QueryWithArgs
     21 }
     22 
     23 func (j *relationJoin) applyTo(q *SelectQuery) {
     24 	if j.apply == nil {
     25 		return
     26 	}
     27 
     28 	var table *schema.Table
     29 	var columns []schema.QueryWithArgs
     30 
     31 	// Save state.
     32 	table, q.table = q.table, j.JoinModel.Table()
     33 	columns, q.columns = q.columns, nil
     34 
     35 	q = j.apply(q)
     36 
     37 	// Restore state.
     38 	q.table = table
     39 	j.columns, q.columns = q.columns, columns
     40 }
     41 
     42 func (j *relationJoin) Select(ctx context.Context, q *SelectQuery) error {
     43 	switch j.Relation.Type {
     44 	}
     45 	panic("not reached")
     46 }
     47 
     48 func (j *relationJoin) selectMany(ctx context.Context, q *SelectQuery) error {
     49 	q = j.manyQuery(q)
     50 	if q == nil {
     51 		return nil
     52 	}
     53 	return q.Scan(ctx)
     54 }
     55 
     56 func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery {
     57 	hasManyModel := newHasManyModel(j)
     58 	if hasManyModel == nil {
     59 		return nil
     60 	}
     61 
     62 	q = q.Model(hasManyModel)
     63 
     64 	var where []byte
     65 
     66 	if q.db.dialect.Features().Has(feature.CompositeIn) {
     67 		return j.manyQueryCompositeIn(where, q)
     68 	}
     69 	return j.manyQueryMulti(where, q)
     70 }
     71 
     72 func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *SelectQuery {
     73 	if len(j.Relation.JoinFields) > 1 {
     74 		where = append(where, '(')
     75 	}
     76 	where = appendColumns(where, j.JoinModel.Table().SQLAlias, j.Relation.JoinFields)
     77 	if len(j.Relation.JoinFields) > 1 {
     78 		where = append(where, ')')
     79 	}
     80 	where = append(where, " IN ("...)
     81 	where = appendChildValues(
     82 		q.db.Formatter(),
     83 		where,
     84 		j.JoinModel.rootValue(),
     85 		j.JoinModel.parentIndex(),
     86 		j.Relation.BaseFields,
     87 	)
     88 	where = append(where, ")"...)
     89 	q = q.Where(internal.String(where))
     90 
     91 	if j.Relation.PolymorphicField != nil {
     92 		q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue)
     93 	}
     94 
     95 	j.applyTo(q)
     96 	q = q.Apply(j.hasManyColumns)
     97 
     98 	return q
     99 }
    100 
    101 func (j *relationJoin) manyQueryMulti(where []byte, q *SelectQuery) *SelectQuery {
    102 	where = appendMultiValues(
    103 		q.db.Formatter(),
    104 		where,
    105 		j.JoinModel.rootValue(),
    106 		j.JoinModel.parentIndex(),
    107 		j.Relation.BaseFields,
    108 		j.Relation.JoinFields,
    109 		j.JoinModel.Table().SQLAlias,
    110 	)
    111 
    112 	q = q.Where(internal.String(where))
    113 
    114 	if j.Relation.PolymorphicField != nil {
    115 		q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue)
    116 	}
    117 
    118 	j.applyTo(q)
    119 	q = q.Apply(j.hasManyColumns)
    120 
    121 	return q
    122 }
    123 
    124 func (j *relationJoin) hasManyColumns(q *SelectQuery) *SelectQuery {
    125 	b := make([]byte, 0, 32)
    126 
    127 	joinTable := j.JoinModel.Table()
    128 	if len(j.columns) > 0 {
    129 		for i, col := range j.columns {
    130 			if i > 0 {
    131 				b = append(b, ", "...)
    132 			}
    133 
    134 			if col.Args == nil {
    135 				if field, ok := joinTable.FieldMap[col.Query]; ok {
    136 					b = append(b, joinTable.SQLAlias...)
    137 					b = append(b, '.')
    138 					b = append(b, field.SQLName...)
    139 					continue
    140 				}
    141 			}
    142 
    143 			var err error
    144 			b, err = col.AppendQuery(q.db.fmter, b)
    145 			if err != nil {
    146 				q.setErr(err)
    147 				return q
    148 			}
    149 
    150 		}
    151 	} else {
    152 		b = appendColumns(b, joinTable.SQLAlias, joinTable.Fields)
    153 	}
    154 
    155 	q = q.ColumnExpr(internal.String(b))
    156 
    157 	return q
    158 }
    159 
    160 func (j *relationJoin) selectM2M(ctx context.Context, q *SelectQuery) error {
    161 	q = j.m2mQuery(q)
    162 	if q == nil {
    163 		return nil
    164 	}
    165 	return q.Scan(ctx)
    166 }
    167 
    168 func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery {
    169 	fmter := q.db.fmter
    170 
    171 	m2mModel := newM2MModel(j)
    172 	if m2mModel == nil {
    173 		return nil
    174 	}
    175 	q = q.Model(m2mModel)
    176 
    177 	index := j.JoinModel.parentIndex()
    178 	baseTable := j.BaseModel.Table()
    179 
    180 	if j.Relation.M2MTable != nil {
    181 		q = q.ColumnExpr(string(j.Relation.M2MTable.SQLAlias) + ".*")
    182 	}
    183 
    184 	//nolint
    185 	var join []byte
    186 	join = append(join, "JOIN "...)
    187 	join = fmter.AppendQuery(join, string(j.Relation.M2MTable.SQLName))
    188 	join = append(join, " AS "...)
    189 	join = append(join, j.Relation.M2MTable.SQLAlias...)
    190 	join = append(join, " ON ("...)
    191 	for i, col := range j.Relation.M2MBaseFields {
    192 		if i > 0 {
    193 			join = append(join, ", "...)
    194 		}
    195 		join = append(join, j.Relation.M2MTable.SQLAlias...)
    196 		join = append(join, '.')
    197 		join = append(join, col.SQLName...)
    198 	}
    199 	join = append(join, ") IN ("...)
    200 	join = appendChildValues(fmter, join, j.BaseModel.rootValue(), index, baseTable.PKs)
    201 	join = append(join, ")"...)
    202 	q = q.Join(internal.String(join))
    203 
    204 	joinTable := j.JoinModel.Table()
    205 	for i, m2mJoinField := range j.Relation.M2MJoinFields {
    206 		joinField := j.Relation.JoinFields[i]
    207 		q = q.Where("?.? = ?.?",
    208 			joinTable.SQLAlias, joinField.SQLName,
    209 			j.Relation.M2MTable.SQLAlias, m2mJoinField.SQLName)
    210 	}
    211 
    212 	j.applyTo(q)
    213 	q = q.Apply(j.hasManyColumns)
    214 
    215 	return q
    216 }
    217 
    218 func (j *relationJoin) hasParent() bool {
    219 	if j.Parent != nil {
    220 		switch j.Parent.Relation.Type {
    221 		case schema.HasOneRelation, schema.BelongsToRelation:
    222 			return true
    223 		}
    224 	}
    225 	return false
    226 }
    227 
    228 func (j *relationJoin) appendAlias(fmter schema.Formatter, b []byte) []byte {
    229 	quote := fmter.IdentQuote()
    230 
    231 	b = append(b, quote)
    232 	b = appendAlias(b, j)
    233 	b = append(b, quote)
    234 	return b
    235 }
    236 
    237 func (j *relationJoin) appendAliasColumn(fmter schema.Formatter, b []byte, column string) []byte {
    238 	quote := fmter.IdentQuote()
    239 
    240 	b = append(b, quote)
    241 	b = appendAlias(b, j)
    242 	b = append(b, "__"...)
    243 	b = append(b, column...)
    244 	b = append(b, quote)
    245 	return b
    246 }
    247 
    248 func (j *relationJoin) appendBaseAlias(fmter schema.Formatter, b []byte) []byte {
    249 	quote := fmter.IdentQuote()
    250 
    251 	if j.hasParent() {
    252 		b = append(b, quote)
    253 		b = appendAlias(b, j.Parent)
    254 		b = append(b, quote)
    255 		return b
    256 	}
    257 	return append(b, j.BaseModel.Table().SQLAlias...)
    258 }
    259 
    260 func (j *relationJoin) appendSoftDelete(fmter schema.Formatter, b []byte, flags internal.Flag) []byte {
    261 	b = append(b, '.')
    262 
    263 	field := j.JoinModel.Table().SoftDeleteField
    264 	b = append(b, field.SQLName...)
    265 
    266 	if field.IsPtr || field.NullZero {
    267 		if flags.Has(deletedFlag) {
    268 			b = append(b, " IS NOT NULL"...)
    269 		} else {
    270 			b = append(b, " IS NULL"...)
    271 		}
    272 	} else {
    273 		if flags.Has(deletedFlag) {
    274 			b = append(b, " != "...)
    275 		} else {
    276 			b = append(b, " = "...)
    277 		}
    278 		b = fmter.Dialect().AppendTime(b, time.Time{})
    279 	}
    280 
    281 	return b
    282 }
    283 
    284 func appendAlias(b []byte, j *relationJoin) []byte {
    285 	if j.hasParent() {
    286 		b = appendAlias(b, j.Parent)
    287 		b = append(b, "__"...)
    288 	}
    289 	b = append(b, j.Relation.Field.Name...)
    290 	return b
    291 }
    292 
    293 func (j *relationJoin) appendHasOneJoin(
    294 	fmter schema.Formatter, b []byte, q *SelectQuery,
    295 ) (_ []byte, err error) {
    296 	isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag)
    297 
    298 	b = append(b, "LEFT JOIN "...)
    299 	b = fmter.AppendQuery(b, string(j.JoinModel.Table().SQLNameForSelects))
    300 	b = append(b, " AS "...)
    301 	b = j.appendAlias(fmter, b)
    302 
    303 	b = append(b, " ON "...)
    304 
    305 	b = append(b, '(')
    306 	for i, baseField := range j.Relation.BaseFields {
    307 		if i > 0 {
    308 			b = append(b, " AND "...)
    309 		}
    310 		b = j.appendAlias(fmter, b)
    311 		b = append(b, '.')
    312 		b = append(b, j.Relation.JoinFields[i].SQLName...)
    313 		b = append(b, " = "...)
    314 		b = j.appendBaseAlias(fmter, b)
    315 		b = append(b, '.')
    316 		b = append(b, baseField.SQLName...)
    317 	}
    318 	b = append(b, ')')
    319 
    320 	if isSoftDelete {
    321 		b = append(b, " AND "...)
    322 		b = j.appendAlias(fmter, b)
    323 		b = j.appendSoftDelete(fmter, b, q.flags)
    324 	}
    325 
    326 	return b, nil
    327 }
    328 
    329 func appendChildValues(
    330 	fmter schema.Formatter, b []byte, v reflect.Value, index []int, fields []*schema.Field,
    331 ) []byte {
    332 	seen := make(map[string]struct{})
    333 	walk(v, index, func(v reflect.Value) {
    334 		start := len(b)
    335 
    336 		if len(fields) > 1 {
    337 			b = append(b, '(')
    338 		}
    339 		for i, f := range fields {
    340 			if i > 0 {
    341 				b = append(b, ", "...)
    342 			}
    343 			b = f.AppendValue(fmter, b, v)
    344 		}
    345 		if len(fields) > 1 {
    346 			b = append(b, ')')
    347 		}
    348 		b = append(b, ", "...)
    349 
    350 		if _, ok := seen[string(b[start:])]; ok {
    351 			b = b[:start]
    352 		} else {
    353 			seen[string(b[start:])] = struct{}{}
    354 		}
    355 	})
    356 	if len(seen) > 0 {
    357 		b = b[:len(b)-2] // trim ", "
    358 	}
    359 	return b
    360 }
    361 
    362 // appendMultiValues is an alternative to appendChildValues that doesn't use the sql keyword ID
    363 // but instead use a old style ((k1=v1) AND (k2=v2)) OR (...) of conditions.
    364 func appendMultiValues(
    365 	fmter schema.Formatter, b []byte, v reflect.Value, index []int, baseFields, joinFields []*schema.Field, joinTable schema.Safe,
    366 ) []byte {
    367 	// This is based on a mix of appendChildValues and query_base.appendColumns
    368 
    369 	// These should never missmatch in length but nice to know if it does
    370 	if len(joinFields) != len(baseFields) {
    371 		panic("not reached")
    372 	}
    373 
    374 	// walk the relations
    375 	b = append(b, '(')
    376 	seen := make(map[string]struct{})
    377 	walk(v, index, func(v reflect.Value) {
    378 		start := len(b)
    379 		for i, f := range baseFields {
    380 			if i > 0 {
    381 				b = append(b, " AND "...)
    382 			}
    383 			if len(baseFields) > 1 {
    384 				b = append(b, '(')
    385 			}
    386 			// Field name
    387 			b = append(b, joinTable...)
    388 			b = append(b, '.')
    389 			b = append(b, []byte(joinFields[i].SQLName)...)
    390 
    391 			// Equals value
    392 			b = append(b, '=')
    393 			b = f.AppendValue(fmter, b, v)
    394 			if len(baseFields) > 1 {
    395 				b = append(b, ')')
    396 			}
    397 		}
    398 
    399 		b = append(b, ") OR ("...)
    400 
    401 		if _, ok := seen[string(b[start:])]; ok {
    402 			b = b[:start]
    403 		} else {
    404 			seen[string(b[start:])] = struct{}{}
    405 		}
    406 	})
    407 	if len(seen) > 0 {
    408 		b = b[:len(b)-6] // trim ") OR ("
    409 	}
    410 	b = append(b, ')')
    411 	return b
    412 }