dialect.go (2448B)
1 package pgdialect 2 3 import ( 4 "database/sql" 5 "fmt" 6 "strconv" 7 "strings" 8 9 "github.com/uptrace/bun" 10 "github.com/uptrace/bun/dialect" 11 "github.com/uptrace/bun/dialect/feature" 12 "github.com/uptrace/bun/dialect/sqltype" 13 "github.com/uptrace/bun/schema" 14 ) 15 16 var pgDialect = New() 17 18 func init() { 19 if Version() != bun.Version() { 20 panic(fmt.Errorf("pgdialect and Bun must have the same version: v%s != v%s", 21 Version(), bun.Version())) 22 } 23 } 24 25 type Dialect struct { 26 schema.BaseDialect 27 28 tables *schema.Tables 29 features feature.Feature 30 } 31 32 func New() *Dialect { 33 d := new(Dialect) 34 d.tables = schema.NewTables(d) 35 d.features = feature.CTE | 36 feature.WithValues | 37 feature.Returning | 38 feature.InsertReturning | 39 feature.DefaultPlaceholder | 40 feature.DoubleColonCast | 41 feature.InsertTableAlias | 42 feature.UpdateTableAlias | 43 feature.DeleteTableAlias | 44 feature.TableCascade | 45 feature.TableIdentity | 46 feature.TableTruncate | 47 feature.TableNotExists | 48 feature.InsertOnConflict | 49 feature.SelectExists | 50 feature.GeneratedIdentity | 51 feature.CompositeIn 52 return d 53 } 54 55 func (d *Dialect) Init(*sql.DB) {} 56 57 func (d *Dialect) Name() dialect.Name { 58 return dialect.PG 59 } 60 61 func (d *Dialect) Features() feature.Feature { 62 return d.features 63 } 64 65 func (d *Dialect) Tables() *schema.Tables { 66 return d.tables 67 } 68 69 func (d *Dialect) OnTable(table *schema.Table) { 70 for _, field := range table.FieldMap { 71 d.onField(field) 72 } 73 } 74 75 func (d *Dialect) onField(field *schema.Field) { 76 field.DiscoveredSQLType = fieldSQLType(field) 77 78 if field.AutoIncrement && !field.Identity { 79 switch field.DiscoveredSQLType { 80 case sqltype.SmallInt: 81 field.CreateTableSQLType = pgTypeSmallSerial 82 case sqltype.Integer: 83 field.CreateTableSQLType = pgTypeSerial 84 case sqltype.BigInt: 85 field.CreateTableSQLType = pgTypeBigSerial 86 } 87 } 88 89 if field.Tag.HasOption("array") || strings.HasSuffix(field.UserSQLType, "[]") { 90 field.Append = d.arrayAppender(field.StructField.Type) 91 field.Scan = arrayScanner(field.StructField.Type) 92 } 93 94 if field.DiscoveredSQLType == sqltype.HSTORE { 95 field.Append = d.hstoreAppender(field.StructField.Type) 96 field.Scan = hstoreScanner(field.StructField.Type) 97 } 98 } 99 100 func (d *Dialect) IdentQuote() byte { 101 return '"' 102 } 103 104 func (d *Dialect) AppendUint32(b []byte, n uint32) []byte { 105 return strconv.AppendInt(b, int64(int32(n)), 10) 106 } 107 108 func (d *Dialect) AppendUint64(b []byte, n uint64) []byte { 109 return strconv.AppendInt(b, int64(n), 10) 110 }