commit 7d193de25fbccc00923d6d791d6d4e0d2d5d498e parent ed462245730bd7832019bd43e0bc1c9d1c055e8e Author: kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com> Date: Wed, 1 Sep 2021 10:08:21 +0100 Improve GetRemoteStatus and db.GetStatus() logic (#174) * only fetch status parents / children if explicity requested when dereferencing Signed-off-by: kim (grufwub) <grufwub@gmail.com> * Remove recursive DB GetStatus logic, don't fetch parent unless requested Signed-off-by: kim (grufwub) <grufwub@gmail.com> * StatusCache copies status so there are no thread-safety issues with modified status objects Signed-off-by: kim (grufwub) <grufwub@gmail.com> * remove sqlite test files Signed-off-by: kim (grufwub) <grufwub@gmail.com> * fix bugs introduced by previous commit Signed-off-by: kim (grufwub) <grufwub@gmail.com> * fix not continue on error in loop Signed-off-by: kim (grufwub) <grufwub@gmail.com> * use our own RunInTx implementation (possible fix for nested tx error) Signed-off-by: kim (grufwub) <grufwub@gmail.com> * fix cast statement to work with SQLite Signed-off-by: kim (grufwub) <grufwub@gmail.com> * be less strict about valid status in cache Signed-off-by: kim (grufwub) <grufwub@gmail.com> * add cache=shared ALWAYS for SQLite db instances Signed-off-by: kim (grufwub) <grufwub@gmail.com> * Fix EnrichRemoteAccount when updating account fails Signed-off-by: kim (grufwub) <grufwub@gmail.com> * add nolint tag Signed-off-by: kim (grufwub) <grufwub@gmail.com> * ensure file: prefixes the filename in sqlite addr Signed-off-by: kim (grufwub) <grufwub@gmail.com> * add an account cache, add status author account from db Signed-off-by: kim (grufwub) <grufwub@gmail.com> * Fix incompatible SQLite query Signed-off-by: kim (grufwub) <grufwub@gmail.com> * *actually* use the new getAccount() function in accountsDB Signed-off-by: kim (grufwub) <grufwub@gmail.com> * update cache tests to use test suite Signed-off-by: kim (grufwub) <grufwub@gmail.com> * add RelationshipTestSuite, add tests for methods with changed SQL Signed-off-by: kim (grufwub) <grufwub@gmail.com> Diffstat:
36 files changed, 653 insertions(+), 227 deletions(-)
diff --git a/internal/api/client/account/sqlite-test.db b/internal/api/client/account/sqlite-test.db Binary files differ. diff --git a/internal/api/client/fileserver/sqlite-test.db b/internal/api/client/fileserver/sqlite-test.db Binary files differ. diff --git a/internal/api/client/media/sqlite-test.db b/internal/api/client/media/sqlite-test.db Binary files differ. diff --git a/internal/api/client/status/sqlite-test.db b/internal/api/client/status/sqlite-test.db Binary files differ. diff --git a/internal/api/s2s/user/sqlite-test.db b/internal/api/s2s/user/sqlite-test.db Binary files differ. diff --git a/internal/cache/account.go b/internal/cache/account.go @@ -0,0 +1,157 @@ +package cache + +import ( + "sync" + + "github.com/ReneKroon/ttlcache" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// AccountCache is a wrapper around ttlcache.Cache to provide URL and URI lookups for gtsmodel.Account +type AccountCache struct { + cache *ttlcache.Cache // map of IDs -> cached accounts + urls map[string]string // map of account URLs -> IDs + uris map[string]string // map of account URIs -> IDs + mutex sync.Mutex +} + +// NewAccountCache returns a new instantiated AccountCache object +func NewAccountCache() *AccountCache { + c := AccountCache{ + cache: ttlcache.NewCache(), + urls: make(map[string]string, 100), + uris: make(map[string]string, 100), + mutex: sync.Mutex{}, + } + + // Set callback to purge lookup maps on expiration + c.cache.SetExpirationCallback(func(key string, value interface{}) { + account := value.(*gtsmodel.Account) + + c.mutex.Lock() + delete(c.urls, account.URL) + delete(c.uris, account.URI) + c.mutex.Unlock() + }) + + return &c +} + +// GetByID attempts to fetch a account from the cache by its ID, you will receive a copy for thread-safety +func (c *AccountCache) GetByID(id string) (*gtsmodel.Account, bool) { + c.mutex.Lock() + account, ok := c.getByID(id) + c.mutex.Unlock() + return account, ok +} + +// GetByURL attempts to fetch a account from the cache by its URL, you will receive a copy for thread-safety +func (c *AccountCache) GetByURL(url string) (*gtsmodel.Account, bool) { + // Perform safe ID lookup + c.mutex.Lock() + id, ok := c.urls[url] + + // Not found, unlock early + if !ok { + c.mutex.Unlock() + return nil, false + } + + // Attempt account lookup + account, ok := c.getByID(id) + c.mutex.Unlock() + return account, ok +} + +// GetByURI attempts to fetch a account from the cache by its URI, you will receive a copy for thread-safety +func (c *AccountCache) GetByURI(uri string) (*gtsmodel.Account, bool) { + // Perform safe ID lookup + c.mutex.Lock() + id, ok := c.uris[uri] + + // Not found, unlock early + if !ok { + c.mutex.Unlock() + return nil, false + } + + // Attempt account lookup + account, ok := c.getByID(id) + c.mutex.Unlock() + return account, ok +} + +// getByID performs an unsafe (no mutex locks) lookup of account by ID, returning a copy of account in cache +func (c *AccountCache) getByID(id string) (*gtsmodel.Account, bool) { + v, ok := c.cache.Get(id) + if !ok { + return nil, false + } + return copyAccount(v.(*gtsmodel.Account)), true +} + +// Put places a account in the cache, ensuring that the object place is a copy for thread-safety +func (c *AccountCache) Put(account *gtsmodel.Account) { + if account == nil || account.ID == "" { + panic("invalid account") + } + + c.mutex.Lock() + c.cache.Set(account.ID, copyAccount(account)) + if account.URL != "" { + c.urls[account.URL] = account.ID + } + if account.URI != "" { + c.uris[account.URI] = account.ID + } + c.mutex.Unlock() +} + +// copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects. +// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr) +// this should be a relatively cheap process +func copyAccount(account *gtsmodel.Account) *gtsmodel.Account { + return >smodel.Account{ + ID: account.ID, + Username: account.Username, + Domain: account.Domain, + AvatarMediaAttachmentID: account.AvatarMediaAttachmentID, + AvatarMediaAttachment: nil, + AvatarRemoteURL: account.AvatarRemoteURL, + HeaderMediaAttachmentID: account.HeaderMediaAttachmentID, + HeaderMediaAttachment: nil, + HeaderRemoteURL: account.HeaderRemoteURL, + DisplayName: account.DisplayName, + Fields: account.Fields, + Note: account.Note, + Memorial: account.Memorial, + MovedToAccountID: account.MovedToAccountID, + CreatedAt: account.CreatedAt, + UpdatedAt: account.UpdatedAt, + Bot: account.Bot, + Reason: account.Reason, + Locked: account.Locked, + Discoverable: account.Discoverable, + Privacy: account.Privacy, + Sensitive: account.Sensitive, + Language: account.Language, + URI: account.URI, + URL: account.URL, + LastWebfingeredAt: account.LastWebfingeredAt, + InboxURI: account.InboxURI, + OutboxURI: account.OutboxURI, + FollowingURI: account.FollowingURI, + FollowersURI: account.FollowersURI, + FeaturedCollectionURI: account.FeaturedCollectionURI, + ActorType: account.ActorType, + AlsoKnownAs: account.AlsoKnownAs, + PrivateKey: account.PrivateKey, + PublicKey: account.PublicKey, + PublicKeyURI: account.PublicKeyURI, + SensitizedAt: account.SensitizedAt, + SilencedAt: account.SilencedAt, + SuspendedAt: account.SuspendedAt, + HideCollections: account.HideCollections, + SuspensionOrigin: account.SuspensionOrigin, + } +} diff --git a/internal/cache/account_test.go b/internal/cache/account_test.go @@ -0,0 +1,63 @@ +package cache_test + +import ( + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/cache" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type AccountCacheTestSuite struct { + suite.Suite + data map[string]*gtsmodel.Account + cache *cache.AccountCache +} + +func (suite *AccountCacheTestSuite) SetupSuite() { + suite.data = testrig.NewTestAccounts() +} + +func (suite *AccountCacheTestSuite) SetupTest() { + suite.cache = cache.NewAccountCache() +} + +func (suite *AccountCacheTestSuite) TearDownTest() { + suite.data = nil + suite.cache = nil +} + +func (suite *AccountCacheTestSuite) TestAccountCache() { + for _, account := range suite.data { + // Place in the cache + suite.cache.Put(account) + } + + for _, account := range suite.data { + var ok bool + var check *gtsmodel.Account + + // Check we can retrieve + check, ok = suite.cache.GetByID(account.ID) + if !ok && !accountIs(account, check) { + suite.Fail("Failed to fetch expected account with ID: %s", account.ID) + } + check, ok = suite.cache.GetByURI(account.URI) + if account.URI != "" && !ok && !accountIs(account, check) { + suite.Fail("Failed to fetch expected account with URI: %s", account.URI) + } + check, ok = suite.cache.GetByURL(account.URL) + if account.URL != "" && !ok && !accountIs(account, check) { + suite.Fail("Failed to fetch expected account with URL: %s", account.URL) + } + } +} + +func TestAccountCache(t *testing.T) { + suite.Run(t, &AccountCacheTestSuite{}) +} + +func accountIs(account1, account2 *gtsmodel.Account) bool { + return account1.ID == account2.ID && account1.URI == account2.URI && account1.URL == account2.URL +} diff --git a/internal/cache/status.go b/internal/cache/status.go @@ -37,7 +37,7 @@ func NewStatusCache() *StatusCache { return &c } -// GetByID attempts to fetch a status from the cache by its ID +// GetByID attempts to fetch a status from the cache by its ID, you will receive a copy for thread-safety func (c *StatusCache) GetByID(id string) (*gtsmodel.Status, bool) { c.mutex.Lock() status, ok := c.getByID(id) @@ -45,7 +45,7 @@ func (c *StatusCache) GetByID(id string) (*gtsmodel.Status, bool) { return status, ok } -// GetByURL attempts to fetch a status from the cache by its URL +// GetByURL attempts to fetch a status from the cache by its URL, you will receive a copy for thread-safety func (c *StatusCache) GetByURL(url string) (*gtsmodel.Status, bool) { // Perform safe ID lookup c.mutex.Lock() @@ -63,7 +63,7 @@ func (c *StatusCache) GetByURL(url string) (*gtsmodel.Status, bool) { return status, ok } -// GetByURI attempts to fetch a status from the cache by its URI +// GetByURI attempts to fetch a status from the cache by its URI, you will receive a copy for thread-safety func (c *StatusCache) GetByURI(uri string) (*gtsmodel.Status, bool) { // Perform safe ID lookup c.mutex.Lock() @@ -81,26 +81,72 @@ func (c *StatusCache) GetByURI(uri string) (*gtsmodel.Status, bool) { return status, ok } -// getByID performs an unsafe (no mutex locks) lookup of status by ID +// getByID performs an unsafe (no mutex locks) lookup of status by ID, returning a copy of status in cache func (c *StatusCache) getByID(id string) (*gtsmodel.Status, bool) { v, ok := c.cache.Get(id) if !ok { return nil, false } - return v.(*gtsmodel.Status), true + return copyStatus(v.(*gtsmodel.Status)), true } -// Put places a status in the cache +// Put places a status in the cache, ensuring that the object place is a copy for thread-safety func (c *StatusCache) Put(status *gtsmodel.Status) { - if status == nil || status.ID == "" || - status.URL == "" || - status.URI == "" { + if status == nil || status.ID == "" { panic("invalid status") } c.mutex.Lock() - c.cache.Set(status.ID, status) - c.urls[status.URL] = status.ID - c.uris[status.URI] = status.ID + c.cache.Set(status.ID, copyStatus(status)) + if status.URL != "" { + c.urls[status.URL] = status.ID + } + if status.URI != "" { + c.uris[status.URI] = status.ID + } c.mutex.Unlock() } + +// copyStatus performs a surface-level copy of status, only keeping attached IDs intact, not the objects. +// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr) +// this should be a relatively cheap process +func copyStatus(status *gtsmodel.Status) *gtsmodel.Status { + return >smodel.Status{ + ID: status.ID, + URI: status.URI, + URL: status.URL, + Content: status.Content, + AttachmentIDs: status.AttachmentIDs, + Attachments: nil, + TagIDs: status.TagIDs, + Tags: nil, + MentionIDs: status.MentionIDs, + Mentions: nil, + EmojiIDs: status.EmojiIDs, + Emojis: nil, + CreatedAt: status.CreatedAt, + UpdatedAt: status.UpdatedAt, + Local: status.Local, + AccountID: status.AccountID, + Account: nil, + AccountURI: status.AccountURI, + InReplyToID: status.InReplyToID, + InReplyTo: nil, + InReplyToURI: status.InReplyToURI, + InReplyToAccountID: status.InReplyToAccountID, + InReplyToAccount: nil, + BoostOfID: status.BoostOfID, + BoostOf: nil, + BoostOfAccountID: status.BoostOfAccountID, + BoostOfAccount: nil, + ContentWarning: status.ContentWarning, + Visibility: status.Visibility, + Sensitive: status.Sensitive, + Language: status.Language, + CreatedWithApplicationID: status.CreatedWithApplicationID, + VisibilityAdvanced: status.VisibilityAdvanced, + ActivityStreamsType: status.ActivityStreamsType, + Text: status.Text, + Pinned: status.Pinned, + } +} diff --git a/internal/cache/status_test.go b/internal/cache/status_test.go @@ -3,39 +3,61 @@ package cache_test import ( "testing" + "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/testrig" ) -func TestStatusCache(t *testing.T) { - cache := cache.NewStatusCache() +type StatusCacheTestSuite struct { + suite.Suite + data map[string]*gtsmodel.Status + cache *cache.StatusCache +} - // Attempt to place a status - status := gtsmodel.Status{ - ID: "id", - URI: "uri", - URL: "url", - } - cache.Put(&status) +func (suite *StatusCacheTestSuite) SetupSuite() { + suite.data = testrig.NewTestStatuses() +} - var ok bool - var check *gtsmodel.Status +func (suite *StatusCacheTestSuite) SetupTest() { + suite.cache = cache.NewStatusCache() +} - // Check we can retrieve - check, ok = cache.GetByID(status.ID) - if !ok || !statusIs(&status, check) { - t.Fatal("Could not find expected status") - } - check, ok = cache.GetByURI(status.URI) - if !ok || !statusIs(&status, check) { - t.Fatal("Could not find expected status") +func (suite *StatusCacheTestSuite) TearDownTest() { + suite.data = nil + suite.cache = nil +} + +func (suite *StatusCacheTestSuite) TestStatusCache() { + for _, status := range suite.data { + // Place in the cache + suite.cache.Put(status) } - check, ok = cache.GetByURL(status.URL) - if !ok || !statusIs(&status, check) { - t.Fatal("Could not find expected status") + + for _, status := range suite.data { + var ok bool + var check *gtsmodel.Status + + // Check we can retrieve + check, ok = suite.cache.GetByID(status.ID) + if !ok && !statusIs(status, check) { + suite.Fail("Failed to fetch expected account with ID: %s", status.ID) + } + check, ok = suite.cache.GetByURI(status.URI) + if status.URI != "" && !ok && !statusIs(status, check) { + suite.Fail("Failed to fetch expected account with URI: %s", status.URI) + } + check, ok = suite.cache.GetByURL(status.URL) + if status.URL != "" && !ok && !statusIs(status, check) { + suite.Fail("Failed to fetch expected account with URL: %s", status.URL) + } } } +func TestStatusCache(t *testing.T) { + suite.Run(t, &StatusCacheTestSuite{}) +} + func statusIs(status1, status2 *gtsmodel.Status) bool { return status1.ID == status2.ID && status1.URI == status2.URI && status1.URL == status2.URL } diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go @@ -25,6 +25,7 @@ import ( "strings" "time" + "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -34,6 +35,7 @@ import ( type accountDB struct { config *config.Config conn *DBConn + cache *cache.AccountCache } func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { @@ -45,60 +47,80 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { } func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { - account := new(gtsmodel.Account) - - q := a.newAccountQ(account). - Where("account.id = ?", id) - - err := q.Scan(ctx) - if err != nil { - return nil, a.conn.ProcessError(err) - } - return account, nil + return a.getAccount( + ctx, + func() (*gtsmodel.Account, bool) { + return a.cache.GetByID(id) + }, + func(account *gtsmodel.Account) error { + return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx) + }, + ) } func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { - account := new(gtsmodel.Account) - - q := a.newAccountQ(account). - Where("account.uri = ?", uri) + return a.getAccount( + ctx, + func() (*gtsmodel.Account, bool) { + return a.cache.GetByURI(uri) + }, + func(account *gtsmodel.Account) error { + return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx) + }, + ) +} - err := q.Scan(ctx) - if err != nil { - return nil, a.conn.ProcessError(err) - } - return account, nil +func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) { + return a.getAccount( + ctx, + func() (*gtsmodel.Account, bool) { + return a.cache.GetByURL(url) + }, + func(account *gtsmodel.Account) error { + return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx) + }, + ) } -func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { - account := new(gtsmodel.Account) +func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) { + // Attempt to fetch cached account + account, cached := cacheGet() - q := a.newAccountQ(account). - Where("account.url = ?", uri) + if !cached { + account = >smodel.Account{} - err := q.Scan(ctx) - if err != nil { - return nil, a.conn.ProcessError(err) + // Not cached! Perform database query + err := dbQuery(account) + if err != nil { + return nil, a.conn.ProcessError(err) + } + + // Place in the cache + a.cache.Put(account) } + return account, nil } func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { if strings.TrimSpace(account.ID) == "" { + // TODO: we should not need this check here return nil, errors.New("account had no ID") } + // Update the account's last-used account.UpdatedAt = time.Now() - q := a.conn. - NewUpdate(). - Model(account). - WherePK() - - _, err := q.Exec(ctx) + // Update the account model in the DB + _, err := a.conn.NewUpdate().Model(account).WherePK().Exec(ctx) if err != nil { return nil, a.conn.ProcessError(err) } + + // Place updated account in cache + // (this will replace existing, i.e. invalidating) + a.cache.Put(account) + return account, nil } diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go @@ -91,6 +91,15 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) conn = WrapDBConn(bun.NewDB(sqldb, pgdialect.New()), log) case dbTypeSqlite: // SQLITE + + // Drop anything fancy from DB address + c.DBConfig.Address = strings.Split(c.DBConfig.Address, "?")[0] + c.DBConfig.Address = strings.TrimPrefix(c.DBConfig.Address, "file:") + + // Append our own SQLite preferences + c.DBConfig.Address = "file:" + c.DBConfig.Address + "?cache=shared" + + // Open new DB instance var err error sqldb, err = sql.Open("sqlite", c.DBConfig.Address) if err != nil { @@ -98,7 +107,7 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) } conn = WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New()), log) - if strings.HasPrefix(strings.TrimPrefix(c.DBConfig.Address, "file:"), ":memory:") { + if c.DBConfig.Address == "file::memory:?cache=shared" { log.Warn("sqlite in-memory database should only be used for debugging") // don't close connections on disconnect -- otherwise @@ -121,11 +130,10 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) conn.RegisterModel(t) } + accounts := &accountDB{config: c, conn: conn, cache: cache.NewAccountCache()} + ps := &bunDBService{ - Account: &accountDB{ - config: c, - conn: conn, - }, + Account: accounts, Admin: &adminDB{ config: c, conn: conn, @@ -165,9 +173,10 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) conn: conn, }, Status: &statusDB{ - config: c, - conn: conn, - cache: cache.NewStatusCache(), + config: c, + conn: conn, + cache: cache.NewStatusCache(), + accounts: accounts, }, Timeline: &timelineDB{ config: c, diff --git a/internal/db/bundb/conn.go b/internal/db/bundb/conn.go @@ -12,6 +12,8 @@ import ( // dbConn wrapps a bun.DB conn to provide SQL-type specific additional functionality type DBConn struct { + // TODO: move *Config here, no need to be in each struct type + errProc func(error) db.Error // errProc is the SQL-type specific error processor log *logrus.Logger // log is the logger passed with this DBConn *bun.DB // DB is the underlying bun.DB connection @@ -35,6 +37,24 @@ func WrapDBConn(dbConn *bun.DB, log *logrus.Logger) *DBConn { } } +func (conn *DBConn) RunInTx(ctx context.Context, fn func(bun.Tx) error) db.Error { + // Acquire a new transaction + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + return conn.ProcessError(err) + } + + // Perform supplied transaction + if err = fn(tx); err != nil { + tx.Rollback() //nolint + return conn.ProcessError(err) + } + + // Finally, commit transaction + err = tx.Commit() + return conn.ProcessError(err) +} + // ProcessError processes an error to replace any known values with our own db.Error types, // making it easier to catch specific situations (e.g. no rows, already exists, etc) func (conn *DBConn) ProcessError(err error) db.Error { diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go @@ -237,7 +237,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI if _, err := r.conn. NewInsert(). Model(follow). - On("CONFLICT ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI). + On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI). Exec(ctx); err != nil { return nil, r.conn.ProcessError(err) } @@ -298,7 +298,7 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str if localOnly { q = q.ColumnExpr("follow.*"). - Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)"). + Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)"). Where("follow.target_account_id = ?", accountID). WhereGroup(" AND ", whereEmptyOrNull("a.domain")) } else { diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go @@ -0,0 +1,124 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type RelationshipTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *RelationshipTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testAttachments = testrig.NewTestAttachments() + suite.testStatuses = testrig.NewTestStatuses() + suite.testTags = testrig.NewTestTags() + suite.testMentions = testrig.NewTestMentions() +} + +func (suite *RelationshipTestSuite) SetupTest() { + suite.config = testrig.NewTestConfig() + suite.db = testrig.NewTestDB() + suite.log = testrig.NewTestLog() + + testrig.StandardDBSetup(suite.db, suite.testAccounts) +} + +func (suite *RelationshipTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) +} + +func (suite *RelationshipTestSuite) TestIsBlocked() { + suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) TestGetBlock() { + suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) TestGetRelationship() { + suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) TestIsFollowing() { + suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) TestIsMutualFollowing() { + suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) AcceptFollowRequest() { + for _, account := range suite.testAccounts { + _, err := suite.db.AcceptFollowRequest(context.Background(), account.ID, "NON-EXISTENT-ID") + if err != nil && !errors.Is(err, db.ErrNoEntries) { + suite.Suite.Fail("error accepting follow request: %v", err) + } + } +} + +func (suite *RelationshipTestSuite) GetAccountFollowRequests() { + suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) GetAccountFollows() { + suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) CountAccountFollows() { + suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) GetAccountFollowedBy() { + // TODO: more comprehensive tests here + + for _, account := range suite.testAccounts { + var err error + + _, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, false) + if err != nil { + suite.Suite.Fail("error checking accounts followed by: %v", err) + } + + _, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, true) + if err != nil { + suite.Suite.Fail("error checking localOnly accounts followed by: %v", err) + } + } +} + +func (suite *RelationshipTestSuite) CountAccountFollowedBy() { + suite.Suite.T().Skip("TODO: implement") +} + +func TestRelationshipTestSuite(t *testing.T) { + suite.Run(t, new(RelationshipTestSuite)) +} diff --git a/internal/db/bundb/sqlite-test.db b/internal/db/bundb/sqlite-test.db Binary files differ. diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go @@ -21,7 +21,6 @@ package bundb import ( "container/list" "context" - "errors" "time" "github.com/superseriousbusiness/gotosocial/internal/cache" @@ -35,6 +34,11 @@ type statusDB struct { config *config.Config conn *DBConn cache *cache.StatusCache + + // TODO: keep method definitions in same place but instead have receiver + // all point to one single "db" type, so they can all share methods + // and caches where necessary + accounts *accountDB } func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { @@ -51,30 +55,6 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { Relation("CreatedWithApplication") } -func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status { - if status.InReplyToID != "" && status.InReplyTo == nil { - // TODO: do we want to keep this possibly recursive strategy? - - if inReplyTo, cached := s.cache.GetByID(status.InReplyToID); cached { - status.InReplyTo = inReplyTo - } else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil { - status.InReplyTo = inReplyTo - } - } - - if status.BoostOfID != "" && status.BoostOf == nil { - // TODO: do we want to keep this possibly recursive strategy? - - if boostOf, cached := s.cache.GetByID(status.BoostOfID); cached { - status.BoostOf = boostOf - } else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil { - status.BoostOf = boostOf - } - } - - return status -} - func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery { return s.conn. NewSelect(). @@ -85,64 +65,79 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery { } func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) { - if status, cached := s.cache.GetByID(id); cached { - return status, nil - } - - status := >smodel.Status{} - - q := s.newStatusQ(status). - Where("status.id = ?", id) - - err := q.Scan(ctx) - if err != nil { - return nil, s.conn.ProcessError(err) - } - - s.cache.Put(status) - return s.getAttachedStatuses(ctx, status), nil + return s.getStatus( + ctx, + func() (*gtsmodel.Status, bool) { + return s.cache.GetByID(id) + }, + func(status *gtsmodel.Status) error { + return s.newStatusQ(status).Where("status.id = ?", id).Scan(ctx) + }, + ) } func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { - if status, cached := s.cache.GetByURI(uri); cached { - return status, nil - } - - status := >smodel.Status{} + return s.getStatus( + ctx, + func() (*gtsmodel.Status, bool) { + return s.cache.GetByURI(uri) + }, + func(status *gtsmodel.Status) error { + return s.newStatusQ(status).Where("LOWER(status.uri) = LOWER(?)", uri).Scan(ctx) + }, + ) +} - q := s.newStatusQ(status). - Where("LOWER(status.uri) = LOWER(?)", uri) +func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { + return s.getStatus( + ctx, + func() (*gtsmodel.Status, bool) { + return s.cache.GetByURL(url) + }, + func(status *gtsmodel.Status) error { + return s.newStatusQ(status).Where("LOWER(status.url) = LOWER(?)", url).Scan(ctx) + }, + ) +} - err := q.Scan(ctx) - if err != nil { - return nil, s.conn.ProcessError(err) - } +func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Status, bool), dbQuery func(*gtsmodel.Status) error) (*gtsmodel.Status, db.Error) { + // Attempt to fetch cached status + status, cached := cacheGet() - s.cache.Put(status) - return s.getAttachedStatuses(ctx, status), nil -} + if !cached { + status = >smodel.Status{} -func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { - if status, cached := s.cache.GetByURL(url); cached { - return status, nil - } + // Not cached! Perform database query + err := dbQuery(status) + if err != nil { + return nil, s.conn.ProcessError(err) + } - status := >smodel.Status{} + // If there is boosted, fetch from DB also + if status.BoostOfID != "" { + boostOf, err := s.GetStatusByID(ctx, status.BoostOfID) + if err == nil { + status.BoostOf = boostOf + } + } - q := s.newStatusQ(status). - Where("LOWER(status.url) = LOWER(?)", url) + // Place in the cache + s.cache.Put(status) + } - err := q.Scan(ctx) + // Set the status author account + author, err := s.accounts.GetAccountByID(ctx, status.AccountID) if err != nil { - return nil, s.conn.ProcessError(err) + return nil, err } - s.cache.Put(status) - return s.getAttachedStatuses(ctx, status), nil + // Return the prepared status + status.Account = author + return status, nil } func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { - transaction := func(ctx context.Context, tx bun.Tx) error { + return s.conn.RunInTx(ctx, func(tx bun.Tx) error { // create links between this status and any emojis it uses for _, i := range status.EmojiIDs { if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{ @@ -174,10 +169,10 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er } } + // Finally, insert the status _, err := tx.NewInsert().Model(status).Exec(ctx) return err - } - return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction)) + }) } func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { @@ -210,12 +205,8 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu children := []*gtsmodel.Status{} for e := foundStatuses.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) - } - // only append children, not the overall parent status + entry := e.Value.(*gtsmodel.Status) if entry.ID != status.ID { children = append(children, entry) } @@ -242,11 +233,7 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, for _, child := range immediateChildren { insertLoop: for e := foundStatuses.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) - } - + entry := e.Value.(*gtsmodel.Status) if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID { foundStatuses.InsertAfter(child, e) break insertLoop diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go @@ -105,10 +105,9 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() { suite.NotNil(status) suite.NotNil(status.Account) suite.NotNil(status.CreatedWithApplication) - suite.NotEmpty(status.Mentions) suite.NotEmpty(status.MentionIDs) - suite.NotNil(status.InReplyTo) - suite.NotNil(status.InReplyToAccount) + suite.NotEmpty(status.InReplyToID) + suite.NotEmpty(status.InReplyToAccountID) } func (suite *StatusTestSuite) TestGetStatusTwice() { diff --git a/internal/db/status.go b/internal/db/status.go @@ -26,13 +26,13 @@ import ( // Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses. type Status interface { - // GetStatusByID returns one status from the database, with all rel fields populated (if possible). + // GetStatusByID returns one status from the database, with no rel fields populated, only their linking ID / URIs GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error) - // GetStatusByURI returns one status from the database, with all rel fields populated (if possible). + // GetStatusByURI returns one status from the database, with no rel fields populated, only their linking ID / URIs GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error) - // GetStatusByURL returns one status from the database, with all rel fields populated (if possible). + // GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error) // PutStatus stores one status in the database. diff --git a/internal/federation/dereference.go b/internal/federation/dereference.go @@ -34,12 +34,12 @@ func (f *federator) EnrichRemoteAccount(ctx context.Context, username string, ac return f.dereferencer.EnrichRemoteAccount(ctx, username, account) } -func (f *federator) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) { - return f.dereferencer.GetRemoteStatus(ctx, username, remoteStatusID, refresh) +func (f *federator) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error) { + return f.dereferencer.GetRemoteStatus(ctx, username, remoteStatusID, refresh, includeParent, includeChilds) } -func (f *federator) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) { - return f.dereferencer.EnrichRemoteStatus(ctx, username, status) +func (f *federator) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error) { + return f.dereferencer.EnrichRemoteStatus(ctx, username, status, includeParent, includeChilds) } func (f *federator) DereferenceRemoteThread(ctx context.Context, username string, statusIRI *url.URL) error { diff --git a/internal/federation/dereferencing/account.go b/internal/federation/dereferencing/account.go @@ -48,7 +48,6 @@ func instanceAccount(account *gtsmodel.Account) bool { // EnrichRemoteAccount is mostly useful for calling after an account has been initially created by // the federatingDB's Create function, or during the federated authorization flow. func (d *deref) EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) { - // if we're dealing with an instance account, we don't need to update anything if instanceAccount(account) { return account, nil @@ -58,13 +57,13 @@ func (d *deref) EnrichRemoteAccount(ctx context.Context, username string, accoun return nil, err } - var err error - account, err = d.db.UpdateAccount(ctx, account) + updated, err := d.db.UpdateAccount(ctx, account) if err != nil { d.log.Errorf("EnrichRemoteAccount: error updating account: %s", err) + return account, nil } - return account, nil + return updated, nil } // GetRemoteAccount completely dereferences a remote account, converts it to a GtS model account, diff --git a/internal/federation/dereferencing/announce.go b/internal/federation/dereferencing/announce.go @@ -46,7 +46,7 @@ func (d *deref) DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Stat return fmt.Errorf("DereferenceAnnounce: error dereferencing thread of boosted status: %s", err) } - boostedStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, boostedStatusURI, false) + boostedStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, boostedStatusURI, false, false, false) if err != nil { return fmt.Errorf("DereferenceAnnounce: error dereferencing remote status with id %s: %s", announce.BoostOf.URI, err) } diff --git a/internal/federation/dereferencing/dereferencer.go b/internal/federation/dereferencing/dereferencer.go @@ -38,8 +38,8 @@ type Dereferencer interface { GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) - GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) - EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) + GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error) + EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error) GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) diff --git a/internal/federation/dereferencing/sqlite-test.db b/internal/federation/dereferencing/sqlite-test.db Binary files differ. diff --git a/internal/federation/dereferencing/status.go b/internal/federation/dereferencing/status.go @@ -39,8 +39,8 @@ import ( // // EnrichRemoteStatus is mostly useful for calling after a status has been initially created by // the federatingDB's Create function, but additional dereferencing is needed on it. -func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) { - if err := d.populateStatusFields(ctx, status, username); err != nil { +func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error) { + if err := d.populateStatusFields(ctx, status, username, includeParent, includeChilds); err != nil { return nil, err } @@ -62,7 +62,7 @@ func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status // If a dereference was performed, then the function also returns the ap.Statusable representation for further processing. // // SIDE EFFECTS: remote status will be stored in the database, and the remote status owner will also be stored. -func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) { +func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error) { new := true // check if we already have the status in our db @@ -105,7 +105,7 @@ func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStat } gtsStatus.ID = ulid - if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil { + if err := d.populateStatusFields(ctx, gtsStatus, username, includeParent, includeChilds); err != nil { return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err) } @@ -115,7 +115,7 @@ func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStat } else { gtsStatus.ID = maybeStatus.ID - if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil { + if err := d.populateStatusFields(ctx, gtsStatus, username, includeParent, includeChilds); err != nil { return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err) } @@ -235,7 +235,7 @@ func (d *deref) dereferenceStatusable(ctx context.Context, username string, remo // This function will deference all of the above, insert them in the database as necessary, // and attach them to the status. The status itself will not be added to the database yet, // that's up the caller to do. -func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Status, requestingUsername string) error { +func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Status, requestingUsername string, includeParent, includeChilds bool) error { l := d.log.WithFields(logrus.Fields{ "func": "dereferenceStatusFields", "status": fmt.Sprintf("%+v", status), @@ -275,14 +275,19 @@ func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Statu // 3. Emojis // TODO - // 4. Mentions - if err := d.populateStatusMentions(ctx, status, requestingUsername); err != nil { - return fmt.Errorf("populateStatusFields: error populating status mentions: %s", err) + // 4. Mentions (only if requested) + // TODO: do we need to handle removing empty mention objects and just using mention IDs slice? + if includeChilds { + if err := d.populateStatusMentions(ctx, status, requestingUsername); err != nil { + return fmt.Errorf("populateStatusFields: error populating status mentions: %s", err) + } } - // 5. Replied-to-status. - if err := d.populateStatusRepliedTo(ctx, status, requestingUsername); err != nil { - return fmt.Errorf("populateStatusFields: error populating status repliedTo: %s", err) + // 5. Replied-to-status (only if requested) + if includeParent { + if err := d.populateStatusRepliedTo(ctx, status, requestingUsername); err != nil { + return fmt.Errorf("populateStatusFields: error populating status repliedTo: %s", err) + } } return nil @@ -391,7 +396,6 @@ func (d *deref) populateStatusAttachments(ctx context.Context, status *gtsmodel. attachments := []*gtsmodel.MediaAttachment{} for _, a := range status.Attachments { - aURL, err := url.Parse(a.RemoteURL) if err != nil { l.Errorf("populateStatusAttachments: couldn't parse attachment url %s: %s", a.RemoteURL, err) @@ -401,6 +405,7 @@ func (d *deref) populateStatusAttachments(ctx context.Context, status *gtsmodel. attachment, err := d.GetRemoteAttachment(ctx, requestingUsername, aURL, status.AccountID, status.ID, a.File.ContentType) if err != nil { l.Errorf("populateStatusAttachments: couldn't get remote attachment %s: %s", a.RemoteURL, err) + continue } attachmentIDs = append(attachmentIDs, attachment.ID) @@ -420,29 +425,16 @@ func (d *deref) populateStatusRepliedTo(ctx context.Context, status *gtsmodel.St return err } - var replyToStatus *gtsmodel.Status - errs := []string{} - // see if we have the status in our db already - if s, err := d.db.GetStatusByURI(ctx, status.InReplyToURI); err != nil { - errs = append(errs, err.Error()) - } else { - replyToStatus = s - } - - if replyToStatus == nil { - // didn't find the status in our db, try to get it remotely - if s, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, statusURI, false); err != nil { - errs = append(errs, err.Error()) - } else { - replyToStatus = s + replyToStatus, err := d.db.GetStatusByURI(ctx, status.InReplyToURI) + if err != nil { + // Status was not in the DB, try fetch + replyToStatus, _, _, err = d.GetRemoteStatus(ctx, requestingUsername, statusURI, false, false, false) + if err != nil { + return fmt.Errorf("populateStatusRepliedTo: couldn't get reply to status with uri %s: %s", status.InReplyToURI, err) } } - if replyToStatus == nil { - return fmt.Errorf("populateStatusRepliedTo: couldn't get reply to status with uri %s: %s", statusURI, strings.Join(errs, " : ")) - } - // we have the status status.InReplyToID = replyToStatus.ID status.InReplyTo = replyToStatus diff --git a/internal/federation/dereferencing/status_test.go b/internal/federation/dereferencing/status_test.go @@ -119,7 +119,7 @@ func (suite *StatusTestSuite) TestDereferenceSimpleStatus() { fetchingAccount := suite.testAccounts["local_account_1"] statusURL := testrig.URLMustParse("https://unknown-instance.com/users/brand_new_person/statuses/01FE4NTHKWW7THT67EF10EB839") - status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false) + status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false, false, false) suite.NoError(err) suite.NotNil(status) suite.NotNil(statusable) @@ -157,7 +157,7 @@ func (suite *StatusTestSuite) TestDereferenceStatusWithMention() { fetchingAccount := suite.testAccounts["local_account_1"] statusURL := testrig.URLMustParse("https://unknown-instance.com/users/brand_new_person/statuses/01FE5Y30E3W4P7TRE0R98KAYQV") - status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false) + status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false, false, true) suite.NoError(err) suite.NotNil(status) suite.NotNil(statusable) diff --git a/internal/federation/dereferencing/thread.go b/internal/federation/dereferencing/thread.go @@ -49,7 +49,7 @@ func (d *deref) DereferenceThread(ctx context.Context, username string, statusIR } // first make sure we have this status in our db - _, statusable, _, err := d.GetRemoteStatus(ctx, username, statusIRI, true) + _, statusable, _, err := d.GetRemoteStatus(ctx, username, statusIRI, true, false, false) if err != nil { return fmt.Errorf("DereferenceThread: error getting status with id %s: %s", statusIRI.String(), err) } @@ -104,7 +104,7 @@ func (d *deref) iterateAncestors(ctx context.Context, username string, statusIRI // If we reach here, we're looking at a remote status -- make sure we have it in our db by calling GetRemoteStatus // We call it with refresh to true because we want the statusable representation to parse inReplyTo from. - status, statusable, _, err := d.GetRemoteStatus(ctx, username, &statusIRI, true) + _, statusable, _, err := d.GetRemoteStatus(ctx, username, &statusIRI, true, false, false) if err != nil { l.Debugf("error getting remote status: %s", err) return nil @@ -116,18 +116,6 @@ func (d *deref) iterateAncestors(ctx context.Context, username string, statusIRI return nil } - // get the ancestor status into our database if we don't have it yet - if _, _, _, err := d.GetRemoteStatus(ctx, username, inReplyTo, false); err != nil { - l.Debugf("error getting remote status: %s", err) - return nil - } - - // now enrich the current status, since we should have the ancestor in the db - if _, err := d.EnrichRemoteStatus(ctx, username, status); err != nil { - l.Debugf("error enriching remote status: %s", err) - return nil - } - // now move up to the next ancestor return d.iterateAncestors(ctx, username, *inReplyTo) } @@ -226,7 +214,7 @@ pageLoop: foundReplies = foundReplies + 1 // get the remote statusable and put it in the db - _, statusable, new, err := d.GetRemoteStatus(ctx, username, itemURI, false) + _, statusable, new, err := d.GetRemoteStatus(ctx, username, itemURI, false, false, false) if new && err == nil && statusable != nil { // now iterate descendants of *that* status if err := d.iterateDescendants(ctx, username, *itemURI, statusable); err != nil { diff --git a/internal/federation/federator.go b/internal/federation/federator.go @@ -62,8 +62,8 @@ type Federator interface { GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) - GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) - EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) + GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error) + EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error) GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) @@ -88,7 +88,6 @@ type federator struct { // NewFederator returns a new federator func NewFederator(db db.DB, federatingDB federatingdb.DB, transportController transport.Controller, config *config.Config, log *logrus.Logger, typeConverter typeutils.TypeConverter, mediaHandler media.Handler) Federator { - dereferencer := dereferencing.NewDereferencer(config, db, typeConverter, transportController, mediaHandler, log) clock := &Clock{} diff --git a/internal/federation/sqlite-test.db b/internal/federation/sqlite-test.db Binary files differ. diff --git a/internal/oauth/sqlite-test.db b/internal/oauth/sqlite-test.db Binary files differ. diff --git a/internal/processing/fromfederator.go b/internal/processing/fromfederator.go @@ -49,7 +49,7 @@ func (p *processor) processFromFederator(ctx context.Context, federatorMsg gtsmo return errors.New("note was not parseable as *gtsmodel.Status") } - status, err := p.federator.EnrichRemoteStatus(ctx, federatorMsg.ReceivingAccount.Username, incomingStatus) + status, err := p.federator.EnrichRemoteStatus(ctx, federatorMsg.ReceivingAccount.Username, incomingStatus, false, false) if err != nil { return err } diff --git a/internal/processing/search.go b/internal/processing/search.go @@ -130,7 +130,7 @@ func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, u // we don't have it locally so dereference it if we're allowed to if resolve { - status, _, _, err := p.federator.GetRemoteStatus(ctx, authed.Account.Username, uri, true) + status, _, _, err := p.federator.GetRemoteStatus(ctx, authed.Account.Username, uri, true, false, false) if err == nil { if err := p.federator.DereferenceRemoteThread(ctx, authed.Account.Username, uri); err != nil { // try to deref the thread while we're here diff --git a/internal/processing/status/sqlite-test.db b/internal/processing/status/sqlite-test.db Binary files differ. diff --git a/internal/text/sqlite-test.db b/internal/text/sqlite-test.db Binary files differ. diff --git a/internal/timeline/sqlite-test.db b/internal/timeline/sqlite-test.db Binary files differ. diff --git a/internal/typeutils/astointernal.go b/internal/typeutils/astointernal.go @@ -339,7 +339,6 @@ func (c *converter) ASStatusToStatus(ctx context.Context, statusable ap.Statusab } func (c *converter) ASFollowToFollowRequest(ctx context.Context, followable ap.Followable) (*gtsmodel.FollowRequest, error) { - idProp := followable.GetJSONLDId() if idProp == nil || !idProp.IsIRI() { return nil, errors.New("no id property set on follow, or was not an iri") diff --git a/internal/typeutils/sqlite-test.db b/internal/typeutils/sqlite-test.db Binary files differ.