commit 56f53a2a6f85876485e2ae67d48b78b448caed6e parent f7af7c061c41de08a455d39da490f4e52dd5e025 Author: tobi <31960611+tsmethurst@users.noreply.github.com> Date: Mon, 3 Oct 2022 10:46:11 +0200 [performance] add user cache and database (#879) * go fmt * add + use user cache and database * fix import * update tests * remove unused relation Diffstat:
21 files changed, 490 insertions(+), 70 deletions(-)
diff --git a/cmd/gotosocial/action/admin/account/account.go b/cmd/gotosocial/action/admin/account/account.go @@ -26,9 +26,7 @@ import ( "github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action" "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db/bundb" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/validate" "golang.org/x/crypto/bcrypt" ) @@ -92,8 +90,8 @@ var Confirm action.GTSAction = func(ctx context.Context) error { return err } - u := >smodel.User{} - if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { + u, err := dbConn.GetUserByAccountID(ctx, a.ID) + if err != nil { return err } @@ -130,8 +128,8 @@ var Promote action.GTSAction = func(ctx context.Context) error { return err } - u := >smodel.User{} - if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { + u, err := dbConn.GetUserByAccountID(ctx, a.ID) + if err != nil { return err } @@ -139,7 +137,7 @@ var Promote action.GTSAction = func(ctx context.Context) error { admin := true u.Admin = &admin u.UpdatedAt = time.Now() - if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil { + if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil { return err } @@ -166,8 +164,8 @@ var Demote action.GTSAction = func(ctx context.Context) error { return err } - u := >smodel.User{} - if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { + u, err := dbConn.GetUserByAccountID(ctx, a.ID) + if err != nil { return err } @@ -175,7 +173,7 @@ var Demote action.GTSAction = func(ctx context.Context) error { admin := false u.Admin = &admin u.UpdatedAt = time.Now() - if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil { + if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil { return err } @@ -202,8 +200,8 @@ var Disable action.GTSAction = func(ctx context.Context) error { return err } - u := >smodel.User{} - if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { + u, err := dbConn.GetUserByAccountID(ctx, a.ID) + if err != nil { return err } @@ -211,7 +209,7 @@ var Disable action.GTSAction = func(ctx context.Context) error { disabled := true u.Disabled = &disabled u.UpdatedAt = time.Now() - if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil { + if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil { return err } @@ -252,8 +250,8 @@ var Password action.GTSAction = func(ctx context.Context) error { return err } - u := >smodel.User{} - if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { + u, err := dbConn.GetUserByAccountID(ctx, a.ID) + if err != nil { return err } @@ -265,7 +263,7 @@ var Password action.GTSAction = func(ctx context.Context) error { updatingColumns := []string{"encrypted_password", "updated_at"} u.EncryptedPassword = string(pw) u.UpdatedAt = time.Now() - if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil { + if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil { return err } diff --git a/internal/api/client/auth/authorize.go b/internal/api/client/auth/authorize.go @@ -94,8 +94,8 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { return } - user := >smodel.User{} - if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil { + user, err := m.db.GetUserByID(c.Request.Context(), userID) + if err != nil { m.clearSession(s) safe := fmt.Sprintf("user with id %s could not be retrieved", userID) var errWithCode gtserror.WithCode @@ -213,8 +213,8 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) { return } - user := >smodel.User{} - if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil { + user, err := m.db.GetUserByID(c.Request.Context(), userID) + if err != nil { m.clearSession(s) safe := fmt.Sprintf("user with id %s could not be retrieved", userID) var errWithCode gtserror.WithCode diff --git a/internal/api/client/auth/authorize_test.go b/internal/api/client/auth/authorize_test.go @@ -76,8 +76,11 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() { doTest := func(testCase authorizeHandlerTestCase) { ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath, nil, "") - user := suite.testUsers["unconfirmed_account"] - account := suite.testAccounts["unconfirmed_account"] + user := >smodel.User{} + account := >smodel.Account{} + + *user = *suite.testUsers["unconfirmed_account"] + *account = *suite.testAccounts["unconfirmed_account"] testSession := sessions.Default(ctx) testSession.Set(sessionUserID, user.ID) @@ -91,8 +94,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() { testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt) updatingColumns = append(updatingColumns, "updated_at") - user.UpdatedAt = time.Now() - err := suite.db.UpdateByPrimaryKey(context.Background(), user, updatingColumns...) + _, err := suite.db.UpdateUser(context.Background(), user, updatingColumns...) suite.NoError(err) _, err = suite.db.UpdateAccount(context.Background(), account) suite.NoError(err) diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go @@ -134,8 +134,7 @@ func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, i // see if we already have a user for this email address // if so, we don't need to continue + create one - user := >smodel.User{} - err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user) + user, err := m.db.GetUserByEmailAddress(ctx, claims.Email) if err == nil { return user, nil } diff --git a/internal/api/client/auth/signin.go b/internal/api/client/auth/signin.go @@ -28,9 +28,7 @@ import ( "github.com/gin-gonic/gin" "github.com/superseriousbusiness/gotosocial/internal/api" "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "golang.org/x/crypto/bcrypt" ) @@ -119,8 +117,8 @@ func (m *Module) ValidatePassword(ctx context.Context, email string, password st return incorrectPassword(err) } - user := >smodel.User{} - if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, user); err != nil { + user, err := m.db.GetUserByEmailAddress(ctx, email) + if err != nil { err := fmt.Errorf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err) return incorrectPassword(err) } diff --git a/internal/api/security/tokencheck.go b/internal/api/security/tokencheck.go @@ -52,8 +52,8 @@ func (m *Module) TokenCheck(c *gin.Context) { log.Tracef("authenticated user %s with bearer token, scope is %s", userID, ti.GetScope()) // fetch user for this token - user := >smodel.User{} - if err := m.db.GetByID(ctx, userID, user); err != nil { + user, err := m.db.GetUserByID(ctx, userID) + if err != nil { if err != db.ErrNoEntries { log.Errorf("database error looking for user with id %s: %s", userID, err) return @@ -80,22 +80,25 @@ func (m *Module) TokenCheck(c *gin.Context) { c.Set(oauth.SessionAuthorizedUser, user) // fetch account for this token - acct, err := m.db.GetAccountByID(ctx, user.AccountID) - if err != nil { - if err != db.ErrNoEntries { - log.Errorf("database error looking for account with id %s: %s", user.AccountID, err) + if user.Account == nil { + acct, err := m.db.GetAccountByID(ctx, user.AccountID) + if err != nil { + if err != db.ErrNoEntries { + log.Errorf("database error looking for account with id %s: %s", user.AccountID, err) + return + } + log.Warnf("no account found for userID %s", userID) return } - log.Warnf("no account found for userID %s", userID) - return + user.Account = acct } - if !acct.SuspendedAt.IsZero() { + if !user.Account.SuspendedAt.IsZero() { log.Warnf("authenticated user %s's account (accountId=%s) has been suspended", userID, user.AccountID) return } - c.Set(oauth.SessionAuthorizedAccount, acct) + c.Set(oauth.SessionAuthorizedAccount, user.Account) } // check for application token diff --git a/internal/cache/user.go b/internal/cache/user.go @@ -0,0 +1,141 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package cache + +import ( + "time" + + "codeberg.org/gruf/go-cache/v2" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// UserCache is a cache wrapper to provide lookups for gtsmodel.User +type UserCache struct { + cache cache.LookupCache[string, string, *gtsmodel.User] +} + +// NewUserCache returns a new instantiated UserCache object +func NewUserCache() *UserCache { + c := &UserCache{} + c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.User]{ + RegisterLookups: func(lm *cache.LookupMap[string, string]) { + lm.RegisterLookup("accountid") + lm.RegisterLookup("email") + lm.RegisterLookup("unconfirmedemail") + lm.RegisterLookup("confirmationtoken") + }, + + AddLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) { + lm.Set("accountid", user.AccountID, user.ID) + if email := user.Email; email != "" { + lm.Set("email", email, user.ID) + } + if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" { + lm.Set("unconfirmedemail", unconfirmedEmail, user.ID) + } + if confirmationToken := user.ConfirmationToken; confirmationToken != "" { + lm.Set("confirmationtoken", confirmationToken, user.ID) + } + }, + + DeleteLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) { + lm.Delete("accountid", user.AccountID) + if email := user.Email; email != "" { + lm.Delete("email", email) + } + if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" { + lm.Delete("unconfirmedemail", unconfirmedEmail) + } + if confirmationToken := user.ConfirmationToken; confirmationToken != "" { + lm.Delete("confirmationtoken", confirmationToken) + } + }, + }) + c.cache.SetTTL(time.Minute*5, false) + c.cache.Start(time.Second * 10) + return c +} + +// GetByID attempts to fetch a user from the cache by its ID, you will receive a copy for thread-safety +func (c *UserCache) GetByID(id string) (*gtsmodel.User, bool) { + return c.cache.Get(id) +} + +// GetByAccountID attempts to fetch a user from the cache by its account ID, you will receive a copy for thread-safety +func (c *UserCache) GetByAccountID(accountID string) (*gtsmodel.User, bool) { + return c.cache.GetBy("accountid", accountID) +} + +// GetByEmail attempts to fetch a user from the cache by its email address, you will receive a copy for thread-safety +func (c *UserCache) GetByEmail(email string) (*gtsmodel.User, bool) { + return c.cache.GetBy("email", email) +} + +// GetByUnconfirmedEmail attempts to fetch a user from the cache by its confirmation token, you will receive a copy for thread-safety +func (c *UserCache) GetByConfirmationToken(token string) (*gtsmodel.User, bool) { + return c.cache.GetBy("confirmationtoken", token) +} + +// Put places a user in the cache, ensuring that the object place is a copy for thread-safety +func (c *UserCache) Put(user *gtsmodel.User) { + if user == nil || user.ID == "" { + panic("invalid user") + } + c.cache.Set(user.ID, copyUser(user)) +} + +// Invalidate invalidates one user from the cache using the ID of the user as key. +func (c *UserCache) Invalidate(userID string) { + c.cache.Invalidate(userID) +} + +func copyUser(user *gtsmodel.User) *gtsmodel.User { + return >smodel.User{ + ID: user.ID, + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + Email: user.Email, + AccountID: user.AccountID, + Account: nil, + EncryptedPassword: user.EncryptedPassword, + SignUpIP: user.SignUpIP, + CurrentSignInAt: user.CurrentSignInAt, + CurrentSignInIP: user.CurrentSignInIP, + LastSignInAt: user.LastSignInAt, + LastSignInIP: user.LastSignInIP, + SignInCount: user.SignInCount, + InviteID: user.InviteID, + ChosenLanguages: user.ChosenLanguages, + FilteredLanguages: user.FilteredLanguages, + Locale: user.Locale, + CreatedByApplicationID: user.CreatedByApplicationID, + CreatedByApplication: nil, + LastEmailedAt: user.LastEmailedAt, + ConfirmationToken: user.ConfirmationToken, + ConfirmationSentAt: user.ConfirmationSentAt, + ConfirmedAt: user.ConfirmedAt, + UnconfirmedEmail: user.UnconfirmedEmail, + Moderator: copyBoolPtr(user.Moderator), + Admin: copyBoolPtr(user.Admin), + Disabled: copyBoolPtr(user.Disabled), + Approved: copyBoolPtr(user.Approved), + ResetPasswordToken: user.ResetPasswordToken, + ResetPasswordSentAt: user.ResetPasswordSentAt, + } +} diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go @@ -30,6 +30,7 @@ import ( "time" "github.com/superseriousbusiness/gotosocial/internal/ap" + "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -40,7 +41,8 @@ import ( ) type adminDB struct { - conn *DBConn + conn *DBConn + userCache *cache.UserCache } func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { @@ -175,6 +177,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, Exec(ctx); err != nil { return nil, a.conn.ProcessError(err) } + a.userCache.Put(u) return u, nil } diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go @@ -87,6 +87,7 @@ type DBService struct { db.Session db.Status db.Timeline + db.User conn *DBConn } @@ -181,13 +182,15 @@ func NewBunDBService(ctx context.Context) (db.DB, error) { notifCache.SetTTL(time.Minute*5, false) notifCache.Start(time.Second * 10) - // Prepare domain block cache + // Prepare other caches blockCache := cache.NewDomainBlockCache() + userCache := cache.NewUserCache() ps := &DBService{ Account: accounts, Admin: &adminDB{ - conn: conn, + conn: conn, + userCache: userCache, }, Basic: &basicDB{ conn: conn, @@ -219,7 +222,11 @@ func NewBunDBService(ctx context.Context) (db.DB, error) { }, Status: status, Timeline: timeline, - conn: conn, + User: &userDB{ + conn: conn, + cache: userCache, + }, + conn: conn, } // we can confidently return this useable service now diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go @@ -0,0 +1,151 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/cache" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +type userDB struct { + conn *DBConn + cache *cache.UserCache +} + +func (u *userDB) newUserQ(user *gtsmodel.User) *bun.SelectQuery { + return u.conn. + NewSelect(). + Model(user). + Relation("Account") +} + +func (u *userDB) getUser(ctx context.Context, cacheGet func() (*gtsmodel.User, bool), dbQuery func(*gtsmodel.User) error) (*gtsmodel.User, db.Error) { + // Attempt to fetch cached user + user, cached := cacheGet() + + if !cached { + user = >smodel.User{} + + // Not cached! Perform database query + err := dbQuery(user) + if err != nil { + return nil, u.conn.ProcessError(err) + } + + // Place in the cache + u.cache.Put(user) + } + + return user, nil +} + +func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) { + return u.getUser( + ctx, + func() (*gtsmodel.User, bool) { + return u.cache.GetByID(id) + }, + func(user *gtsmodel.User) error { + return u.newUserQ(user).Where("user.id = ?", id).Scan(ctx) + }, + ) +} + +func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, db.Error) { + return u.getUser( + ctx, + func() (*gtsmodel.User, bool) { + return u.cache.GetByAccountID(accountID) + }, + func(user *gtsmodel.User) error { + return u.newUserQ(user).Where("user.account_id = ?", accountID).Scan(ctx) + }, + ) +} + +func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, db.Error) { + return u.getUser( + ctx, + func() (*gtsmodel.User, bool) { + return u.cache.GetByEmail(emailAddress) + }, + func(user *gtsmodel.User) error { + return u.newUserQ(user).Where("user.email = ?", emailAddress).Scan(ctx) + }, + ) +} + +func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, db.Error) { + return u.getUser( + ctx, + func() (*gtsmodel.User, bool) { + return u.cache.GetByConfirmationToken(confirmationToken) + }, + func(user *gtsmodel.User) error { + return u.newUserQ(user).Where("user.confirmation_token = ?", confirmationToken).Scan(ctx) + }, + ) +} + +func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) (*gtsmodel.User, db.Error) { + if _, err := u.conn. + NewInsert(). + Model(user). + Exec(ctx); err != nil { + return nil, u.conn.ProcessError(err) + } + + u.cache.Put(user) + return user, nil +} + +func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, db.Error) { + // Update the user's last-updated + user.UpdatedAt = time.Now() + + if _, err := u.conn. + NewUpdate(). + Model(user). + WherePK(). + Column(columns...). + Exec(ctx); err != nil { + return nil, u.conn.ProcessError(err) + } + + u.cache.Invalidate(user.ID) + return user, nil +} + +func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error { + if _, err := u.conn. + NewDelete(). + Model(>smodel.User{ID: userID}). + WherePK(). + Exec(ctx); err != nil { + return u.conn.ProcessError(err) + } + + u.cache.Invalidate(userID) + return nil +} diff --git a/internal/db/bundb/user_test.go b/internal/db/bundb/user_test.go @@ -0,0 +1,73 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type UserTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *UserTestSuite) TestGetUser() { + user, err := suite.db.GetUserByID(context.Background(), suite.testUsers["local_account_1"].ID) + suite.NoError(err) + suite.NotNil(user) +} + +func (suite *UserTestSuite) TestGetUserByEmailAddress() { + user, err := suite.db.GetUserByEmailAddress(context.Background(), suite.testUsers["local_account_1"].Email) + suite.NoError(err) + suite.NotNil(user) +} + +func (suite *UserTestSuite) TestGetUserByAccountID() { + user, err := suite.db.GetUserByAccountID(context.Background(), suite.testAccounts["local_account_1"].ID) + suite.NoError(err) + suite.NotNil(user) +} + +func (suite *UserTestSuite) TestUpdateUserSelectedColumns() { + testUser := suite.testUsers["local_account_1"] + user := >smodel.User{ + ID: testUser.ID, + Email: "whatever", + Locale: "es", + } + + user, err := suite.db.UpdateUser(context.Background(), user, "email", "locale") + suite.NoError(err) + suite.NotNil(user) + + dbUser, err := suite.db.GetUserByID(context.Background(), testUser.ID) + suite.NoError(err) + suite.NotNil(dbUser) + suite.Equal("whatever", dbUser.Email) + suite.Equal("es", dbUser.Locale) + suite.Equal(testUser.AccountID, dbUser.AccountID) +} + +func TestUserTestSuite(t *testing.T) { + suite.Run(t, new(UserTestSuite)) +} diff --git a/internal/db/db.go b/internal/db/db.go @@ -44,6 +44,7 @@ type DB interface { Session Status Timeline + User /* USEFUL CONVERSION FUNCTIONS diff --git a/internal/db/user.go b/internal/db/user.go @@ -0,0 +1,42 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package db + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// User contains functions related to user getting/setting/creation. +type User interface { + // GetUserByID returns one user with the given ID, or an error if something goes wrong. + GetUserByID(ctx context.Context, id string) (*gtsmodel.User, Error) + // GetUserByAccountID returns one user by its account ID, or an error if something goes wrong. + GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, Error) + // GetUserByID returns one user with the given email address, or an error if something goes wrong. + GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, Error) + // GetUserByConfirmationToken returns one user by its confirmation token, or an error if something goes wrong. + GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, Error) + // UpdateUser updates one user by its primary key. If columns is set, only given columns + // will be updated. If not set, all columns will be updated. + UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, Error) + // DeleteUserByID deletes one user by its ID. + DeleteUserByID(ctx context.Context, userID string) Error +} diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go @@ -70,13 +70,14 @@ func (p *processor) Delete(ctx context.Context, account *gtsmodel.Account, origi // 1. Delete account's application(s), clients, and oauth tokens // we only need to do this step for local account since remote ones won't have any tokens or applications on our server + var user *gtsmodel.User if account.Domain == "" { // see if we can get a user for this account - u := >smodel.User{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, u); err == nil { + var err error + if user, err = p.db.GetUserByAccountID(ctx, account.ID); err == nil { // we got one! select all tokens with the user's ID tokens := []*gtsmodel.Token{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: u.ID}}, &tokens); err == nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err == nil { // we have some tokens to delete for _, t := range tokens { // delete client(s) associated with this token @@ -240,9 +241,11 @@ selectStatusesLoop: // TODO // 16. Delete account's user - l.Debug("deleting account user") - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, >smodel.User{}); err != nil { - return gtserror.NewErrorInternalError(err) + if user != nil { + l.Debug("deleting account user") + if err := p.db.DeleteUserByID(ctx, user.ID); err != nil { + return gtserror.NewErrorInternalError(err) + } } // 17. Delete account's timeline @@ -288,8 +291,8 @@ func (p *processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account, if form.DeleteOriginID == account.ID { // the account owner themself has requested deletion via the API, get their user from the db - user := >smodel.User{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, user); err != nil { + user, err := p.db.GetUserByAccountID(ctx, account.ID) + if err != nil { return gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/fromclientapi.go b/internal/processing/fromclientapi.go @@ -29,7 +29,6 @@ import ( "github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/gotosocial/internal/ap" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/messages" @@ -138,8 +137,8 @@ func (p *processor) processCreateAccountFromClientAPI(ctx context.Context, clien } // get the user this account belongs to - user := >smodel.User{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, user); err != nil { + user, err := p.db.GetUserByAccountID(ctx, account.ID) + if err != nil { return err } diff --git a/internal/processing/fromfederator_test.go b/internal/processing/fromfederator_test.go @@ -370,7 +370,7 @@ func (suite *FromFederatorTestSuite) TestProcessAccountDelete() { // no statuses from foss satan should be left in the database if !testrig.WaitFor(func() bool { s, err := suite.db.GetAccountStatuses(ctx, deletedAccount.ID, 0, false, false, "", "", false, false, false) - return s == nil && err == db.ErrNoEntries + return s == nil && err == db.ErrNoEntries }) { suite.FailNow("timeout waiting for statuses to be deleted") } diff --git a/internal/processing/instance.go b/internal/processing/instance.go @@ -142,8 +142,8 @@ func (p *processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername)) } // make sure it has a user associated with it - contactUser := >smodel.User{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: contactAccount.ID}}, contactUser); err != nil { + contactUser, err := p.db.GetUserByAccountID(ctx, contactAccount.ID) + if err != nil { return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername)) } // suspended accounts cannot be contact accounts diff --git a/internal/processing/streaming/authorize.go b/internal/processing/streaming/authorize.go @@ -40,8 +40,8 @@ func (p *processor) AuthorizeStreamingRequest(ctx context.Context, accessToken s return nil, gtserror.NewErrorUnauthorized(err) } - user := >smodel.User{} - if err := p.db.GetByID(ctx, uid, user); err != nil { + user, err := p.db.GetUserByID(ctx, uid) + if err != nil { if err == db.ErrNoEntries { err := fmt.Errorf("no user found for validated uid %s", uid) return nil, gtserror.NewErrorUnauthorized(err) diff --git a/internal/processing/user/emailconfirm.go b/internal/processing/user/emailconfirm.go @@ -89,8 +89,8 @@ func (p *processor) ConfirmEmail(ctx context.Context, token string) (*gtsmodel.U return nil, gtserror.NewErrorNotFound(errors.New("no token provided")) } - user := >smodel.User{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "confirmation_token", Value: token}}, user); err != nil { + user, err := p.db.GetUserByConfirmationToken(ctx, token) + if err != nil { if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(err) } diff --git a/internal/typeutils/internaltofrontend_test.go b/internal/typeutils/internaltofrontend_test.go @@ -46,9 +46,9 @@ func (suite *InternalToFrontendTestSuite) TestAccountToFrontend() { func (suite *InternalToFrontendTestSuite) TestAccountToFrontendWithEmojiStruct() { testAccount := suite.testAccounts["local_account_1"] // take zork for this test testEmoji := suite.testEmojis["rainbow"] - + testAccount.Emojis = []*gtsmodel.Emoji{testEmoji} - + apiAccount, err := suite.typeconverter.AccountToAPIAccountPublic(context.Background(), testAccount) suite.NoError(err) suite.NotNil(apiAccount) @@ -61,9 +61,9 @@ func (suite *InternalToFrontendTestSuite) TestAccountToFrontendWithEmojiStruct() func (suite *InternalToFrontendTestSuite) TestAccountToFrontendWithEmojiIDs() { testAccount := suite.testAccounts["local_account_1"] // take zork for this test testEmoji := suite.testEmojis["rainbow"] - + testAccount.EmojiIDs = []string{testEmoji.ID} - + apiAccount, err := suite.typeconverter.AccountToAPIAccountPublic(context.Background(), testAccount) suite.NoError(err) suite.NotNil(apiAccount) diff --git a/internal/visibility/statusvisible.go b/internal/visibility/statusvisible.go @@ -68,8 +68,8 @@ func (f *filter) StatusVisible(ctx context.Context, targetStatus *gtsmodel.Statu // if the target user doesn't exist (anymore) then the status also shouldn't be visible // note: we only do this for local users if targetAccount.Domain == "" { - targetUser := >smodel.User{} - if err := f.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: targetAccount.ID}}, targetUser); err != nil { + targetUser, err := f.db.GetUserByAccountID(ctx, targetAccount.ID) + if err != nil { l.Debug("target user could not be selected") if err == db.ErrNoEntries { return false, nil @@ -98,8 +98,8 @@ func (f *filter) StatusVisible(ctx context.Context, targetStatus *gtsmodel.Statu // if the requesting user doesn't exist (anymore) then the status also shouldn't be visible // note: we only do this for local users if requestingAccount.Domain == "" { - requestingUser := >smodel.User{} - if err := f.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: requestingAccount.ID}}, requestingUser); err != nil { + requestingUser, err := f.db.GetUserByAccountID(ctx, requestingAccount.ID) + if err != nil { // if the requesting account is local but doesn't have a corresponding user in the db this is a problem l.Debug("requesting user could not be selected") if err == db.ErrNoEntries {