gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

bundb.go (14721B)


      1 // GoToSocial
      2 // Copyright (C) GoToSocial Authors admin@gotosocial.org
      3 // SPDX-License-Identifier: AGPL-3.0-or-later
      4 //
      5 // This program is free software: you can redistribute it and/or modify
      6 // it under the terms of the GNU Affero General Public License as published by
      7 // the Free Software Foundation, either version 3 of the License, or
      8 // (at your option) any later version.
      9 //
     10 // This program is distributed in the hope that it will be useful,
     11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
     12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     13 // GNU Affero General Public License for more details.
     14 //
     15 // You should have received a copy of the GNU Affero General Public License
     16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
     17 
     18 package bundb
     19 
     20 import (
     21 	"context"
     22 	"crypto/tls"
     23 	"crypto/x509"
     24 	"database/sql"
     25 	"encoding/pem"
     26 	"errors"
     27 	"fmt"
     28 	"os"
     29 	"runtime"
     30 	"strconv"
     31 	"strings"
     32 	"time"
     33 
     34 	"codeberg.org/gruf/go-bytesize"
     35 	"github.com/google/uuid"
     36 	"github.com/jackc/pgx/v5"
     37 	"github.com/jackc/pgx/v5/stdlib"
     38 	"github.com/superseriousbusiness/gotosocial/internal/config"
     39 	"github.com/superseriousbusiness/gotosocial/internal/db"
     40 	"github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations"
     41 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
     42 	"github.com/superseriousbusiness/gotosocial/internal/id"
     43 	"github.com/superseriousbusiness/gotosocial/internal/log"
     44 	"github.com/superseriousbusiness/gotosocial/internal/state"
     45 	"github.com/superseriousbusiness/gotosocial/internal/tracing"
     46 	"github.com/uptrace/bun"
     47 	"github.com/uptrace/bun/dialect/pgdialect"
     48 	"github.com/uptrace/bun/dialect/sqlitedialect"
     49 	"github.com/uptrace/bun/migrate"
     50 
     51 	"modernc.org/sqlite"
     52 )
     53 
     54 var registerTables = []interface{}{
     55 	&gtsmodel.AccountToEmoji{},
     56 	&gtsmodel.StatusToEmoji{},
     57 	&gtsmodel.StatusToTag{},
     58 }
     59 
     60 // DBService satisfies the DB interface
     61 type DBService struct {
     62 	db.Account
     63 	db.Admin
     64 	db.Basic
     65 	db.Domain
     66 	db.Emoji
     67 	db.Instance
     68 	db.List
     69 	db.Media
     70 	db.Mention
     71 	db.Notification
     72 	db.Relationship
     73 	db.Report
     74 	db.Search
     75 	db.Session
     76 	db.Status
     77 	db.StatusBookmark
     78 	db.StatusFave
     79 	db.Timeline
     80 	db.User
     81 	db.Tombstone
     82 	conn *DBConn
     83 }
     84 
     85 // GetConn returns the underlying bun connection.
     86 // Should only be used in testing + exceptional circumstance.
     87 func (dbService *DBService) GetConn() *DBConn {
     88 	return dbService.conn
     89 }
     90 
     91 func doMigration(ctx context.Context, db *bun.DB) error {
     92 	migrator := migrate.NewMigrator(db, migrations.Migrations)
     93 
     94 	if err := migrator.Init(ctx); err != nil {
     95 		return err
     96 	}
     97 
     98 	group, err := migrator.Migrate(ctx)
     99 	if err != nil && !strings.Contains(err.Error(), "no migrations") {
    100 		return err
    101 	}
    102 
    103 	if group == nil || group.ID == 0 {
    104 		log.Info(ctx, "there are no new migrations to run")
    105 		return nil
    106 	}
    107 
    108 	log.Infof(ctx, "MIGRATED DATABASE TO %s", group)
    109 	return nil
    110 }
    111 
    112 // NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
    113 // Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
    114 func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
    115 	var conn *DBConn
    116 	var err error
    117 	t := strings.ToLower(config.GetDbType())
    118 
    119 	switch t {
    120 	case "postgres":
    121 		conn, err = pgConn(ctx)
    122 		if err != nil {
    123 			return nil, err
    124 		}
    125 	case "sqlite":
    126 		conn, err = sqliteConn(ctx)
    127 		if err != nil {
    128 			return nil, err
    129 		}
    130 	default:
    131 		return nil, fmt.Errorf("database type %s not supported for bundb", t)
    132 	}
    133 
    134 	// Add database query hooks.
    135 	conn.DB.AddQueryHook(queryHook{})
    136 	if config.GetTracingEnabled() {
    137 		conn.DB.AddQueryHook(tracing.InstrumentBun())
    138 	}
    139 
    140 	// execute sqlite pragmas *after* adding database hook;
    141 	// this allows the pragma queries to be logged
    142 	if t == "sqlite" {
    143 		if err := sqlitePragmas(ctx, conn); err != nil {
    144 			return nil, err
    145 		}
    146 	}
    147 
    148 	// table registration is needed for many-to-many, see:
    149 	// https://bun.uptrace.dev/orm/many-to-many-relation/
    150 	for _, t := range registerTables {
    151 		conn.RegisterModel(t)
    152 	}
    153 
    154 	// perform any pending database migrations: this includes
    155 	// the very first 'migration' on startup which just creates
    156 	// necessary tables
    157 	if err := doMigration(ctx, conn.DB); err != nil {
    158 		return nil, fmt.Errorf("db migration error: %s", err)
    159 	}
    160 
    161 	ps := &DBService{
    162 		Account: &accountDB{
    163 			conn:  conn,
    164 			state: state,
    165 		},
    166 		Admin: &adminDB{
    167 			conn:  conn,
    168 			state: state,
    169 		},
    170 		Basic: &basicDB{
    171 			conn: conn,
    172 		},
    173 		Domain: &domainDB{
    174 			conn:  conn,
    175 			state: state,
    176 		},
    177 		Emoji: &emojiDB{
    178 			conn:  conn,
    179 			state: state,
    180 		},
    181 		Instance: &instanceDB{
    182 			conn: conn,
    183 		},
    184 		List: &listDB{
    185 			conn:  conn,
    186 			state: state,
    187 		},
    188 		Media: &mediaDB{
    189 			conn:  conn,
    190 			state: state,
    191 		},
    192 		Mention: &mentionDB{
    193 			conn:  conn,
    194 			state: state,
    195 		},
    196 		Notification: &notificationDB{
    197 			conn:  conn,
    198 			state: state,
    199 		},
    200 		Relationship: &relationshipDB{
    201 			conn:  conn,
    202 			state: state,
    203 		},
    204 		Report: &reportDB{
    205 			conn:  conn,
    206 			state: state,
    207 		},
    208 		Search: &searchDB{
    209 			conn:  conn,
    210 			state: state,
    211 		},
    212 		Session: &sessionDB{
    213 			conn: conn,
    214 		},
    215 		Status: &statusDB{
    216 			conn:  conn,
    217 			state: state,
    218 		},
    219 		StatusBookmark: &statusBookmarkDB{
    220 			conn:  conn,
    221 			state: state,
    222 		},
    223 		StatusFave: &statusFaveDB{
    224 			conn:  conn,
    225 			state: state,
    226 		},
    227 		Timeline: &timelineDB{
    228 			conn:  conn,
    229 			state: state,
    230 		},
    231 		User: &userDB{
    232 			conn:  conn,
    233 			state: state,
    234 		},
    235 		Tombstone: &tombstoneDB{
    236 			conn:  conn,
    237 			state: state,
    238 		},
    239 		conn: conn,
    240 	}
    241 
    242 	// we can confidently return this useable service now
    243 	return ps, nil
    244 }
    245 
    246 func pgConn(ctx context.Context) (*DBConn, error) {
    247 	opts, err := deriveBunDBPGOptions() //nolint:contextcheck
    248 	if err != nil {
    249 		return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
    250 	}
    251 
    252 	sqldb := stdlib.OpenDB(*opts)
    253 
    254 	// Tune db connections for postgres, see:
    255 	// - https://bun.uptrace.dev/guide/running-bun-in-production.html#database-sql
    256 	// - https://www.alexedwards.net/blog/configuring-sqldb
    257 	sqldb.SetMaxOpenConns(maxOpenConns())     // x number of conns per CPU
    258 	sqldb.SetMaxIdleConns(2)                  // assume default 2; if max idle is less than max open, it will be automatically adjusted
    259 	sqldb.SetConnMaxLifetime(5 * time.Minute) // fine to kill old connections
    260 
    261 	conn := WrapDBConn(bun.NewDB(sqldb, pgdialect.New()))
    262 
    263 	// ping to check the db is there and listening
    264 	if err := conn.PingContext(ctx); err != nil {
    265 		return nil, fmt.Errorf("postgres ping: %s", err)
    266 	}
    267 
    268 	log.Info(ctx, "connected to POSTGRES database")
    269 	return conn, nil
    270 }
    271 
    272 func sqliteConn(ctx context.Context) (*DBConn, error) {
    273 	// validate db address has actually been set
    274 	address := config.GetDbAddress()
    275 	if address == "" {
    276 		return nil, fmt.Errorf("'%s' was not set when attempting to start sqlite", config.DbAddressFlag())
    277 	}
    278 
    279 	// Drop anything fancy from DB address
    280 	address = strings.Split(address, "?")[0]       // drop any provided query strings
    281 	address = strings.TrimPrefix(address, "file:") // we'll prepend this later ourselves
    282 
    283 	// build our own SQLite preferences
    284 	prefs := []string{
    285 		// use immediate transaction lock mode to fail quickly if tx can't lock
    286 		// see https://pkg.go.dev/modernc.org/sqlite#Driver.Open
    287 		"_txlock=immediate",
    288 	}
    289 
    290 	if address == ":memory:" {
    291 		log.Warn(ctx, "using sqlite in-memory mode; all data will be deleted when gts shuts down; this mode should only be used for debugging or running tests")
    292 
    293 		// Use random name for in-memory instead of ':memory:', so
    294 		// multiple in-mem databases can be created without conflict.
    295 		address = uuid.NewString()
    296 
    297 		// in-mem-specific preferences
    298 		prefs = append(prefs, []string{
    299 			"mode=memory",  // indicate in-memory mode using query
    300 			"cache=shared", // shared cache so that tests don't fail
    301 		}...)
    302 	}
    303 
    304 	// rebuild address string with our derived preferences
    305 	address = "file:" + address
    306 	for i, q := range prefs {
    307 		var prefix string
    308 		if i == 0 {
    309 			prefix = "?"
    310 		} else {
    311 			prefix = "&"
    312 		}
    313 		address += prefix + q
    314 	}
    315 
    316 	// Open new DB instance
    317 	sqldb, err := sql.Open("sqlite", address)
    318 	if err != nil {
    319 		if errWithCode, ok := err.(*sqlite.Error); ok {
    320 			err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()])
    321 		}
    322 		return nil, fmt.Errorf("could not open sqlite db with address %s: %w", address, err)
    323 	}
    324 
    325 	// Tune db connections for sqlite, see:
    326 	// - https://bun.uptrace.dev/guide/running-bun-in-production.html#database-sql
    327 	// - https://www.alexedwards.net/blog/configuring-sqldb
    328 	sqldb.SetMaxOpenConns(1)    // only 1 connection regardless of multiplier, see https://github.com/superseriousbusiness/gotosocial/issues/1407
    329 	sqldb.SetMaxIdleConns(1)    // only keep max 1 idle connection around
    330 	sqldb.SetConnMaxLifetime(0) // don't kill connections due to age
    331 
    332 	// Wrap Bun database conn in our own wrapper
    333 	conn := WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New()))
    334 
    335 	// ping to check the db is there and listening
    336 	if err := conn.PingContext(ctx); err != nil {
    337 		if errWithCode, ok := err.(*sqlite.Error); ok {
    338 			err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()])
    339 		}
    340 		return nil, fmt.Errorf("sqlite ping: %s", err)
    341 	}
    342 	log.Infof(ctx, "connected to SQLITE database with address %s", address)
    343 
    344 	return conn, nil
    345 }
    346 
    347 /*
    348 	HANDY STUFF
    349 */
    350 
    351 // maxOpenConns returns multiplier * GOMAXPROCS,
    352 // returning just 1 instead if multiplier < 1.
    353 func maxOpenConns() int {
    354 	multiplier := config.GetDbMaxOpenConnsMultiplier()
    355 	if multiplier < 1 {
    356 		return 1
    357 	}
    358 	return multiplier * runtime.GOMAXPROCS(0)
    359 }
    360 
    361 // deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options
    362 // with sensible defaults, or an error if it's not satisfied by the provided config.
    363 func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
    364 	// these are all optional, the db adapter figures out defaults
    365 	address := config.GetDbAddress()
    366 
    367 	// validate database
    368 	database := config.GetDbDatabase()
    369 	if database == "" {
    370 		return nil, errors.New("no database set")
    371 	}
    372 
    373 	var tlsConfig *tls.Config
    374 	switch config.GetDbTLSMode() {
    375 	case "", "disable":
    376 		break // nothing to do
    377 	case "enable":
    378 		/* #nosec G402 */
    379 		tlsConfig = &tls.Config{
    380 			InsecureSkipVerify: true,
    381 		}
    382 	case "require":
    383 		tlsConfig = &tls.Config{
    384 			InsecureSkipVerify: false,
    385 			ServerName:         address,
    386 			MinVersion:         tls.VersionTLS12,
    387 		}
    388 	}
    389 
    390 	if certPath := config.GetDbTLSCACert(); tlsConfig != nil && certPath != "" {
    391 		// load the system cert pool first -- we'll append the given CA cert to this
    392 		certPool, err := x509.SystemCertPool()
    393 		if err != nil {
    394 			return nil, fmt.Errorf("error fetching system CA cert pool: %s", err)
    395 		}
    396 
    397 		// open the file itself and make sure there's something in it
    398 		caCertBytes, err := os.ReadFile(certPath)
    399 		if err != nil {
    400 			return nil, fmt.Errorf("error opening CA certificate at %s: %s", certPath, err)
    401 		}
    402 		if len(caCertBytes) == 0 {
    403 			return nil, fmt.Errorf("ca cert at %s was empty", certPath)
    404 		}
    405 
    406 		// make sure we have a PEM block
    407 		caPem, _ := pem.Decode(caCertBytes)
    408 		if caPem == nil {
    409 			return nil, fmt.Errorf("could not parse cert at %s into PEM", certPath)
    410 		}
    411 
    412 		// parse the PEM block into the certificate
    413 		caCert, err := x509.ParseCertificate(caPem.Bytes)
    414 		if err != nil {
    415 			return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", certPath, err)
    416 		}
    417 
    418 		// we're happy, add it to the existing pool and then use this pool in our tls config
    419 		certPool.AddCert(caCert)
    420 		tlsConfig.RootCAs = certPool
    421 	}
    422 
    423 	cfg, _ := pgx.ParseConfig("")
    424 	if address != "" {
    425 		cfg.Host = address
    426 	}
    427 	if port := config.GetDbPort(); port > 0 {
    428 		cfg.Port = uint16(port)
    429 	}
    430 	if u := config.GetDbUser(); u != "" {
    431 		cfg.User = u
    432 	}
    433 	if p := config.GetDbPassword(); p != "" {
    434 		cfg.Password = p
    435 	}
    436 	if tlsConfig != nil {
    437 		cfg.TLSConfig = tlsConfig
    438 	}
    439 	cfg.Database = database
    440 	cfg.RuntimeParams["application_name"] = config.GetApplicationName()
    441 
    442 	return cfg, nil
    443 }
    444 
    445 // sqlitePragmas sets desired sqlite pragmas based on configured values, and
    446 // logs the results of the pragma queries. Errors if something goes wrong.
    447 func sqlitePragmas(ctx context.Context, conn *DBConn) error {
    448 	var pragmas [][]string
    449 	if mode := config.GetDbSqliteJournalMode(); mode != "" {
    450 		// Set the user provided SQLite journal mode
    451 		pragmas = append(pragmas, []string{"journal_mode", mode})
    452 	}
    453 
    454 	if mode := config.GetDbSqliteSynchronous(); mode != "" {
    455 		// Set the user provided SQLite synchronous mode
    456 		pragmas = append(pragmas, []string{"synchronous", mode})
    457 	}
    458 
    459 	if size := config.GetDbSqliteCacheSize(); size > 0 {
    460 		// Set the user provided SQLite cache size (in kibibytes)
    461 		// Prepend a '-' character to this to indicate to sqlite
    462 		// that we're giving kibibytes rather than num pages.
    463 		// https://www.sqlite.org/pragma.html#pragma_cache_size
    464 		s := "-" + strconv.FormatUint(uint64(size/bytesize.KiB), 10)
    465 		pragmas = append(pragmas, []string{"cache_size", s})
    466 	}
    467 
    468 	if timeout := config.GetDbSqliteBusyTimeout(); timeout > 0 {
    469 		t := strconv.FormatInt(timeout.Milliseconds(), 10)
    470 		pragmas = append(pragmas, []string{"busy_timeout", t})
    471 	}
    472 
    473 	for _, p := range pragmas {
    474 		pk := p[0]
    475 		pv := p[1]
    476 
    477 		if _, err := conn.DB.ExecContext(ctx, "PRAGMA ?=?", bun.Ident(pk), bun.Safe(pv)); err != nil {
    478 			return fmt.Errorf("error executing sqlite pragma %s: %w", pk, err)
    479 		}
    480 
    481 		var res string
    482 		if err := conn.DB.NewRaw("PRAGMA ?", bun.Ident(pk)).Scan(ctx, &res); err != nil {
    483 			return fmt.Errorf("error scanning sqlite pragma %s: %w", pv, err)
    484 		}
    485 
    486 		log.Infof(ctx, "sqlite pragma %s set to %s", pk, res)
    487 	}
    488 
    489 	return nil
    490 }
    491 
    492 /*
    493 	CONVERSION FUNCTIONS
    494 */
    495 
    496 func (dbService *DBService) TagStringToTag(ctx context.Context, t string, originAccountID string) (*gtsmodel.Tag, error) {
    497 	protocol := config.GetProtocol()
    498 	host := config.GetHost()
    499 	now := time.Now()
    500 
    501 	tag := &gtsmodel.Tag{}
    502 	// we can use selectorinsert here to create the new tag if it doesn't exist already
    503 	// inserted will be true if this is a new tag we just created
    504 	if err := dbService.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil && err != sql.ErrNoRows {
    505 		return nil, fmt.Errorf("error getting tag with name %s: %s", t, err)
    506 	}
    507 
    508 	if tag.ID == "" {
    509 		// tag doesn't exist yet so populate it
    510 		newID, err := id.NewRandomULID()
    511 		if err != nil {
    512 			return nil, err
    513 		}
    514 		tag.ID = newID
    515 		tag.URL = protocol + "://" + host + "/tags/" + t
    516 		tag.Name = t
    517 		tag.FirstSeenFromAccountID = originAccountID
    518 		tag.CreatedAt = now
    519 		tag.UpdatedAt = now
    520 		useable := true
    521 		tag.Useable = &useable
    522 		listable := true
    523 		tag.Listable = &listable
    524 	}
    525 
    526 	// bail already if the tag isn't useable
    527 	if !*tag.Useable {
    528 		return nil, fmt.Errorf("tag %s is not useable", t)
    529 	}
    530 	tag.LastStatusAt = now
    531 	return tag, nil
    532 }