commit aa07750bdb4dacdb1be39d765114915bba3fc29f parent e58a6a2da3b808ca21d1ef1c1ed83ad932dd9dd6 Author: tobi <31960611+tsmethurst@users.noreply.github.com> Date: Sat, 8 Oct 2022 13:50:48 +0200 [chore] Standardize database queries, use `bun.Ident()` properly (#886) * use bun.Ident for user queries * use bun.Ident for account queries * use bun.Ident for media queries * add DeleteAccount func * remove CaseInsensitive in Where+use Ident ipv Safe * update admin db * update domain, use ident * update emoji, use ident * update instance queries, use bun.Ident * fix media * update mentions, use bun ident * update relationship + tests * use tableexpr * add test follows to bun db test suite * update notifications * updatebyprimarykey => updatebyid * fix session * prefer explicit ID to pk * fix little fucky wucky * remove workaround * use proper db func for attachment selection * update status db * add m2m entries in test rig * fix up timeline * go fmt * fix status put issue * update GetAccountStatuses Diffstat:
45 files changed, 1032 insertions(+), 528 deletions(-)
diff --git a/cmd/gotosocial/action/admin/account/account.go b/cmd/gotosocial/action/admin/account/account.go @@ -101,7 +101,7 @@ var Confirm action.GTSAction = func(ctx context.Context) error { u.Email = u.UnconfirmedEmail u.ConfirmedAt = time.Now() u.UpdatedAt = time.Now() - if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil { + if err := dbConn.UpdateByID(ctx, u, u.ID, updatingColumns...); err != nil { return err } diff --git a/internal/cache/account.go b/internal/cache/account.go @@ -101,6 +101,11 @@ func (c *AccountCache) Put(account *gtsmodel.Account) { c.cache.Set(account.ID, copyAccount(account)) } +// Invalidate removes (invalidates) one account from the cache by its ID. +func (c *AccountCache) Invalidate(id string) { + c.cache.Invalidate(id) +} + // copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects. // due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr) // this should be a relatively cheap process diff --git a/internal/db/account.go b/internal/db/account.go @@ -48,6 +48,11 @@ type Account interface { // UpdateAccount updates one account by ID. UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) + // DeleteAccount deletes one account from the database by its ID. + // DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the + // account as suspended instead, rather than deleting from the db entirely. + DeleteAccount(ctx context.Context, id string) Error + // GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username. GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, Error) diff --git a/internal/db/basic.go b/internal/db/basic.go @@ -62,11 +62,11 @@ type Basic interface { // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. Put(ctx context.Context, i interface{}) Error - // UpdateByPrimaryKey updates values of i based on its primary key. + // UpdateByID updates values of i based on its id. // If any columns are specified, these will be updated exclusively. // Otherwise, the whole model will be updated. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - UpdateByPrimaryKey(ctx context.Context, i interface{}, columns ...string) Error + UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) Error // UpdateWhere updates column key of interface i with the given value, where the given parameters apply. UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go @@ -21,7 +21,6 @@ package bundb import ( "context" "errors" - "fmt" "strings" "time" @@ -56,7 +55,7 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac return a.cache.GetByID(id) }, func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx) + return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx) }, ) } @@ -68,7 +67,7 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel. return a.cache.GetByURI(uri) }, func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx) + return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx) }, ) } @@ -80,7 +79,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel. return a.cache.GetByURL(url) }, func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx) + return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx) }, ) } @@ -95,11 +94,11 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str q := a.newAccountQ(account) if domain != "" { - q = q.Where("account.username = ?", username) - q = q.Where("account.domain = ?", domain) + q = q.Where("? = ?", bun.Ident("account.username"), username) + q = q.Where("? = ?", bun.Ident("account.domain"), domain) } else { - q = q.Where("account.username = ?", strings.ToLower(username)) - q = q.Where("account.domain IS NULL") + q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username)) + q = q.Where("? IS NULL", bun.Ident("account.domain")) } return q.Scan(ctx) @@ -114,7 +113,7 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo return a.cache.GetByPubkeyID(id) }, func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("account.public_key_uri = ?", id).Scan(ctx) + return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx) }, ) } @@ -169,26 +168,36 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { // create links between this account and any emojis it uses // first clear out any old emoji links - if _, err := tx.NewDelete(). - Model(&[]*gtsmodel.AccountToEmoji{}). - Where("account_id = ?", account.ID). + if _, err := tx. + NewDelete(). + TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). + Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID). Exec(ctx); err != nil { return err } // now populate new emoji links for _, i := range account.EmojiIDs { - if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ - AccountID: account.ID, - EmojiID: i, - }).Exec(ctx); err != nil { + if _, err := tx. + NewInsert(). + Model(>smodel.AccountToEmoji{ + AccountID: account.ID, + EmojiID: i, + }).Exec(ctx); err != nil { return err } } // update the account - _, err := tx.NewUpdate().Model(account).WherePK().Exec(ctx) - return err + if _, err := tx. + NewUpdate(). + Model(account). + Where("? = ?", bun.Ident("account.id"), account.ID). + Exec(ctx); err != nil { + return err + } + + return nil }); err != nil { return nil, a.conn.ProcessError(err) } @@ -197,6 +206,32 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account return account, nil } +func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { + if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { + // clear out any emoji links + if _, err := tx. + NewDelete(). + TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). + Where("? = ?", bun.Ident("account_to_emoji.account_id"), id). + Exec(ctx); err != nil { + return err + } + + // delete the account + _, err := tx. + NewUpdate(). + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Where("? = ?", bun.Ident("account.id"), id). + Exec(ctx) + return err + }); err != nil { + return a.conn.ProcessError(err) + } + + a.cache.Invalidate(id) + return nil +} + func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { account := new(gtsmodel.Account) @@ -204,11 +239,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts if domain != "" { q = q. - Where("account.username = ?", domain). - Where("account.domain = ?", domain) + Where("? = ?", bun.Ident("account.username"), domain). + Where("? = ?", bun.Ident("account.domain"), domain) } else { q = q. - Where("account.username = ?", config.GetHost()). + Where("? = ?", bun.Ident("account.username"), config.GetHost()). WhereGroup(" AND ", whereEmptyOrNull("domain")) } @@ -224,10 +259,10 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) q := a.conn. NewSelect(). Model(status). - Order("id DESC"). - Limit(1). - Where("account_id = ?", accountID). - Column("created_at") + Column("status.created_at"). + Where("? = ?", bun.Ident("status.account_id"), accountID). + Order("status.id DESC"). + Limit(1) if err := q.Scan(ctx); err != nil { return time.Time{}, a.conn.ProcessError(err) @@ -240,12 +275,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen return errors.New("one media attachment cannot be both header and avatar") } - var headerOrAVI string + var column bun.Ident switch { case *mediaAttachment.Avatar: - headerOrAVI = "avatar" + column = bun.Ident("account.avatar_media_attachment_id") case *mediaAttachment.Header: - headerOrAVI = "header" + column = bun.Ident("account.header_media_attachment_id") default: return errors.New("given media attachment was neither a header nor an avatar") } @@ -257,11 +292,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen Exec(ctx); err != nil { return a.conn.ProcessError(err) } + if _, err := a.conn. NewUpdate(). - Model(>smodel.Account{}). - Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID). - Where("id = ?", accountID). + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Set("? = ?", column, mediaAttachment.ID). + Where("? = ?", bun.Ident("account.id"), accountID). Exec(ctx); err != nil { return a.conn.ProcessError(err) } @@ -284,7 +320,7 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g if err := a.conn. NewSelect(). Model(faves). - Where("account_id = ?", accountID). + Where("? = ?", bun.Ident("status_fave.account_id"), accountID). Scan(ctx); err != nil { return nil, a.conn.ProcessError(err) } @@ -295,8 +331,8 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) { return a.conn. NewSelect(). - Model(>smodel.Status{}). - Where("account_id = ?", accountID). + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Where("? = ?", bun.Ident("status.account_id"), accountID). Count(ctx) } @@ -305,12 +341,12 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li q := a.conn. NewSelect(). - Table("statuses"). - Column("id"). - Order("id DESC") + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Column("status.id"). + Order("status.id DESC") if accountID != "" { - q = q.Where("account_id = ?", accountID) + q = q.Where("? = ?", bun.Ident("status.account_id"), accountID) } if limit != 0 { @@ -321,27 +357,27 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li // include self-replies (threads) whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { return q. - WhereOr("in_reply_to_account_id = ?", accountID). - WhereGroup(" OR ", whereEmptyOrNull("in_reply_to_uri")) + WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID). + WhereGroup(" OR ", whereEmptyOrNull("status.in_reply_to_uri")) } q = q.WhereGroup(" AND ", whereGroup) } if excludeReblogs { - q = q.WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")) + q = q.WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")) } if maxID != "" { - q = q.Where("id < ?", maxID) + q = q.Where("? < ?", bun.Ident("status.id"), maxID) } if minID != "" { - q = q.Where("id > ?", minID) + q = q.Where("? > ?", bun.Ident("status.id"), minID) } if pinnedOnly { - q = q.Where("pinned = ?", true) + q = q.Where("? = ?", bun.Ident("status.pinned"), true) } if mediaOnly { @@ -352,15 +388,15 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li switch a.conn.Dialect().Name() { case dialect.PG: return q. - Where("? IS NOT NULL", bun.Ident("attachments")). - Where("? != '{}'", bun.Ident("attachments")) + Where("? IS NOT NULL", bun.Ident("status.attachments")). + Where("? != '{}'", bun.Ident("status.attachments")) case dialect.SQLite: return q. - Where("? IS NOT NULL", bun.Ident("attachments")). - Where("? != ''", bun.Ident("attachments")). - Where("? != 'null'", bun.Ident("attachments")). - Where("? != '{}'", bun.Ident("attachments")). - Where("? != '[]'", bun.Ident("attachments")) + Where("? IS NOT NULL", bun.Ident("status.attachments")). + Where("? != ''", bun.Ident("status.attachments")). + Where("? != 'null'", bun.Ident("status.attachments")). + Where("? != '{}'", bun.Ident("status.attachments")). + Where("? != '[]'", bun.Ident("status.attachments")) default: log.Panic("db dialect was neither pg nor sqlite") return q @@ -369,7 +405,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li } if publicOnly { - q = q.Where("visibility = ?", gtsmodel.VisibilityPublic) + q = q.Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic) } if err := q.Scan(ctx, &statusIDs); err != nil { @@ -384,19 +420,19 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, q := a.conn. NewSelect(). - Table("statuses"). - Column("id"). - Where("account_id = ?", accountID). - WhereGroup(" AND ", whereEmptyOrNull("in_reply_to_uri")). - WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")). - Where("visibility = ?", gtsmodel.VisibilityPublic). - Where("federated = ?", true) + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Column("status.id"). + Where("? = ?", bun.Ident("status.account_id"), accountID). + WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")). + WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")). + Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). + Where("? = ?", bun.Ident("status.federated"), true) if maxID != "" { - q = q.Where("id < ?", maxID) + q = q.Where("? < ?", bun.Ident("status.id"), maxID) } - q = q.Limit(limit).Order("id DESC") + q = q.Limit(limit).Order("status.id DESC") if err := q.Scan(ctx, &statusIDs); err != nil { return nil, a.conn.ProcessError(err) @@ -411,16 +447,16 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI fq := a.conn. NewSelect(). Model(&blocks). - Where("block.account_id = ?", accountID). + Where("? = ?", bun.Ident("block.account_id"), accountID). Relation("TargetAccount"). Order("block.id DESC") if maxID != "" { - fq = fq.Where("block.id < ?", maxID) + fq = fq.Where("? < ?", bun.Ident("block.id"), maxID) } if sinceID != "" { - fq = fq.Where("block.id > ?", sinceID) + fq = fq.Where("? > ?", bun.Ident("block.id"), sinceID) } if limit > 0 { diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go @@ -42,6 +42,18 @@ func (suite *AccountTestSuite) TestGetAccountStatuses() { suite.Len(statuses, 5) } +func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogs() { + statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false, false) + suite.NoError(err) + suite.Len(statuses, 5) +} + +func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogsPublicOnly() { + statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false, true) + suite.NoError(err) + suite.Len(statuses, 1) +} + func (suite *AccountTestSuite) TestGetAccountStatusesMediaOnly() { statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, false, false, "", "", false, true, false) suite.NoError(err) @@ -99,7 +111,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() { err = dbService.GetConn(). NewSelect(). Model(noCache). - Where("account.id = ?", bun.Ident(testAccount.ID)). + Where("? = ?", bun.Ident("account.id"), testAccount.ID). Relation("AvatarMediaAttachment"). Relation("HeaderMediaAttachment"). Relation("Emojis"). @@ -127,7 +139,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() { err = dbService.GetConn(). NewSelect(). Model(noCache). - Where("account.id = ?", bun.Ident(testAccount.ID)). + Where("? = ?", bun.Ident("account.id"), testAccount.ID). Relation("AvatarMediaAttachment"). Relation("HeaderMediaAttachment"). Relation("Emojis"). diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go @@ -22,7 +22,6 @@ import ( "context" "crypto/rand" "crypto/rsa" - "database/sql" "fmt" "net" "net/mail" @@ -37,21 +36,26 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/uris" + "github.com/uptrace/bun" "golang.org/x/crypto/bcrypt" ) +// generate RSA keys of this length +const rsaKeyBits = 2048 + type adminDB struct { - conn *DBConn - userCache *cache.UserCache + conn *DBConn + userCache *cache.UserCache + accountCache *cache.AccountCache } func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { q := a.conn. NewSelect(). - Model(>smodel.Account{}). - Where("username = ?", username). - Where("domain = ?", nil) - + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Column("account.id"). + Where("? = ?", bun.Ident("account.username"), username). + Where("? IS NULL", bun.Ident("account.domain")) return a.conn.NotExists(ctx, q) } @@ -64,29 +68,31 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db. domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ // check if the email domain is blocked - if err := a.conn. + emailDomainBlockedQ := a.conn. NewSelect(). - Model(>smodel.EmailDomainBlock{}). - Where("domain = ?", domain). - Scan(ctx); err == nil { - // fail because we found something + TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")). + Column("email_domain_block.id"). + Where("? = ?", bun.Ident("email_domain_block.domain"), domain) + emailDomainBlocked, err := a.conn.Exists(ctx, emailDomainBlockedQ) + if err != nil { + return false, err + } + if emailDomainBlocked { return false, fmt.Errorf("email domain %s is blocked", domain) - } else if err != sql.ErrNoRows { - return false, a.conn.ProcessError(err) } // check if this email is associated with a user already q := a.conn. NewSelect(). - Model(>smodel.User{}). - Where("email = ?", email). - WhereOr("unconfirmed_email = ?", email) - + TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")). + Column("user.id"). + Where("? = ?", bun.Ident("user.email"), email). + WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email) return a.conn.NotExists(ctx, q) } func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) + key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits) if err != nil { log.Errorf("error creating new rsa key: %s", err) return nil, err @@ -94,13 +100,20 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, // if something went wrong while creating a user, we might already have an account, so check here first... acct := >smodel.Account{} - q := a.conn.NewSelect(). + if err := a.conn. + NewSelect(). Model(acct). - Where("username = ?", username). - WhereGroup(" AND ", whereEmptyOrNull("domain")) + Where("? = ?", bun.Ident("account.username"), username). + WhereGroup(" AND ", whereEmptyOrNull("account.domain")). + Scan(ctx); err != nil { + err = a.conn.ProcessError(err) + if err != db.ErrNoEntries { + log.Errorf("error checking for existing account: %s", err) + return nil, err + } - if err := q.Scan(ctx); err != nil { - // we just don't have an account yet so create one before we proceed + // if we have db.ErrNoEntries, we just don't have an + // account yet so create one before we proceed accountURIs := uris.GenerateURIsForAccount(username) accountID, err := id.NewRandomULID() if err != nil { @@ -126,14 +139,19 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, FeaturedCollectionURI: accountURIs.CollectionURI, } + // insert the new account! if _, err = a.conn. NewInsert(). Model(acct). Exec(ctx); err != nil { return nil, a.conn.ProcessError(err) } + a.accountCache.Put(acct) } + // we either created or already had an account by now, + // so proceed with creating a user for that account + pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return nil, fmt.Errorf("error hashing password: %s", err) @@ -171,6 +189,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, u.Moderator = &moderator } + // insert the user! if _, err = a.conn. NewInsert(). Model(u). @@ -187,9 +206,10 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { q := a.conn. NewSelect(). - Model(>smodel.Account{}). - Where("username = ?", username). - WhereGroup(" AND ", whereEmptyOrNull("domain")) + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Column("account.id"). + Where("? = ?", bun.Ident("account.username"), username). + WhereGroup(" AND ", whereEmptyOrNull("account.domain")) exists, err := a.conn.Exists(ctx, q) if err != nil { @@ -200,7 +220,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { return nil } - key, err := rsa.GenerateKey(rand.Reader, 2048) + key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits) if err != nil { log.Errorf("error creating new rsa key: %s", err) return err @@ -237,6 +257,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { return a.conn.ProcessError(err) } + a.accountCache.Put(acct) log.Infof("instance account %s CREATED with id %s", username, acct.ID) return nil } @@ -248,8 +269,9 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error { // check if instance entry already exists q := a.conn. NewSelect(). - Model(>smodel.Instance{}). - Where("domain = ?", host) + Column("instance.id"). + TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")). + Where("? = ?", bun.Ident("instance.domain"), host) exists, err := a.conn.Exists(ctx, q) if err != nil { diff --git a/internal/db/bundb/admin_test.go b/internal/db/bundb/admin_test.go @@ -23,6 +23,7 @@ import ( "testing" "github.com/stretchr/testify/suite" + gtsmodel "github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations/20211113114307_init" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -30,6 +31,44 @@ type AdminTestSuite struct { BunDBStandardTestSuite } +func (suite *AdminTestSuite) TestIsUsernameAvailableNo() { + available, err := suite.db.IsUsernameAvailable(context.Background(), "the_mighty_zork") + suite.NoError(err) + suite.False(available) +} + +func (suite *AdminTestSuite) TestIsUsernameAvailableYes() { + available, err := suite.db.IsUsernameAvailable(context.Background(), "someone_completely_different") + suite.NoError(err) + suite.True(available) +} + +func (suite *AdminTestSuite) TestIsEmailAvailableNo() { + available, err := suite.db.IsEmailAvailable(context.Background(), "zork@example.org") + suite.NoError(err) + suite.False(available) +} + +func (suite *AdminTestSuite) TestIsEmailAvailableYes() { + available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com") + suite.NoError(err) + suite.True(available) +} + +func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() { + if err := suite.db.Put(context.Background(), >smodel.EmailDomainBlock{ + ID: "01GEEV2R2YC5GRSN96761YJE47", + Domain: "somewhere.com", + CreatedByAccountID: suite.testAccounts["admin_account"].ID, + }); err != nil { + suite.FailNow(err.Error()) + } + + available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com") + suite.EqualError(err, "email domain somewhere.com is blocked") + suite.False(available) +} + func (suite *AdminTestSuite) TestCreateInstanceAccount() { // we need to take an empty db for this... testrig.StandardDBTeardown(suite.db) diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go @@ -94,12 +94,12 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface return b.conn.ProcessError(err) } -func (b *basicDB) UpdateByPrimaryKey(ctx context.Context, i interface{}, columns ...string) db.Error { +func (b *basicDB) UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) db.Error { q := b.conn. NewUpdate(). Model(i). Column(columns...). - WherePK() + Where("? = ?", bun.Ident("id"), id) _, err := q.Exec(ctx) return b.conn.ProcessError(err) @@ -110,7 +110,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, updateWhere(q, where) - q = q.Set("? = ?", bun.Safe(key), value) + q = q.Set("? = ?", bun.Ident(key), value) _, err := q.Exec(ctx) return b.conn.ProcessError(err) diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go @@ -159,17 +159,11 @@ func NewBunDBService(ctx context.Context) (db.DB, error) { return nil, fmt.Errorf("db migration error: %s", err) } - // Create DB structs that require ptrs to each other - accounts := &accountDB{conn: conn, cache: cache.NewAccountCache()} - status := &statusDB{conn: conn, cache: cache.NewStatusCache()} - emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()} - timeline := &timelineDB{conn: conn} - - // Setup DB cross-referencing - accounts.status = status - status.accounts = accounts - timeline.status = status + // Prepare caches required by more than one struct + userCache := cache.NewUserCache() + accountCache := cache.NewAccountCache() + // Prepare other caches // Prepare mentions cache // TODO: move into internal/cache mentionCache := grufcache.New[string, *gtsmodel.Mention]() @@ -182,22 +176,30 @@ func NewBunDBService(ctx context.Context) (db.DB, error) { notifCache.SetTTL(time.Minute*5, false) notifCache.Start(time.Second * 10) - // Prepare other caches - blockCache := cache.NewDomainBlockCache() - userCache := cache.NewUserCache() + // Create DB structs that require ptrs to each other + accounts := &accountDB{conn: conn, cache: accountCache} + status := &statusDB{conn: conn, cache: cache.NewStatusCache()} + emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()} + timeline := &timelineDB{conn: conn} + + // Setup DB cross-referencing + accounts.status = status + status.accounts = accounts + timeline.status = status ps := &DBService{ Account: accounts, Admin: &adminDB{ - conn: conn, - userCache: userCache, + conn: conn, + userCache: userCache, + accountCache: accountCache, }, Basic: &basicDB{ conn: conn, }, Domain: &domainDB{ conn: conn, - cache: blockCache, + cache: cache.NewDomainBlockCache(), }, Emoji: emoji, Instance: &instanceDB{ diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go @@ -40,6 +40,7 @@ type BunDBStandardTestSuite struct { testStatuses map[string]*gtsmodel.Status testTags map[string]*gtsmodel.Tag testMentions map[string]*gtsmodel.Mention + testFollows map[string]*gtsmodel.Follow } func (suite *BunDBStandardTestSuite) SetupSuite() { @@ -52,6 +53,7 @@ func (suite *BunDBStandardTestSuite) SetupSuite() { suite.testStatuses = testrig.NewTestStatuses() suite.testTags = testrig.NewTestTags() suite.testMentions = testrig.NewTestMentions() + suite.testFollows = testrig.NewTestFollows() } func (suite *BunDBStandardTestSuite) SetupTest() { diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go @@ -28,6 +28,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" "golang.org/x/net/idna" ) @@ -95,7 +96,7 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel q := d.conn. NewSelect(). Model(block). - Where("domain = ?", domain). + Where("? = ?", bun.Ident("domain_block.domain"), domain). Limit(1) // Query database for domain block @@ -126,7 +127,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro // Attempt to delete domain block if _, err := d.conn.NewDelete(). Model((*gtsmodel.DomainBlock)(nil)). - Where("domain = ?", domain). + Where("? = ?", bun.Ident("domain_block.domain"), domain). Exec(ctx); err != nil { return d.conn.ProcessError(err) } diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go @@ -54,12 +54,12 @@ func (e *emojiDB) GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Er q := e.conn. NewSelect(). - Table("emojis"). - Column("id"). - Where("visible_in_picker = true"). - Where("disabled = false"). - Where("domain IS NULL"). - Order("shortcode ASC") + TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")). + Column("emoji.id"). + Where("? = ?", bun.Ident("emoji.visible_in_picker"), true). + Where("? = ?", bun.Ident("emoji.disabled"), false). + Where("? IS NULL", bun.Ident("emoji.domain")). + Order("emoji.shortcode ASC") if err := q.Scan(ctx, &emojiIDs); err != nil { return nil, e.conn.ProcessError(err) @@ -75,7 +75,7 @@ func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, return e.cache.GetByID(id) }, func(emoji *gtsmodel.Emoji) error { - return e.newEmojiQ(emoji).Where("emoji.id = ?", id).Scan(ctx) + return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx) }, ) } @@ -87,7 +87,7 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj return e.cache.GetByURI(uri) }, func(emoji *gtsmodel.Emoji) error { - return e.newEmojiQ(emoji).Where("emoji.uri = ?", uri).Scan(ctx) + return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx) }, ) } @@ -102,11 +102,11 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin q := e.newEmojiQ(emoji) if domain != "" { - q = q.Where("emoji.shortcode = ?", shortcode) - q = q.Where("emoji.domain = ?", domain) + q = q.Where("? = ?", bun.Ident("emoji.shortcode"), shortcode) + q = q.Where("? = ?", bun.Ident("emoji.domain"), domain) } else { - q = q.Where("emoji.shortcode = ?", strings.ToLower(shortcode)) - q = q.Where("emoji.domain IS NULL") + q = q.Where("? = ?", bun.Ident("emoji.shortcode"), strings.ToLower(shortcode)) + q = q.Where("? IS NULL", bun.Ident("emoji.domain")) } return q.Scan(ctx) diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go @@ -24,7 +24,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/uptrace/bun" ) @@ -35,15 +34,16 @@ type instanceDB struct { func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) { q := i.conn. NewSelect(). - Model(&[]*gtsmodel.Account{}). - Where("username != ?", domain). - Where("? IS NULL", bun.Ident("suspended_at")) + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Column("account.id"). + Where("? != ?", bun.Ident("account.username"), domain). + Where("? IS NULL", bun.Ident("account.suspended_at")) - if domain == config.GetHost() { + if domain == config.GetHost() || domain == config.GetAccountDomain() { // if the domain is *this* domain, just count where the domain field is null - q = q.WhereGroup(" AND ", whereEmptyOrNull("domain")) + q = q.WhereGroup(" AND ", whereEmptyOrNull("account.domain")) } else { - q = q.Where("domain = ?", domain) + q = q.Where("? = ?", bun.Ident("account.domain"), domain) } count, err := q.Count(ctx) @@ -56,15 +56,16 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) { q := i.conn. NewSelect(). - Model(&[]*gtsmodel.Status{}) + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")) - if domain == config.GetHost() { + if domain == config.GetHost() || domain == config.GetAccountDomain() { // if the domain is *this* domain, just count where local is true - q = q.Where("local = ?", true) + q = q.Where("? = ?", bun.Ident("status.local"), true) } else { // join on the domain of the account - q = q.Join("JOIN accounts AS account ON account.id = status.account_id"). - Where("account.domain = ?", domain) + q = q. + Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("account.id"), bun.Ident("status.account_id")). + Where("? = ?", bun.Ident("account.domain"), domain) } count, err := q.Count(ctx) @@ -77,14 +78,14 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) ( func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) { q := i.conn. NewSelect(). - Model(&[]*gtsmodel.Instance{}) + TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")) if domain == config.GetHost() { // if the domain is *this* domain, just count other instances it knows about // exclude domains that are blocked q = q. - Where("domain != ?", domain). - Where("? IS NULL", bun.Ident("suspended_at")) + Where("? != ?", bun.Ident("instance.domain"), domain). + Where("? IS NULL", bun.Ident("instance.suspended_at")) } else { // TODO: implement federated domain counting properly for remote domains return 0, nil @@ -103,10 +104,10 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool q := i.conn. NewSelect(). Model(&instances). - Where("domain != ?", config.GetHost()) + Where("? != ?", bun.Ident("instance.domain"), config.GetHost()) if !includeSuspended { - q = q.Where("? IS NULL", bun.Ident("suspended_at")) + q = q.Where("? IS NULL", bun.Ident("instance.suspended_at")) } if err := q.Scan(ctx); err != nil { @@ -117,17 +118,15 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool } func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { - log.Debug("GetAccountsForInstance") - accounts := []*gtsmodel.Account{} q := i.conn.NewSelect(). Model(&accounts). - Where("domain = ?", domain). - Order("id DESC") + Where("? = ?", bun.Ident("account.domain"), domain). + Order("account.id DESC") if maxID != "" { - q = q.Where("id < ?", maxID) + q = q.Where("? < ?", bun.Ident("account.id"), maxID) } if limit > 0 { diff --git a/internal/db/bundb/instance_test.go b/internal/db/bundb/instance_test.go @@ -0,0 +1,83 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/config" +) + +type InstanceTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *InstanceTestSuite) TestCountInstanceUsers() { + count, err := suite.db.CountInstanceUsers(context.Background(), config.GetHost()) + suite.NoError(err) + suite.Equal(4, count) +} + +func (suite *InstanceTestSuite) TestCountInstanceUsersRemote() { + count, err := suite.db.CountInstanceUsers(context.Background(), "fossbros-anonymous.io") + suite.NoError(err) + suite.Equal(1, count) +} + +func (suite *InstanceTestSuite) TestCountInstanceStatuses() { + count, err := suite.db.CountInstanceStatuses(context.Background(), config.GetHost()) + suite.NoError(err) + suite.Equal(16, count) +} + +func (suite *InstanceTestSuite) TestCountInstanceStatusesRemote() { + count, err := suite.db.CountInstanceStatuses(context.Background(), "fossbros-anonymous.io") + suite.NoError(err) + suite.Equal(1, count) +} + +func (suite *InstanceTestSuite) TestCountInstanceDomains() { + count, err := suite.db.CountInstanceDomains(context.Background(), config.GetHost()) + suite.NoError(err) + suite.Equal(2, count) +} + +func (suite *InstanceTestSuite) TestGetInstancePeers() { + peers, err := suite.db.GetInstancePeers(context.Background(), false) + suite.NoError(err) + suite.Len(peers, 2) +} + +func (suite *InstanceTestSuite) TestGetInstancePeersIncludeSuspended() { + peers, err := suite.db.GetInstancePeers(context.Background(), true) + suite.NoError(err) + suite.Len(peers, 2) +} + +func (suite *InstanceTestSuite) TestGetInstanceAccounts() { + accounts, err := suite.db.GetInstanceAccounts(context.Background(), "fossbros-anonymous.io", "", 10) + suite.NoError(err) + suite.Len(accounts, 1) +} + +func TestInstanceTestSuite(t *testing.T) { + suite.Run(t, new(InstanceTestSuite)) +} diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go @@ -42,7 +42,7 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M attachment := >smodel.MediaAttachment{} q := m.newMediaQ(attachment). - Where("media_attachment.id = ?", id) + Where("? = ?", bun.Ident("media_attachment.id"), id) if err := q.Scan(ctx); err != nil { return nil, m.conn.ProcessError(err) @@ -56,10 +56,10 @@ func (m *mediaDB) GetRemoteOlderThan(ctx context.Context, olderThan time.Time, l q := m.conn. NewSelect(). Model(&attachments). - Where("media_attachment.cached = true"). - Where("media_attachment.avatar = false"). - Where("media_attachment.header = false"). - Where("media_attachment.created_at < ?", olderThan). + Where("? = ?", bun.Ident("media_attachment.cached"), true). + Where("? = ?", bun.Ident("media_attachment.avatar"), false). + Where("? = ?", bun.Ident("media_attachment.header"), false). + Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan). WhereGroup(" AND ", whereNotEmptyAndNotNull("media_attachment.remote_url")). Order("media_attachment.created_at DESC") @@ -79,13 +79,13 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit q := m.newMediaQ(&attachments). WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery { return innerQ. - WhereOr("media_attachment.avatar = true"). - WhereOr("media_attachment.header = true") + WhereOr("? = ?", bun.Ident("media_attachment.avatar"), true). + WhereOr("? = ?", bun.Ident("media_attachment.header"), true) }). Order("media_attachment.id DESC") if maxID != "" { - q = q.Where("media_attachment.id < ?", maxID) + q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID) } if limit != 0 { @@ -103,15 +103,15 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim attachments := []*gtsmodel.MediaAttachment{} q := m.newMediaQ(&attachments). - Where("media_attachment.cached = true"). - Where("media_attachment.avatar = false"). - Where("media_attachment.header = false"). - Where("media_attachment.created_at < ?", olderThan). - Where("media_attachment.remote_url IS NULL"). - Where("media_attachment.status_id IS NULL") + Where("? = ?", bun.Ident("media_attachment.cached"), true). + Where("? = ?", bun.Ident("media_attachment.avatar"), false). + Where("? = ?", bun.Ident("media_attachment.header"), false). + Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan). + Where("? IS NULL", bun.Ident("media_attachment.remote_url")). + Where("? IS NULL", bun.Ident("media_attachment.status_id")) if maxID != "" { - q = q.Where("media_attachment.id < ?", maxID) + q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID) } if limit != 0 { diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go @@ -46,7 +46,7 @@ func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Ment mention := gtsmodel.Mention{} q := m.newMentionQ(&mention). - Where("mention.id = ?", id) + Where("? = ?", bun.Ident("mention.id"), id) if err := q.Scan(ctx); err != nil { return nil, m.conn.ProcessError(err) diff --git a/internal/db/bundb/migrations/20220612091800_duplicated_media_cleanup.go b/internal/db/bundb/migrations/20220612091800_duplicated_media_cleanup.go @@ -47,8 +47,8 @@ func init() { } if _, err := tx.NewDelete(). - Model(a). - WherePK(). + TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")). + Where("? = ?", bun.Ident("media_attachment.id"), a.ID). Exec(ctx); err != nil { l.Errorf("error deleting attachment with id %s: %s", a.ID, err) } else { diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go @@ -25,6 +25,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/uptrace/bun" ) type notificationDB struct { @@ -44,7 +45,7 @@ func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmo Relation("OriginAccount"). Relation("TargetAccount"). Relation("Status"). - WherePK() + Where("? = ?", bun.Ident("notification.id"), id) if err := q.Scan(ctx); err != nil { return nil, n.conn.ProcessError(err) @@ -67,24 +68,24 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, q := n.conn. NewSelect(). - Table("notifications"). - Column("id") + TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). + Column("notification.id") if maxID != "" { - q = q.Where("id < ?", maxID) + q = q.Where("? < ?", bun.Ident("notification.id"), maxID) } if sinceID != "" { - q = q.Where("id > ?", sinceID) + q = q.Where("? > ?", bun.Ident("notification.id"), sinceID) } for _, excludeType := range excludeTypes { - q = q.Where("notification_type != ?", excludeType) + q = q.Where("? != ?", bun.Ident("notification.notification_type"), excludeType) } q = q. - Where("target_account_id = ?", accountID). - Order("id DESC") + Where("? = ?", bun.Ident("notification.target_account_id"), accountID). + Order("notification.id DESC") if limit != 0 { q = q.Limit(limit) @@ -116,13 +117,12 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, func (n *notificationDB) ClearNotifications(ctx context.Context, accountID string) db.Error { if _, err := n.conn. NewDelete(). - Table("notifications"). - Where("target_account_id = ?", accountID). + TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). + Where("? = ?", bun.Ident("notification.target_account_id"), accountID). Exec(ctx); err != nil { return n.conn.ProcessError(err) } n.cache.Clear() - return nil } diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go @@ -51,26 +51,25 @@ func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery { func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) { q := r.conn. NewSelect(). - Model(>smodel.Block{}). - ExcludeColumn("id", "created_at", "updated_at", "uri"). - Limit(1) + TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). + Column("block.id") if eitherDirection { q = q. WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery { return inner. - Where("account_id = ?", account1). - Where("target_account_id = ?", account2) + Where("? = ?", bun.Ident("block.account_id"), account1). + Where("? = ?", bun.Ident("block.target_account_id"), account2) }). WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery { return inner. - Where("account_id = ?", account2). - Where("target_account_id = ?", account1) + Where("? = ?", bun.Ident("block.account_id"), account2). + Where("? = ?", bun.Ident("block.target_account_id"), account1) }) } else { q = q. - Where("account_id = ?", account1). - Where("target_account_id = ?", account2) + Where("? = ?", bun.Ident("block.account_id"), account1). + Where("? = ?", bun.Ident("block.target_account_id"), account2) } return r.conn.Exists(ctx, q) @@ -80,8 +79,8 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 block := >smodel.Block{} q := r.newBlockQ(block). - Where("block.account_id = ?", account1). - Where("block.target_account_id = ?", account2) + Where("? = ?", bun.Ident("block.account_id"), account1). + Where("? = ?", bun.Ident("block.target_account_id"), account2) if err := q.Scan(ctx); err != nil { return nil, r.conn.ProcessError(err) @@ -99,13 +98,13 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount if err := r.conn. NewSelect(). Model(follow). - Where("account_id = ?", requestingAccount). - Where("target_account_id = ?", targetAccount). + Column("follow.show_reblogs", "follow.notify"). + Where("? = ?", bun.Ident("follow.account_id"), requestingAccount). + Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount). Limit(1). Scan(ctx); err != nil { - if err != sql.ErrNoRows { - // a proper error - return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err) + if err := r.conn.ProcessError(err); err != db.ErrNoEntries { + return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err) } // no follow exists so these are all false rel.Following = false @@ -119,55 +118,56 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount } // check if the target account follows the requesting account - count, err := r.conn. + followedByQ := r.conn. NewSelect(). - Model(>smodel.Follow{}). - Where("account_id = ?", targetAccount). - Where("target_account_id = ?", requestingAccount). - Limit(1). - Count(ctx) + TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). + Column("follow.id"). + Where("? = ?", bun.Ident("follow.account_id"), targetAccount). + Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount) + followedBy, err := r.conn.Exists(ctx, followedByQ) if err != nil { - return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err) + return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err) } - rel.FollowedBy = count > 0 + rel.FollowedBy = followedBy - // check if the requesting account blocks the target account - count, err = r.conn.NewSelect(). - Model(>smodel.Block{}). - Where("account_id = ?", requestingAccount). - Where("target_account_id = ?", targetAccount). - Limit(1). - Count(ctx) + // check if there's a pending following request from requesting account to target account + requestedQ := r.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). + Column("follow_request.id"). + Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount). + Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount) + requested, err := r.conn.Exists(ctx, requestedQ) if err != nil { - return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err) + return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err) } - rel.Blocking = count > 0 + rel.Requested = requested - // check if the target account blocks the requesting account - count, err = r.conn. + // check if the requesting account is blocking the target account + blockingQ := r.conn. NewSelect(). - Model(>smodel.Block{}). - Where("account_id = ?", targetAccount). - Where("target_account_id = ?", requestingAccount). - Limit(1). - Count(ctx) + TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). + Column("block.id"). + Where("? = ?", bun.Ident("block.account_id"), requestingAccount). + Where("? = ?", bun.Ident("block.target_account_id"), targetAccount) + blocking, err := r.conn.Exists(ctx, blockingQ) if err != nil { - return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) + return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err) } - rel.BlockedBy = count > 0 + rel.Blocking = blocking - // check if there's a pending following request from requesting account to target account - count, err = r.conn. + // check if the requesting account is blocked by the target account + blockedByQ := r.conn. NewSelect(). - Model(>smodel.FollowRequest{}). - Where("account_id = ?", requestingAccount). - Where("target_account_id = ?", targetAccount). - Limit(1). - Count(ctx) + TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). + Column("block.id"). + Where("? = ?", bun.Ident("block.account_id"), targetAccount). + Where("? = ?", bun.Ident("block.target_account_id"), requestingAccount) + blockedBy, err := r.conn.Exists(ctx, blockedByQ) if err != nil { - return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) + return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err) } - rel.Requested = count > 0 + rel.BlockedBy = blockedBy return rel, nil } @@ -179,10 +179,10 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode q := r.conn. NewSelect(). - Model(>smodel.Follow{}). - Where("account_id = ?", sourceAccount.ID). - Where("target_account_id = ?", targetAccount.ID). - Limit(1) + TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). + Column("follow.id"). + Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID). + Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID) return r.conn.Exists(ctx, q) } @@ -194,9 +194,10 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g q := r.conn. NewSelect(). - Model(>smodel.FollowRequest{}). - Where("account_id = ?", sourceAccount.ID). - Where("target_account_id = ?", targetAccount.ID) + TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). + Column("follow_request.id"). + Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID). + Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID) return r.conn.Exists(ctx, q) } @@ -222,82 +223,98 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod } func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { - // make sure the original follow request exists - fr := >smodel.FollowRequest{} - if err := r.conn. - NewSelect(). - Model(fr). - Where("account_id = ?", originAccountID). - Where("target_account_id = ?", targetAccountID). - Scan(ctx); err != nil { - return nil, r.conn.ProcessError(err) - } + var follow *gtsmodel.Follow + + if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error { + // get original follow request + followRequest := >smodel.FollowRequest{} + if err := tx. + NewSelect(). + Model(followRequest). + Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID). + Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID). + Scan(ctx); err != nil { + return err + } - // create a new follow to 'replace' the request with - follow := >smodel.Follow{ - ID: fr.ID, - AccountID: originAccountID, - TargetAccountID: targetAccountID, - URI: fr.URI, - } + // create a new follow to 'replace' the request with + follow = >smodel.Follow{ + ID: followRequest.ID, + AccountID: originAccountID, + TargetAccountID: targetAccountID, + URI: followRequest.URI, + } - // if the follow already exists, just update the URI -- we don't need to do anything else - if _, err := r.conn. - NewInsert(). - Model(follow). - On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI). - Exec(ctx); err != nil { - return nil, r.conn.ProcessError(err) - } + // if the follow already exists, just update the URI -- we don't need to do anything else + if _, err := tx. + NewInsert(). + Model(follow). + On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI). + Exec(ctx); err != nil { + return err + } + + // now remove the follow request + if _, err := tx. + NewDelete(). + TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). + Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). + Exec(ctx); err != nil { + return err + } - // now remove the follow request - if _, err := r.conn. - NewDelete(). - Model(>smodel.FollowRequest{}). - Where("account_id = ?", originAccountID). - Where("target_account_id = ?", targetAccountID). - Exec(ctx); err != nil { + return nil + }); err != nil { return nil, r.conn.ProcessError(err) } + // return the new follow return follow, nil } func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) { - // first get the follow request out of the database - fr := >smodel.FollowRequest{} - if err := r.conn. - NewSelect(). - Model(fr). - Where("account_id = ?", originAccountID). - Where("target_account_id = ?", targetAccountID). - Scan(ctx); err != nil { - return nil, r.conn.ProcessError(err) - } + followRequest := >smodel.FollowRequest{} + + if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error { + // get original follow request + if err := tx. + NewSelect(). + Model(followRequest). + Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID). + Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID). + Scan(ctx); err != nil { + return err + } - // now delete it from the database by ID - if _, err := r.conn. - NewDelete(). - Model(>smodel.FollowRequest{ID: fr.ID}). - WherePK(). - Exec(ctx); err != nil { + // now delete it from the database by ID + if _, err := tx. + NewDelete(). + TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). + Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). + Exec(ctx); err != nil { + return err + } + + return nil + }); err != nil { return nil, r.conn.ProcessError(err) } // return the deleted follow request - return fr, nil + return followRequest, nil } func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) { followRequests := []*gtsmodel.FollowRequest{} q := r.newFollowQ(&followRequests). - Where("target_account_id = ?", accountID). + Where("? = ?", bun.Ident("follow_request.target_account_id"), accountID). Order("follow_request.updated_at DESC") if err := q.Scan(ctx); err != nil { return nil, r.conn.ProcessError(err) } + return followRequests, nil } @@ -305,21 +322,31 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string follows := []*gtsmodel.Follow{} q := r.newFollowQ(&follows). - Where("account_id = ?", accountID). + Where("? = ?", bun.Ident("follow.account_id"), accountID). Order("follow.updated_at DESC") if err := q.Scan(ctx); err != nil { return nil, r.conn.ProcessError(err) } + return follows, nil } func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { - return r.conn. + q := r.conn. NewSelect(). - Model(&[]*gtsmodel.Follow{}). - Where("account_id = ?", accountID). - Count(ctx) + TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")) + + if localOnly { + q = q. + Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.target_account_id"), bun.Ident("account.id")). + Where("? = ?", bun.Ident("follow.account_id"), accountID). + Where("? IS NULL", bun.Ident("account.domain")) + } else { + q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID) + } + + return q.Count(ctx) } func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { @@ -331,12 +358,12 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str Order("follow.updated_at DESC") if localOnly { - q = q.ColumnExpr("follow.*"). - Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)"). - Where("follow.target_account_id = ?", accountID). - WhereGroup(" AND ", whereEmptyOrNull("a.domain")) + q = q. + Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")). + Where("? = ?", bun.Ident("follow.target_account_id"), accountID). + Where("? IS NULL", bun.Ident("account.domain")) } else { - q = q.Where("target_account_id = ?", accountID) + q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID) } err := q.Scan(ctx) @@ -347,9 +374,18 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str } func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { - return r.conn. + q := r.conn. NewSelect(). - Model(&[]*gtsmodel.Follow{}). - Where("target_account_id = ?", accountID). - Count(ctx) + TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")) + + if localOnly { + q = q. + Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")). + Where("? = ?", bun.Ident("follow.target_account_id"), accountID). + Where("? IS NULL", bun.Ident("account.domain")) + } else { + q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID) + } + + return q.Count(ctx) } diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go @@ -20,7 +20,6 @@ package bundb_test import ( "context" - "errors" "testing" "github.com/stretchr/testify/suite" @@ -48,12 +47,14 @@ func (suite *RelationshipTestSuite) TestIsBlocked() { suite.False(blocked) // have account1 block account2 - suite.db.Put(ctx, >smodel.Block{ + if err := suite.db.Put(ctx, >smodel.Block{ ID: "01G202BCSXXJZ70BHB5KCAHH8C", URI: "http://localhost:8080/some_block_uri_1", AccountID: account1, TargetAccountID: account2, - }) + }); err != nil { + suite.FailNow(err.Error()) + } // account 1 now blocks account 2 blocked, err = suite.db.IsBlocked(ctx, account1, account2, false) @@ -75,62 +76,242 @@ func (suite *RelationshipTestSuite) TestIsBlocked() { } func (suite *RelationshipTestSuite) TestGetBlock() { - suite.Suite.T().Skip("TODO: implement") + ctx := context.Background() + + account1 := suite.testAccounts["local_account_1"].ID + account2 := suite.testAccounts["local_account_2"].ID + + if err := suite.db.Put(ctx, >smodel.Block{ + ID: "01G202BCSXXJZ70BHB5KCAHH8C", + URI: "http://localhost:8080/some_block_uri_1", + AccountID: account1, + TargetAccountID: account2, + }); err != nil { + suite.FailNow(err.Error()) + } + + block, err := suite.db.GetBlock(ctx, account1, account2) + suite.NoError(err) + suite.NotNil(block) + suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID) } func (suite *RelationshipTestSuite) TestGetRelationship() { - suite.Suite.T().Skip("TODO: implement") + requestingAccount := suite.testAccounts["local_account_1"] + targetAccount := suite.testAccounts["admin_account"] + + relationship, err := suite.db.GetRelationship(context.Background(), requestingAccount.ID, targetAccount.ID) + suite.NoError(err) + suite.NotNil(relationship) + + suite.True(relationship.Following) + suite.True(relationship.ShowingReblogs) + suite.False(relationship.Notifying) + suite.True(relationship.FollowedBy) + suite.False(relationship.Blocking) + suite.False(relationship.BlockedBy) + suite.False(relationship.Muting) + suite.False(relationship.MutingNotifications) + suite.False(relationship.Requested) + suite.False(relationship.DomainBlocking) + suite.False(relationship.Endorsed) + suite.Empty(relationship.Note) +} + +func (suite *RelationshipTestSuite) TestIsFollowingYes() { + requestingAccount := suite.testAccounts["local_account_1"] + targetAccount := suite.testAccounts["admin_account"] + isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount) + suite.NoError(err) + suite.True(isFollowing) } -func (suite *RelationshipTestSuite) TestIsFollowing() { - suite.Suite.T().Skip("TODO: implement") +func (suite *RelationshipTestSuite) TestIsFollowingNo() { + requestingAccount := suite.testAccounts["admin_account"] + targetAccount := suite.testAccounts["local_account_2"] + isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount) + suite.NoError(err) + suite.False(isFollowing) } func (suite *RelationshipTestSuite) TestIsMutualFollowing() { - suite.Suite.T().Skip("TODO: implement") + requestingAccount := suite.testAccounts["local_account_1"] + targetAccount := suite.testAccounts["admin_account"] + isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount) + suite.NoError(err) + suite.True(isMutualFollowing) +} + +func (suite *RelationshipTestSuite) TestIsMutualFollowingNo() { + requestingAccount := suite.testAccounts["local_account_1"] + targetAccount := suite.testAccounts["local_account_2"] + isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount) + suite.NoError(err) + suite.True(isMutualFollowing) } -func (suite *RelationshipTestSuite) AcceptFollowRequest() { - for _, account := range suite.testAccounts { - _, err := suite.db.AcceptFollowRequest(context.Background(), account.ID, "NON-EXISTENT-ID") - if err != nil && !errors.Is(err, db.ErrNoEntries) { - suite.Suite.Fail("error accepting follow request: %v", err) - } +func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() { + ctx := context.Background() + account := suite.testAccounts["admin_account"] + targetAccount := suite.testAccounts["local_account_2"] + + followRequest := >smodel.FollowRequest{ + ID: "01GEF753FWHCHRDWR0QEHBXM8W", + URI: "http://localhost:8080/weeeeeeeeeeeeeeeee", + AccountID: account.ID, + TargetAccountID: targetAccount.ID, } + + if err := suite.db.Put(ctx, followRequest); err != nil { + suite.FailNow(err.Error()) + } + + follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID) + suite.NoError(err) + suite.NotNil(follow) + suite.Equal(followRequest.URI, follow.URI) } -func (suite *RelationshipTestSuite) GetAccountFollowRequests() { - suite.Suite.T().Skip("TODO: implement") +func (suite *RelationshipTestSuite) TestAcceptFollowRequestNotExisting() { + ctx := context.Background() + account := suite.testAccounts["admin_account"] + targetAccount := suite.testAccounts["local_account_2"] + + follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID) + suite.ErrorIs(err, db.ErrNoEntries) + suite.Nil(follow) } -func (suite *RelationshipTestSuite) GetAccountFollows() { - suite.Suite.T().Skip("TODO: implement") +func (suite *RelationshipTestSuite) TestAcceptFollowRequestFollowAlreadyExists() { + ctx := context.Background() + account := suite.testAccounts["local_account_1"] + targetAccount := suite.testAccounts["admin_account"] + + // follow already exists in the db from local_account_1 -> admin_account + existingFollow := >smodel.Follow{} + if err := suite.db.GetByID(ctx, suite.testFollows["local_account_1_admin_account"].ID, existingFollow); err != nil { + suite.FailNow(err.Error()) + } + + followRequest := >smodel.FollowRequest{ + ID: "01GEF753FWHCHRDWR0QEHBXM8W", + URI: "http://localhost:8080/weeeeeeeeeeeeeeeee", + AccountID: account.ID, + TargetAccountID: targetAccount.ID, + } + + if err := suite.db.Put(ctx, followRequest); err != nil { + suite.FailNow(err.Error()) + } + + follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID) + suite.NoError(err) + suite.NotNil(follow) + + // uri should be equal to value of new/overlapping follow request + suite.NotEqual(followRequest.URI, existingFollow.URI) + suite.Equal(followRequest.URI, follow.URI) } -func (suite *RelationshipTestSuite) CountAccountFollows() { - suite.Suite.T().Skip("TODO: implement") +func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() { + ctx := context.Background() + account := suite.testAccounts["admin_account"] + targetAccount := suite.testAccounts["local_account_2"] + + followRequest := >smodel.FollowRequest{ + ID: "01GEF753FWHCHRDWR0QEHBXM8W", + URI: "http://localhost:8080/weeeeeeeeeeeeeeeee", + AccountID: account.ID, + TargetAccountID: targetAccount.ID, + } + + if err := suite.db.Put(ctx, followRequest); err != nil { + suite.FailNow(err.Error()) + } + + rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID) + suite.NoError(err) + suite.NotNil(rejectedFollowRequest) } -func (suite *RelationshipTestSuite) GetAccountFollowedBy() { - // TODO: more comprehensive tests here +func (suite *RelationshipTestSuite) TestRejectFollowRequestNotExisting() { + ctx := context.Background() + account := suite.testAccounts["admin_account"] + targetAccount := suite.testAccounts["local_account_2"] - for _, account := range suite.testAccounts { - var err error + rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID) + suite.ErrorIs(err, db.ErrNoEntries) + suite.Nil(rejectedFollowRequest) +} - _, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, false) - if err != nil { - suite.Suite.Fail("error checking accounts followed by: %v", err) - } +func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() { + ctx := context.Background() + account := suite.testAccounts["admin_account"] + targetAccount := suite.testAccounts["local_account_2"] - _, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, true) - if err != nil { - suite.Suite.Fail("error checking localOnly accounts followed by: %v", err) - } + followRequest := >smodel.FollowRequest{ + ID: "01GEF753FWHCHRDWR0QEHBXM8W", + URI: "http://localhost:8080/weeeeeeeeeeeeeeeee", + AccountID: account.ID, + TargetAccountID: targetAccount.ID, } + + if err := suite.db.Put(ctx, followRequest); err != nil { + suite.FailNow(err.Error()) + } + + followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID) + suite.NoError(err) + suite.Len(followRequests, 1) } -func (suite *RelationshipTestSuite) CountAccountFollowedBy() { - suite.Suite.T().Skip("TODO: implement") +func (suite *RelationshipTestSuite) TestGetAccountFollows() { + account := suite.testAccounts["local_account_1"] + follows, err := suite.db.GetAccountFollows(context.Background(), account.ID) + suite.NoError(err) + suite.Len(follows, 2) +} + +func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() { + account := suite.testAccounts["local_account_1"] + followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, true) + suite.NoError(err) + suite.Equal(2, followsCount) +} + +func (suite *RelationshipTestSuite) TestCountAccountFollows() { + account := suite.testAccounts["local_account_1"] + followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, false) + suite.NoError(err) + suite.Equal(2, followsCount) +} + +func (suite *RelationshipTestSuite) TestGetAccountFollowedBy() { + account := suite.testAccounts["local_account_1"] + follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, false) + suite.NoError(err) + suite.Len(follows, 2) +} + +func (suite *RelationshipTestSuite) TestGetAccountFollowedByLocalOnly() { + account := suite.testAccounts["local_account_1"] + follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, true) + suite.NoError(err) + suite.Len(follows, 2) +} + +func (suite *RelationshipTestSuite) TestCountAccountFollowedBy() { + account := suite.testAccounts["local_account_1"] + followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, false) + suite.NoError(err) + suite.Equal(2, followsCount) +} + +func (suite *RelationshipTestSuite) TestCountAccountFollowedByLocalOnly() { + account := suite.testAccounts["local_account_1"] + followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, true) + suite.NoError(err) + suite.Equal(2, followsCount) } func TestRelationshipTestSuite(t *testing.T) { diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go @@ -21,7 +21,6 @@ package bundb import ( "context" "crypto/rand" - "errors" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -35,29 +34,22 @@ type sessionDB struct { func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { rss := make([]*gtsmodel.RouterSession, 0, 1) - _, err := s.conn. + // get the first router session in the db or... + if err := s.conn. NewSelect(). Model(&rss). Limit(1). - Order("id DESC"). - Exec(ctx) - if err != nil { + Order("router_session.id DESC"). + Scan(ctx); err != nil { return nil, s.conn.ProcessError(err) } + // ... create a new one if len(rss) == 0 { - // no session created yet, so make one return s.createSession(ctx) } - if len(rss) != 1 { - // we asked for 1 so we should get 1 - return nil, errors.New("more than 1 router session was returned") - } - - // return the one session found - rs := rss[0] - return rs, nil + return rss[0], nil } func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { @@ -71,24 +63,23 @@ func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, return nil, err } - rid, err := id.NewULID() + id, err := id.NewULID() if err != nil { return nil, err } rs := >smodel.RouterSession{ - ID: rid, + ID: id, Auth: auth, Crypt: crypt, } - q := s.conn. + if _, err := s.conn. NewInsert(). - Model(rs) - - _, err = q.Exec(ctx) - if err != nil { + Model(rs). + Exec(ctx); err != nil { return nil, s.conn.ProcessError(err) } + return rs, nil } diff --git a/internal/db/bundb/session_test.go b/internal/db/bundb/session_test.go @@ -37,14 +37,13 @@ func (suite *SessionTestSuite) TestGetSession() { suite.NotEmpty(session.Crypt) suite.NotEmpty(session.ID) - // TODO -- the same session should be returned with consecutive selects - // right now there's an issue with bytea in bun, so uncomment this when that issue is fixed: https://github.com/uptrace/bun/issues/122 - // session2, err := suite.db.GetSession(context.Background()) - // suite.NoError(err) - // suite.NotNil(session2) - // suite.Equal(session.Auth, session2.Auth) - // suite.Equal(session.Crypt, session2.Crypt) - // suite.Equal(session.ID, session2.ID) + // the same session should be returned with consecutive selects + session2, err := suite.db.GetSession(context.Background()) + suite.NoError(err) + suite.NotNil(session2) + suite.Equal(session.Auth, session2.Auth) + suite.Equal(session.Crypt, session2.Crypt) + suite.Equal(session.ID, session2.ID) } func TestSessionTestSuite(t *testing.T) { diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go @@ -72,7 +72,7 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat return s.cache.GetByID(id) }, func(status *gtsmodel.Status) error { - return s.newStatusQ(status).Where("status.id = ?", id).Scan(ctx) + return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx) }, ) } @@ -84,7 +84,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St return s.cache.GetByURI(uri) }, func(status *gtsmodel.Status) error { - return s.newStatusQ(status).Where("status.uri = ?", uri).Scan(ctx) + return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx) }, ) } @@ -96,7 +96,7 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St return s.cache.GetByURL(url) }, func(status *gtsmodel.Status) error { - return s.newStatusQ(status).Where("status.url = ?", url).Scan(ctx) + return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx) }, ) } @@ -109,8 +109,7 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta status = >smodel.Status{} // Not cached! Perform database query - err := dbQuery(status) - if err != nil { + if err := dbQuery(status); err != nil { return nil, s.conn.ProcessError(err) } @@ -138,24 +137,34 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta } func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { - return s.conn.RunInTx(ctx, func(tx bun.Tx) error { + err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { // create links between this status and any emojis it uses for _, i := range status.EmojiIDs { - if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{ - StatusID: status.ID, - EmojiID: i, - }).Exec(ctx); err != nil { - return err + if _, err := tx. + NewInsert(). + Model(>smodel.StatusToEmoji{ + StatusID: status.ID, + EmojiID: i, + }).Exec(ctx); err != nil { + err = s.conn.errProc(err) + if !errors.Is(err, db.ErrAlreadyExists) { + return err + } } } // create links between this status and any tags it uses for _, i := range status.TagIDs { - if _, err := tx.NewInsert().Model(>smodel.StatusToTag{ - StatusID: status.ID, - TagID: i, - }).Exec(ctx); err != nil { - return err + if _, err := tx. + NewInsert(). + Model(>smodel.StatusToTag{ + StatusID: status.ID, + TagID: i, + }).Exec(ctx); err != nil { + err = s.conn.errProc(err) + if !errors.Is(err, db.ErrAlreadyExists) { + return err + } } } @@ -163,27 +172,46 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er for _, a := range status.Attachments { a.StatusID = status.ID a.UpdatedAt = time.Now() - if _, err := tx.NewUpdate().Model(a). - Where("id = ?", a.ID). + if _, err := tx. + NewUpdate(). + Model(a). + Where("? = ?", bun.Ident("media_attachment.id"), a.ID). Exec(ctx); err != nil { - return err + err = s.conn.errProc(err) + if !errors.Is(err, db.ErrAlreadyExists) { + return err + } } } // Finally, insert the status - _, err := tx.NewInsert().Model(status).Exec(ctx) - return err + if _, err := tx. + NewInsert(). + Model(status). + Exec(ctx); err != nil { + return err + } + + return nil }) + if err != nil { + return s.conn.ProcessError(err) + } + + s.cache.Put(status) + return nil } func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, db.Error) { err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { // create links between this status and any emojis it uses for _, i := range status.EmojiIDs { - if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{ - StatusID: status.ID, - EmojiID: i, - }).Exec(ctx); err != nil { + if _, err := tx. + NewInsert(). + Model(>smodel.StatusToEmoji{ + StatusID: status.ID, + EmojiID: i, + }).Exec(ctx); err != nil { err = s.conn.errProc(err) if !errors.Is(err, db.ErrAlreadyExists) { return err @@ -193,10 +221,12 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (* // create links between this status and any tags it uses for _, i := range status.TagIDs { - if _, err := tx.NewInsert().Model(>smodel.StatusToTag{ - StatusID: status.ID, - TagID: i, - }).Exec(ctx); err != nil { + if _, err := tx. + NewInsert(). + Model(>smodel.StatusToTag{ + StatusID: status.ID, + TagID: i, + }).Exec(ctx); err != nil { err = s.conn.errProc(err) if !errors.Is(err, db.ErrAlreadyExists) { return err @@ -208,23 +238,32 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (* for _, a := range status.Attachments { a.StatusID = status.ID a.UpdatedAt = time.Now() - if _, err := tx.NewUpdate().Model(a). - Where("id = ?", a.ID). + if _, err := tx. + NewUpdate(). + Model(a). + Where("? = ?", bun.Ident("media_attachment.id"), a.ID). Exec(ctx); err != nil { return err } } // Finally, update the status itself - if _, err := tx.NewUpdate().Model(status).WherePK().Exec(ctx); err != nil { + if _, err := tx. + NewUpdate(). + Model(status). + Where("? = ?", bun.Ident("status.id"), status.ID). + Exec(ctx); err != nil { return err } - s.cache.Put(status) return nil }) + if err != nil { + return nil, s.conn.ProcessError(err) + } - return status, err + s.cache.Put(status) + return status, nil } func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { @@ -232,8 +271,8 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { // delete links between this status and any emojis it uses if _, err := tx. NewDelete(). - Model(>smodel.StatusToEmoji{}). - Where("status_id = ?", id). + TableExpr("? AS ?", bun.Ident("status_to_emojis"), bun.Ident("status_to_emoji")). + Where("? = ?", bun.Ident("status_to_emoji.status_id"), id). Exec(ctx); err != nil { return err } @@ -241,8 +280,8 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { // delete links between this status and any tags it uses if _, err := tx. NewDelete(). - Model(>smodel.StatusToTag{}). - Where("status_id = ?", id). + TableExpr("? AS ?", bun.Ident("status_to_tags"), bun.Ident("status_to_tag")). + Where("? = ?", bun.Ident("status_to_tag.status_id"), id). Exec(ctx); err != nil { return err } @@ -250,17 +289,20 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { // delete the status itself if _, err := tx. NewDelete(). - Model(>smodel.Status{ID: id}). - WherePK(). + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Where("? = ?", bun.Ident("status.id"), id). Exec(ctx); err != nil { return err } - s.cache.Invalidate(id) return nil }) + if err != nil { + return s.conn.ProcessError(err) + } - return s.conn.ProcessError(err) + s.cache.Invalidate(id) + return nil } func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { @@ -312,11 +354,11 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, q := s.conn. NewSelect(). - Table("statuses"). - Column("id"). - Where("in_reply_to_id = ?", status.ID) + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Column("status.id"). + Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID) if minID != "" { - q = q.Where("id > ?", minID) + q = q.Where("? > ?", bun.Ident("status.id"), minID) } if err := q.Scan(ctx, &childIDs); err != nil { @@ -356,23 +398,35 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, } func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { - return s.conn.NewSelect().Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx) + return s.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID). + Count(ctx) } func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { - return s.conn.NewSelect().Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx) + return s.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Where("? = ?", bun.Ident("status.boost_of_id"), status.ID). + Count(ctx) } func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { - return s.conn.NewSelect().Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx) + return s.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). + Where("? = ?", bun.Ident("status_fave.status_id"), status.ID). + Count(ctx) } func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { q := s.conn. NewSelect(). - Model(>smodel.StatusFave{}). - Where("status_id = ?", status.ID). - Where("account_id = ?", accountID) + TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). + Where("? = ?", bun.Ident("status_fave.status_id"), status.ID). + Where("? = ?", bun.Ident("status_fave.account_id"), accountID) return s.conn.Exists(ctx, q) } @@ -380,9 +434,9 @@ func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { q := s.conn. NewSelect(). - Model(>smodel.Status{}). - Where("boost_of_id = ?", status.ID). - Where("account_id = ?", accountID) + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Where("? = ?", bun.Ident("status.boost_of_id"), status.ID). + Where("? = ?", bun.Ident("status.account_id"), accountID) return s.conn.Exists(ctx, q) } @@ -390,9 +444,9 @@ func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Sta func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { q := s.conn. NewSelect(). - Model(>smodel.StatusMute{}). - Where("status_id = ?", status.ID). - Where("account_id = ?", accountID) + TableExpr("? AS ?", bun.Ident("status_mutes"), bun.Ident("status_mute")). + Where("? = ?", bun.Ident("status_mute.status_id"), status.ID). + Where("? = ?", bun.Ident("status_mute.account_id"), accountID) return s.conn.Exists(ctx, q) } @@ -400,9 +454,9 @@ func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { q := s.conn. NewSelect(). - Model(>smodel.StatusBookmark{}). - Where("status_id = ?", status.ID). - Where("account_id = ?", accountID) + TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). + Where("? = ?", bun.Ident("status_bookmark.status_id"), status.ID). + Where("? = ?", bun.Ident("status_bookmark.account_id"), accountID) return s.conn.Exists(ctx, q) } @@ -410,8 +464,9 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) { faves := []*gtsmodel.StatusFave{} - q := s.newFaveQ(&faves). - Where("status_id = ?", status.ID) + q := s. + newFaveQ(&faves). + Where("? = ?", bun.Ident("status_fave.status_id"), status.ID) if err := q.Scan(ctx); err != nil { return nil, s.conn.ProcessError(err) @@ -422,8 +477,9 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) { reblogs := []*gtsmodel.Status{} - q := s.newStatusQ(&reblogs). - Where("boost_of_id = ?", status.ID) + q := s. + newStatusQ(&reblogs). + Where("? = ?", bun.Ident("status.boost_of_id"), status.ID) if err := q.Scan(ctx); err != nil { return nil, s.conn.ProcessError(err) diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go @@ -108,14 +108,14 @@ func (suite *StatusTestSuite) TestGetStatusTwice() { suite.NoError(err) after1 := time.Now() duration1 := after1.Sub(before1) - fmt.Println(duration1.Milliseconds()) + fmt.Println(duration1.Microseconds()) before2 := time.Now() _, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI) suite.NoError(err) after2 := time.Now() duration2 := after2.Sub(before2) - fmt.Println(duration2.Milliseconds()) + fmt.Println(duration2.Microseconds()) // second retrieval should be several orders faster since it will be cached now suite.Less(duration2, duration1) diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go @@ -34,38 +34,48 @@ type timelineDB struct { } func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { + // Ensure reasonable + if limit < 0 { + limit = 0 + } + // Make educated guess for slice size statusIDs := make([]string, 0, limit) q := t.conn. NewSelect(). - Table("statuses"). - + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). // Select only IDs from table - Column("statuses.id"). + Column("status.id"). // Find out who accountID follows. - Join("LEFT JOIN follows ON follows.target_account_id = statuses.account_id AND follows.account_id = ?", accountID). + Join("LEFT JOIN ? AS ? ON ? = ? AND ? = ?", + bun.Ident("follows"), + bun.Ident("follow"), + bun.Ident("follow.target_account_id"), + bun.Ident("status.account_id"), + bun.Ident("follow.account_id"), + accountID). // Sort by highest ID (newest) to lowest ID (oldest) - Order("statuses.id DESC") + Order("status.id DESC") if maxID != "" { // return only statuses LOWER (ie., older) than maxID - q = q.Where("statuses.id < ?", maxID) + q = q.Where("? < ?", bun.Ident("status.id"), maxID) } if sinceID != "" { // return only statuses HIGHER (ie., newer) than sinceID - q = q.Where("statuses.id > ?", sinceID) + q = q.Where("? > ?", bun.Ident("status.id"), sinceID) } if minID != "" { // return only statuses HIGHER (ie., newer) than minID - q = q.Where("statuses.id > ?", minID) + q = q.Where("? > ?", bun.Ident("status.id"), minID) } if local { // return only statuses posted by local account havers - q = q.Where("statuses.local = ?", local) + q = q.Where("? = ?", bun.Ident("status.local"), local) } if limit > 0 { @@ -78,13 +88,11 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI // // This is equivalent to something like WHERE ... AND (... OR ...) // See: https://bun.uptrace.dev/guide/queries.html#select - whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { + q = q.WhereGroup(" AND ", func(*bun.SelectQuery) *bun.SelectQuery { return q. - WhereOr("follows.account_id = ?", accountID). - WhereOr("statuses.account_id = ?", accountID) - } - - q = q.WhereGroup(" AND ", whereGroup) + WhereOr("? = ?", bun.Ident("follow.account_id"), accountID). + WhereOr("? = ?", bun.Ident("status.account_id"), accountID) + }) if err := q.Scan(ctx, &statusIDs); err != nil { return nil, t.conn.ProcessError(err) @@ -118,28 +126,28 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, ma q := t.conn. NewSelect(). - Table("statuses"). - Column("statuses.id"). - Where("statuses.visibility = ?", gtsmodel.VisibilityPublic). - WhereGroup(" AND ", whereEmptyOrNull("statuses.in_reply_to_id")). - WhereGroup(" AND ", whereEmptyOrNull("statuses.in_reply_to_uri")). - WhereGroup(" AND ", whereEmptyOrNull("statuses.boost_of_id")). - Order("statuses.id DESC") + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Column("status.id"). + Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). + WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_id")). + WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")). + WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")). + Order("status.id DESC") if maxID != "" { - q = q.Where("statuses.id < ?", maxID) + q = q.Where("? < ?", bun.Ident("status.id"), maxID) } if sinceID != "" { - q = q.Where("statuses.id > ?", sinceID) + q = q.Where("? > ?", bun.Ident("status.id"), sinceID) } if minID != "" { - q = q.Where("statuses.id > ?", minID) + q = q.Where("? > ?", bun.Ident("status.id"), minID) } if local { - q = q.Where("statuses.local = ?", local) + q = q.Where("? = ?", bun.Ident("status.local"), local) } if limit > 0 { @@ -181,15 +189,15 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max fq := t.conn. NewSelect(). Model(&faves). - Where("account_id = ?", accountID). - Order("id DESC") + Where("? = ?", bun.Ident("status_fave.account_id"), accountID). + Order("status_fave.id DESC") if maxID != "" { - fq = fq.Where("id < ?", maxID) + fq = fq.Where("? < ?", bun.Ident("status_fave.id"), maxID) } if minID != "" { - fq = fq.Where("id > ?", minID) + fq = fq.Where("? > ?", bun.Ident("status_fave.id"), minID) } if limit > 0 { diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go @@ -38,6 +38,15 @@ func (suite *TimelineTestSuite) TestGetPublicTimeline() { suite.Len(s, 6) } +func (suite *TimelineTestSuite) TestGetHomeTimeline() { + viewingAccount := suite.testAccounts["local_account_1"] + + s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false) + suite.NoError(err) + + suite.Len(s, 16) +} + func TestTimelineTestSuite(t *testing.T) { suite.Run(t, new(TimelineTestSuite)) } diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go @@ -67,7 +67,7 @@ func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db return u.cache.GetByID(id) }, func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("user.id = ?", id).Scan(ctx) + return u.newUserQ(user).Where("? = ?", bun.Ident("user.id"), id).Scan(ctx) }, ) } @@ -79,7 +79,7 @@ func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gts return u.cache.GetByAccountID(accountID) }, func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("user.account_id = ?", accountID).Scan(ctx) + return u.newUserQ(user).Where("? = ?", bun.Ident("user.account_id"), accountID).Scan(ctx) }, ) } @@ -91,7 +91,7 @@ func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) return u.cache.GetByEmail(emailAddress) }, func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("user.email = ?", emailAddress).Scan(ctx) + return u.newUserQ(user).Where("? = ?", bun.Ident("user.email"), emailAddress).Scan(ctx) }, ) } @@ -103,7 +103,7 @@ func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationTok return u.cache.GetByConfirmationToken(confirmationToken) }, func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("user.confirmation_token = ?", confirmationToken).Scan(ctx) + return u.newUserQ(user).Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken).Scan(ctx) }, ) } @@ -127,7 +127,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns .. if _, err := u.conn. NewUpdate(). Model(user). - WherePK(). + Where("? = ?", bun.Ident("user.id"), user.ID). Column(columns...). Exec(ctx); err != nil { return nil, u.conn.ProcessError(err) @@ -140,8 +140,8 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns .. func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error { if _, err := u.conn. NewDelete(). - Model(>smodel.User{ID: userID}). - WherePK(). + TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")). + Where("? = ?", bun.Ident("user.id"), userID). Exec(ctx); err != nil { return u.conn.ProcessError(err) } diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go @@ -85,14 +85,8 @@ func parseWhere(w db.Where) (query string, args []interface{}) { return } - if w.CaseInsensitive { - query = "LOWER(?) != LOWER(?)" - args = []interface{}{bun.Safe(w.Key), w.Value} - return - } - query = "? != ?" - args = []interface{}{bun.Safe(w.Key), w.Value} + args = []interface{}{bun.Ident(w.Key), w.Value} return } @@ -102,13 +96,7 @@ func parseWhere(w db.Where) (query string, args []interface{}) { return } - if w.CaseInsensitive { - query = "LOWER(?) = LOWER(?)" - args = []interface{}{bun.Safe(w.Key), w.Value} - return - } - query = "? = ?" - args = []interface{}{bun.Safe(w.Key), w.Value} + args = []interface{}{bun.Ident(w.Key), w.Value} return } diff --git a/internal/db/params.go b/internal/db/params.go @@ -24,9 +24,6 @@ type Where struct { Key string // The value to match. Value interface{} - // Whether the value (if a string) should be case sensitive or not. - // Defaults to false. - CaseInsensitive bool // If set, reverse the where. // `WHERE k = v` becomes `WHERE k != v`. // `WHERE k IS NULL` becomes `WHERE k IS NOT NULL` diff --git a/internal/media/processingmedia.go b/internal/media/processingmedia.go @@ -101,7 +101,7 @@ func (p *ProcessingMedia) LoadAttachment(ctx context.Context) (*gtsmodel.MediaAt if !p.insertedInDB { if p.recache { // if it's a recache we should only need to update - if err := p.database.UpdateByPrimaryKey(ctx, p.attachment); err != nil { + if err := p.database.UpdateByID(ctx, p.attachment, p.attachment.ID); err != nil { return nil, err } } else { diff --git a/internal/media/prunemeta_test.go b/internal/media/prunemeta_test.go @@ -40,7 +40,7 @@ func (suite *PruneMetaTestSuite) TestPruneMeta() { zork := suite.testAccounts["local_account_1"] zork.AvatarMediaAttachmentID = "" zork.HeaderMediaAttachmentID = "" - if err := suite.db.UpdateByPrimaryKey(ctx, zork, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil { + if err := suite.db.UpdateByID(ctx, zork, zork.ID, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil { panic(err) } @@ -72,7 +72,7 @@ func (suite *PruneMetaTestSuite) TestPruneMetaTwice() { zork := suite.testAccounts["local_account_1"] zork.AvatarMediaAttachmentID = "" zork.HeaderMediaAttachmentID = "" - if err := suite.db.UpdateByPrimaryKey(ctx, zork, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil { + if err := suite.db.UpdateByID(ctx, zork, zork.ID, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil { panic(err) } @@ -95,14 +95,14 @@ func (suite *PruneMetaTestSuite) TestPruneMetaMultipleAccounts() { zork := suite.testAccounts["local_account_1"] zork.AvatarMediaAttachmentID = "" zork.HeaderMediaAttachmentID = "" - if err := suite.db.UpdateByPrimaryKey(ctx, zork, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil { + if err := suite.db.UpdateByID(ctx, zork, zork.ID, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil { panic(err) } // set zork's unused header as belonging to turtle turtle := suite.testAccounts["local_account_1"] zorkOldHeader.AccountID = turtle.ID - if err := suite.db.UpdateByPrimaryKey(ctx, zorkOldHeader, "account_id"); err != nil { + if err := suite.db.UpdateByID(ctx, zorkOldHeader, zorkOldHeader.ID, "account_id"); err != nil { panic(err) } diff --git a/internal/media/pruneremote.go b/internal/media/pruneremote.go @@ -90,7 +90,7 @@ func (m *manager) pruneOneRemote(ctx context.Context, attachment *gtsmodel.Media // update the attachment to reflect that we no longer have it cached if changed { - return m.db.UpdateByPrimaryKey(ctx, attachment, "updated_at", "cached") + return m.db.UpdateByID(ctx, attachment, attachment.ID, "updated_at", "cached") } return nil diff --git a/internal/processing/admin/createdomainblock.go b/internal/processing/admin/createdomainblock.go @@ -128,15 +128,17 @@ func (p *processor) initiateDomainBlockSideEffects(ctx context.Context, account instance.ContactAccountUsername = "" instance.ContactAccountID = "" instance.Version = "" - if err := p.db.UpdateByPrimaryKey(ctx, instance, updatingColumns...); err != nil { + if err := p.db.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil { l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err) } l.Debug("domainBlockProcessSideEffects: instance entry updated") } // if we have an instance account for this instance, delete it - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, >smodel.Account{}); err != nil { - l.Errorf("domainBlockProcessSideEffects: db error removing instance account: %s", err) + if instanceAccount, err := p.db.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil { + if err := p.db.DeleteAccount(ctx, instanceAccount.ID); err != nil { + l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err) + } } // delete accounts through the normal account deletion system (which should also delete media + posts + remove posts from timelines) diff --git a/internal/processing/admin/deletedomainblock.go b/internal/processing/admin/deletedomainblock.go @@ -55,14 +55,14 @@ func (p *processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc // remove the domain block reference from the instance, if we have an entry for it i := >smodel.Instance{} if err := p.db.GetWhere(ctx, []db.Where{ - {Key: "domain", Value: domainBlock.Domain, CaseInsensitive: true}, + {Key: "domain", Value: domainBlock.Domain}, {Key: "domain_block_id", Value: id}, }, i); err == nil { updatingColumns := []string{"suspended_at", "domain_block_id", "updated_at"} i.SuspendedAt = time.Time{} i.DomainBlockID = "" i.UpdatedAt = time.Now() - if err := p.db.UpdateByPrimaryKey(ctx, i, updatingColumns...); err != nil { + if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err)) } } diff --git a/internal/processing/instance.go b/internal/processing/instance.go @@ -224,7 +224,7 @@ func (p *processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe } } - if err := p.db.UpdateByPrimaryKey(ctx, i, updatingColumns...); err != nil { + if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", host, err)) } diff --git a/internal/processing/media/getfile_test.go b/internal/processing/media/getfile_test.go @@ -69,7 +69,7 @@ func (suite *GetFileTestSuite) TestGetRemoteFileUncached() { // uncache the file from local testAttachment := suite.testAttachments["remote_account_1_status_1_attachment_1"] testAttachment.Cached = testrig.FalseBool() - err := suite.db.UpdateByPrimaryKey(ctx, testAttachment, "cached") + err := suite.db.UpdateByID(ctx, testAttachment, testAttachment.ID, "cached") suite.NoError(err) err = suite.storage.Delete(ctx, testAttachment.File.Path) suite.NoError(err) @@ -124,7 +124,7 @@ func (suite *GetFileTestSuite) TestGetRemoteFileUncachedInterrupted() { // uncache the file from local testAttachment := suite.testAttachments["remote_account_1_status_1_attachment_1"] testAttachment.Cached = testrig.FalseBool() - err := suite.db.UpdateByPrimaryKey(ctx, testAttachment, "cached") + err := suite.db.UpdateByID(ctx, testAttachment, testAttachment.ID, "cached") suite.NoError(err) err = suite.storage.Delete(ctx, testAttachment.File.Path) suite.NoError(err) @@ -179,7 +179,7 @@ func (suite *GetFileTestSuite) TestGetRemoteFileThumbnailUncached() { // uncache the file from local testAttachment.Cached = testrig.FalseBool() - err = suite.db.UpdateByPrimaryKey(ctx, testAttachment, "cached") + err = suite.db.UpdateByID(ctx, testAttachment, testAttachment.ID, "cached") suite.NoError(err) err = suite.storage.Delete(ctx, testAttachment.File.Path) suite.NoError(err) diff --git a/internal/processing/media/unattach.go b/internal/processing/media/unattach.go @@ -47,7 +47,7 @@ func (p *processor) Unattach(ctx context.Context, account *gtsmodel.Account, med attachment.UpdatedAt = time.Now() attachment.StatusID = "" - if err := p.db.UpdateByPrimaryKey(ctx, attachment, updatingColumns...); err != nil { + if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error updating attachment: %s", err)) } diff --git a/internal/processing/media/update.go b/internal/processing/media/update.go @@ -61,7 +61,7 @@ func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, media updatingColumns = append(updatingColumns, "focus_x", "focus_y") } - if err := p.db.UpdateByPrimaryKey(ctx, attachment, updatingColumns...); err != nil { + if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error updating media: %s", err)) } diff --git a/internal/processing/status/util.go b/internal/processing/status/util.go @@ -162,27 +162,28 @@ func (p *processor) ProcessMediaIDs(ctx context.Context, form *apimodel.Advanced return nil } - gtsMediaAttachments := []*gtsmodel.MediaAttachment{} - attachments := []string{} + attachments := []*gtsmodel.MediaAttachment{} + attachmentIDs := []string{} for _, mediaID := range form.MediaIDs { - // check these attachments exist - a := >smodel.MediaAttachment{} - if err := p.db.GetByID(ctx, mediaID, a); err != nil { - return fmt.Errorf("invalid media type or media not found for media id %s", mediaID) + attachment, err := p.db.GetAttachmentByID(ctx, mediaID) + if err != nil { + return fmt.Errorf("ProcessMediaIDs: invalid media type or media not found for media id %s", mediaID) } - // check they belong to the requesting account id - if a.AccountID != thisAccountID { - return fmt.Errorf("media with id %s does not belong to account %s", mediaID, thisAccountID) + + if attachment.AccountID != thisAccountID { + return fmt.Errorf("ProcessMediaIDs: media with id %s does not belong to account %s", mediaID, thisAccountID) } - // check they're not already used in a status - if a.StatusID != "" || a.ScheduledStatusID != "" { - return fmt.Errorf("media with id %s is already attached to a status", mediaID) + + if attachment.StatusID != "" || attachment.ScheduledStatusID != "" { + return fmt.Errorf("ProcessMediaIDs: media with id %s is already attached to a status", mediaID) } - gtsMediaAttachments = append(gtsMediaAttachments, a) - attachments = append(attachments, a.ID) + + attachments = append(attachments, attachment) + attachmentIDs = append(attachmentIDs, attachment.ID) } - status.Attachments = gtsMediaAttachments - status.AttachmentIDs = attachments + + status.Attachments = attachments + status.AttachmentIDs = attachmentIDs return nil } diff --git a/internal/processing/user/changepassword.go b/internal/processing/user/changepassword.go @@ -45,7 +45,7 @@ func (p *processor) ChangePassword(ctx context.Context, user *gtsmodel.User, old user.EncryptedPassword = string(newPasswordHash) user.UpdatedAt = time.Now() - if err := p.db.UpdateByPrimaryKey(ctx, user, "encrypted_password", "updated_at"); err != nil { + if err := p.db.UpdateByID(ctx, user, user.ID, "encrypted_password", "updated_at"); err != nil { return gtserror.NewErrorInternalError(err, "database error") } diff --git a/internal/processing/user/emailconfirm.go b/internal/processing/user/emailconfirm.go @@ -77,7 +77,7 @@ func (p *processor) SendConfirmEmail(ctx context.Context, user *gtsmodel.User, u user.LastEmailedAt = time.Now() user.UpdatedAt = time.Now() - if err := p.db.UpdateByPrimaryKey(ctx, user, updatingColumns...); err != nil { + if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil { return fmt.Errorf("SendConfirmEmail: error updating user entry after email sent: %s", err) } @@ -126,7 +126,7 @@ func (p *processor) ConfirmEmail(ctx context.Context, token string) (*gtsmodel.U user.ConfirmationToken = "" user.UpdatedAt = time.Now() - if err := p.db.UpdateByPrimaryKey(ctx, user, updatingColumns...); err != nil { + if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/user/emailconfirm_test.go b/internal/processing/user/emailconfirm_test.go @@ -74,7 +74,7 @@ func (suite *EmailConfirmTestSuite) TestConfirmEmail() { user.ConfirmationSentAt = time.Now().Add(-5 * time.Minute) user.ConfirmationToken = "1d1aa44b-afa4-49c8-ac4b-eceb61715cc6" - err := suite.db.UpdateByPrimaryKey(ctx, user, updatingColumns...) + err := suite.db.UpdateByID(ctx, user, user.ID, updatingColumns...) suite.NoError(err) // confirm with the token set above @@ -102,7 +102,7 @@ func (suite *EmailConfirmTestSuite) TestConfirmEmailOldToken() { user.ConfirmationSentAt = time.Now().Add(-192 * time.Hour) user.ConfirmationToken = "1d1aa44b-afa4-49c8-ac4b-eceb61715cc6" - err := suite.db.UpdateByPrimaryKey(ctx, user, updatingColumns...) + err := suite.db.UpdateByID(ctx, user, user.ID, updatingColumns...) suite.NoError(err) // confirm with the token set above diff --git a/testrig/db.go b/testrig/db.go @@ -187,7 +187,7 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) { } for _, v := range NewTestStatuses() { - if err := db.PutStatus(ctx, v); err != nil { + if err := db.Put(ctx, v); err != nil { log.Panic(err) } } @@ -198,12 +198,24 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) { } } + for _, v := range NewTestStatusToEmojis() { + if err := db.Put(ctx, v); err != nil { + log.Panic(err) + } + } + for _, v := range NewTestTags() { if err := db.Put(ctx, v); err != nil { log.Panic(err) } } + for _, v := range NewTestStatusToTags() { + if err := db.Put(ctx, v); err != nil { + log.Panic(err) + } + } + for _, v := range NewTestMentions() { if err := db.Put(ctx, v); err != nil { log.Panic(err) diff --git a/testrig/testmodels.go b/testrig/testmodels.go @@ -977,6 +977,15 @@ func NewTestEmojis() map[string]*gtsmodel.Emoji { } } +func NewTestStatusToEmojis() map[string]*gtsmodel.StatusToEmoji { + return map[string]*gtsmodel.StatusToEmoji{ + "admin_account_status_1_rainbow": { + StatusID: "01F8MH75CBF9JFX4ZAD54N0W0R", + EmojiID: "01F8MH9H8E4VG3KDYJR9EGPXCQ", + }, + } +} + func NewTestInstances() map[string]*gtsmodel.Instance { return map[string]*gtsmodel.Instance{ "localhost:8080": { @@ -1540,6 +1549,15 @@ func NewTestTags() map[string]*gtsmodel.Tag { } } +func NewTestStatusToTags() map[string]*gtsmodel.StatusToTag { + return map[string]*gtsmodel.StatusToTag{ + "admin_account_status_1_welcome": { + StatusID: "01F8MH75CBF9JFX4ZAD54N0W0R", + TagID: "01F8MHA1A2NF9MJ3WCCQ3K8BSZ", + }, + } +} + // NewTestMentions returns a map of gts model mentions keyed by their name. func NewTestMentions() map[string]*gtsmodel.Mention { return map[string]*gtsmodel.Mention{