relationship_follow.go (8616B)
1 // GoToSocial 2 // Copyright (C) GoToSocial Authors admin@gotosocial.org 3 // SPDX-License-Identifier: AGPL-3.0-or-later 4 // 5 // This program is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Affero General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // This program is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Affero General Public License for more details. 14 // 15 // You should have received a copy of the GNU Affero General Public License 16 // along with this program. If not, see <http://www.gnu.org/licenses/>. 17 18 package bundb 19 20 import ( 21 "context" 22 "errors" 23 "fmt" 24 "time" 25 26 "github.com/superseriousbusiness/gotosocial/internal/db" 27 "github.com/superseriousbusiness/gotosocial/internal/gtscontext" 28 "github.com/superseriousbusiness/gotosocial/internal/gtserror" 29 "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" 30 "github.com/superseriousbusiness/gotosocial/internal/log" 31 "github.com/uptrace/bun" 32 ) 33 34 func (r *relationshipDB) GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error) { 35 return r.getFollow( 36 ctx, 37 "ID", 38 func(follow *gtsmodel.Follow) error { 39 return r.conn.NewSelect(). 40 Model(follow). 41 Where("? = ?", bun.Ident("id"), id). 42 Scan(ctx) 43 }, 44 id, 45 ) 46 } 47 48 func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmodel.Follow, error) { 49 return r.getFollow( 50 ctx, 51 "URI", 52 func(follow *gtsmodel.Follow) error { 53 return r.conn.NewSelect(). 54 Model(follow). 55 Where("? = ?", bun.Ident("uri"), uri). 56 Scan(ctx) 57 }, 58 uri, 59 ) 60 } 61 62 func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) { 63 return r.getFollow( 64 ctx, 65 "AccountID.TargetAccountID", 66 func(follow *gtsmodel.Follow) error { 67 return r.conn.NewSelect(). 68 Model(follow). 69 Where("? = ?", bun.Ident("account_id"), sourceAccountID). 70 Where("? = ?", bun.Ident("target_account_id"), targetAccountID). 71 Scan(ctx) 72 }, 73 sourceAccountID, 74 targetAccountID, 75 ) 76 } 77 78 func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) { 79 // Preallocate slice of expected length. 80 follows := make([]*gtsmodel.Follow, 0, len(ids)) 81 82 for _, id := range ids { 83 // Fetch follow model for this ID. 84 follow, err := r.GetFollowByID(ctx, id) 85 if err != nil { 86 log.Errorf(ctx, "error getting follow %q: %v", id, err) 87 continue 88 } 89 90 // Append to return slice. 91 follows = append(follows, follow) 92 } 93 94 return follows, nil 95 } 96 97 func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) { 98 follow, err := r.GetFollow( 99 gtscontext.SetBarebones(ctx), 100 sourceAccountID, 101 targetAccountID, 102 ) 103 if err != nil && !errors.Is(err, db.ErrNoEntries) { 104 return false, err 105 } 106 return (follow != nil), nil 107 } 108 109 func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 string, accountID2 string) (bool, db.Error) { 110 // make sure account 1 follows account 2 111 f1, err := r.IsFollowing(ctx, 112 accountID1, 113 accountID2, 114 ) 115 if !f1 /* f1 = false when err != nil */ { 116 return false, err 117 } 118 119 // make sure account 2 follows account 1 120 f2, err := r.IsFollowing(ctx, 121 accountID2, 122 accountID1, 123 ) 124 if !f2 /* f2 = false when err != nil */ { 125 return false, err 126 } 127 128 return true, nil 129 } 130 131 func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) { 132 // Fetch follow from database cache with loader callback 133 follow, err := r.state.Caches.GTS.Follow().Load(lookup, func() (*gtsmodel.Follow, error) { 134 var follow gtsmodel.Follow 135 136 // Not cached! Perform database query 137 if err := dbQuery(&follow); err != nil { 138 return nil, r.conn.ProcessError(err) 139 } 140 141 return &follow, nil 142 }, keyParts...) 143 if err != nil { 144 // error already processed 145 return nil, err 146 } 147 148 if gtscontext.Barebones(ctx) { 149 // Only a barebones model was requested. 150 return follow, nil 151 } 152 153 if err := r.state.DB.PopulateFollow(ctx, follow); err != nil { 154 return nil, err 155 } 156 157 return follow, nil 158 } 159 160 func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Follow) error { 161 var ( 162 err error 163 errs = make(gtserror.MultiError, 0, 2) 164 ) 165 166 if follow.Account == nil { 167 // Follow account is not set, fetch from the database. 168 follow.Account, err = r.state.DB.GetAccountByID( 169 gtscontext.SetBarebones(ctx), 170 follow.AccountID, 171 ) 172 if err != nil { 173 errs.Append(fmt.Errorf("error populating follow account: %w", err)) 174 } 175 } 176 177 if follow.TargetAccount == nil { 178 // Follow target account is not set, fetch from the database. 179 follow.TargetAccount, err = r.state.DB.GetAccountByID( 180 gtscontext.SetBarebones(ctx), 181 follow.TargetAccountID, 182 ) 183 if err != nil { 184 errs.Append(fmt.Errorf("error populating follow target account: %w", err)) 185 } 186 } 187 188 return errs.Combine() 189 } 190 191 func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error { 192 return r.state.Caches.GTS.Follow().Store(follow, func() error { 193 _, err := r.conn.NewInsert().Model(follow).Exec(ctx) 194 return r.conn.ProcessError(err) 195 }) 196 } 197 198 func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Follow, columns ...string) error { 199 follow.UpdatedAt = time.Now() 200 if len(columns) > 0 { 201 // If we're updating by column, ensure "updated_at" is included. 202 columns = append(columns, "updated_at") 203 } 204 205 return r.state.Caches.GTS.Follow().Store(follow, func() error { 206 if _, err := r.conn.NewUpdate(). 207 Model(follow). 208 Where("? = ?", bun.Ident("follow.id"), follow.ID). 209 Column(columns...). 210 Exec(ctx); err != nil { 211 return r.conn.ProcessError(err) 212 } 213 214 return nil 215 }) 216 } 217 218 func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error { 219 // Delete the follow itself using the given ID. 220 if _, err := r.conn.NewDelete(). 221 Table("follows"). 222 Where("? = ?", bun.Ident("id"), id). 223 Exec(ctx); err != nil { 224 return r.conn.ProcessError(err) 225 } 226 227 // Delete every list entry that used this followID. 228 if err := r.state.DB.DeleteListEntriesForFollowID(ctx, id); err != nil { 229 return fmt.Errorf("deleteFollow: error deleting list entries: %w", err) 230 } 231 232 return nil 233 } 234 235 func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error { 236 defer r.state.Caches.GTS.Follow().Invalidate("ID", id) 237 238 // Load follow into cache before attempting a delete, 239 // as we need it cached in order to trigger the invalidate 240 // callback. This in turn invalidates others. 241 follow, err := r.GetFollowByID(gtscontext.SetBarebones(ctx), id) 242 if err != nil { 243 if errors.Is(err, db.ErrNoEntries) { 244 // Already gone. 245 return nil 246 } 247 return err 248 } 249 250 // Finally delete follow from DB. 251 return r.deleteFollow(ctx, follow.ID) 252 } 253 254 func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error { 255 defer r.state.Caches.GTS.Follow().Invalidate("URI", uri) 256 257 // Load follow into cache before attempting a delete, 258 // as we need it cached in order to trigger the invalidate 259 // callback. This in turn invalidates others. 260 follow, err := r.GetFollowByURI(gtscontext.SetBarebones(ctx), uri) 261 if err != nil { 262 if errors.Is(err, db.ErrNoEntries) { 263 // Already gone. 264 return nil 265 } 266 return err 267 } 268 269 // Finally delete follow from DB. 270 return r.deleteFollow(ctx, follow.ID) 271 } 272 273 func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID string) error { 274 var followIDs []string 275 276 // Get full list of IDs. 277 if _, err := r.conn. 278 NewSelect(). 279 Column("id"). 280 Table("follows"). 281 WhereOr("? = ? OR ? = ?", 282 bun.Ident("account_id"), 283 accountID, 284 bun.Ident("target_account_id"), 285 accountID, 286 ). 287 Exec(ctx, &followIDs); err != nil { 288 return r.conn.ProcessError(err) 289 } 290 291 defer func() { 292 // Invalidate all IDs on return. 293 for _, id := range followIDs { 294 r.state.Caches.GTS.Follow().Invalidate("ID", id) 295 } 296 }() 297 298 // Load all follows into cache, this *really* isn't great 299 // but it is the only way we can ensure we invalidate all 300 // related caches correctly (e.g. visibility). 301 for _, id := range followIDs { 302 follow, err := r.GetFollowByID(ctx, id) 303 if err != nil && !errors.Is(err, db.ErrNoEntries) { 304 return err 305 } 306 307 // Delete each follow from DB. 308 if err := r.deleteFollow(ctx, follow.ID); err != nil && !errors.Is(err, db.ErrNoEntries) { 309 return err 310 } 311 } 312 313 return nil 314 }