gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

conn.go (3277B)


      1 // GoToSocial
      2 // Copyright (C) GoToSocial Authors admin@gotosocial.org
      3 // SPDX-License-Identifier: AGPL-3.0-or-later
      4 //
      5 // This program is free software: you can redistribute it and/or modify
      6 // it under the terms of the GNU Affero General Public License as published by
      7 // the Free Software Foundation, either version 3 of the License, or
      8 // (at your option) any later version.
      9 //
     10 // This program is distributed in the hope that it will be useful,
     11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
     12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     13 // GNU Affero General Public License for more details.
     14 //
     15 // You should have received a copy of the GNU Affero General Public License
     16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
     17 
     18 package bundb
     19 
     20 import (
     21 	"context"
     22 	"database/sql"
     23 
     24 	"github.com/superseriousbusiness/gotosocial/internal/db"
     25 	"github.com/uptrace/bun"
     26 	"github.com/uptrace/bun/dialect"
     27 )
     28 
     29 // DBConn wrapps a bun.DB conn to provide SQL-type specific additional functionality
     30 type DBConn struct {
     31 	errProc func(error) db.Error // errProc is the SQL-type specific error processor
     32 	*bun.DB                      // DB is the underlying bun.DB connection
     33 }
     34 
     35 // WrapDBConn wraps a bun DB connection to provide our own error processing dependent on DB dialect.
     36 func WrapDBConn(dbConn *bun.DB) *DBConn {
     37 	var errProc func(error) db.Error
     38 	switch dbConn.Dialect().Name() {
     39 	case dialect.PG:
     40 		errProc = processPostgresError
     41 	case dialect.SQLite:
     42 		errProc = processSQLiteError
     43 	default:
     44 		panic("unknown dialect name: " + dbConn.Dialect().Name().String())
     45 	}
     46 	return &DBConn{
     47 		errProc: errProc,
     48 		DB:      dbConn,
     49 	}
     50 }
     51 
     52 // RunInTx wraps execution of the supplied transaction function.
     53 func (conn *DBConn) RunInTx(ctx context.Context, fn func(bun.Tx) error) db.Error {
     54 	return conn.ProcessError(func() error {
     55 		// Acquire a new transaction
     56 		tx, err := conn.BeginTx(ctx, nil)
     57 		if err != nil {
     58 			return err
     59 		}
     60 
     61 		var done bool
     62 
     63 		defer func() {
     64 			if !done {
     65 				_ = tx.Rollback()
     66 			}
     67 		}()
     68 
     69 		// Perform supplied transaction
     70 		if err := fn(tx); err != nil {
     71 			return err
     72 		}
     73 
     74 		// Finally, commit
     75 		err = tx.Commit() //nolint:contextcheck
     76 		done = true
     77 		return err
     78 	}())
     79 }
     80 
     81 // ProcessError processes an error to replace any known values with our own db.Error types,
     82 // making it easier to catch specific situations (e.g. no rows, already exists, etc)
     83 func (conn *DBConn) ProcessError(err error) db.Error {
     84 	switch {
     85 	case err == nil:
     86 		return nil
     87 	case err == sql.ErrNoRows:
     88 		return db.ErrNoEntries
     89 	default:
     90 		return conn.errProc(err)
     91 	}
     92 }
     93 
     94 // Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors
     95 func (conn *DBConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
     96 	exists, err := query.Exists(ctx)
     97 
     98 	// Process error as our own and check if it exists
     99 	switch err := conn.ProcessError(err); err {
    100 	case nil:
    101 		return exists, nil
    102 	case db.ErrNoEntries:
    103 		return false, nil
    104 	default:
    105 		return false, err
    106 	}
    107 }
    108 
    109 // NotExists is the functional opposite of conn.Exists()
    110 func (conn *DBConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
    111 	exists, err := conn.Exists(ctx, query)
    112 	return !exists, err
    113 }