conn.go (3277B)
1 // GoToSocial 2 // Copyright (C) GoToSocial Authors admin@gotosocial.org 3 // SPDX-License-Identifier: AGPL-3.0-or-later 4 // 5 // This program is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Affero General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // This program is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Affero General Public License for more details. 14 // 15 // You should have received a copy of the GNU Affero General Public License 16 // along with this program. If not, see <http://www.gnu.org/licenses/>. 17 18 package bundb 19 20 import ( 21 "context" 22 "database/sql" 23 24 "github.com/superseriousbusiness/gotosocial/internal/db" 25 "github.com/uptrace/bun" 26 "github.com/uptrace/bun/dialect" 27 ) 28 29 // DBConn wrapps a bun.DB conn to provide SQL-type specific additional functionality 30 type DBConn struct { 31 errProc func(error) db.Error // errProc is the SQL-type specific error processor 32 *bun.DB // DB is the underlying bun.DB connection 33 } 34 35 // WrapDBConn wraps a bun DB connection to provide our own error processing dependent on DB dialect. 36 func WrapDBConn(dbConn *bun.DB) *DBConn { 37 var errProc func(error) db.Error 38 switch dbConn.Dialect().Name() { 39 case dialect.PG: 40 errProc = processPostgresError 41 case dialect.SQLite: 42 errProc = processSQLiteError 43 default: 44 panic("unknown dialect name: " + dbConn.Dialect().Name().String()) 45 } 46 return &DBConn{ 47 errProc: errProc, 48 DB: dbConn, 49 } 50 } 51 52 // RunInTx wraps execution of the supplied transaction function. 53 func (conn *DBConn) RunInTx(ctx context.Context, fn func(bun.Tx) error) db.Error { 54 return conn.ProcessError(func() error { 55 // Acquire a new transaction 56 tx, err := conn.BeginTx(ctx, nil) 57 if err != nil { 58 return err 59 } 60 61 var done bool 62 63 defer func() { 64 if !done { 65 _ = tx.Rollback() 66 } 67 }() 68 69 // Perform supplied transaction 70 if err := fn(tx); err != nil { 71 return err 72 } 73 74 // Finally, commit 75 err = tx.Commit() //nolint:contextcheck 76 done = true 77 return err 78 }()) 79 } 80 81 // ProcessError processes an error to replace any known values with our own db.Error types, 82 // making it easier to catch specific situations (e.g. no rows, already exists, etc) 83 func (conn *DBConn) ProcessError(err error) db.Error { 84 switch { 85 case err == nil: 86 return nil 87 case err == sql.ErrNoRows: 88 return db.ErrNoEntries 89 default: 90 return conn.errProc(err) 91 } 92 } 93 94 // Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors 95 func (conn *DBConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) { 96 exists, err := query.Exists(ctx) 97 98 // Process error as our own and check if it exists 99 switch err := conn.ProcessError(err); err { 100 case nil: 101 return exists, nil 102 case db.ErrNoEntries: 103 return false, nil 104 default: 105 return false, err 106 } 107 } 108 109 // NotExists is the functional opposite of conn.Exists() 110 func (conn *DBConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) { 111 exists, err := conn.Exists(ctx, query) 112 return !exists, err 113 }