relationship_follow_req.go (10618B)
1 // GoToSocial 2 // Copyright (C) GoToSocial Authors admin@gotosocial.org 3 // SPDX-License-Identifier: AGPL-3.0-or-later 4 // 5 // This program is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Affero General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // This program is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Affero General Public License for more details. 14 // 15 // You should have received a copy of the GNU Affero General Public License 16 // along with this program. If not, see <http://www.gnu.org/licenses/>. 17 18 package bundb 19 20 import ( 21 "context" 22 "errors" 23 "fmt" 24 "time" 25 26 "github.com/superseriousbusiness/gotosocial/internal/db" 27 "github.com/superseriousbusiness/gotosocial/internal/gtscontext" 28 "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" 29 "github.com/superseriousbusiness/gotosocial/internal/log" 30 "github.com/uptrace/bun" 31 ) 32 33 func (r *relationshipDB) GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error) { 34 return r.getFollowRequest( 35 ctx, 36 "ID", 37 func(followReq *gtsmodel.FollowRequest) error { 38 return r.conn.NewSelect(). 39 Model(followReq). 40 Where("? = ?", bun.Ident("id"), id). 41 Scan(ctx) 42 }, 43 id, 44 ) 45 } 46 47 func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string) (*gtsmodel.FollowRequest, error) { 48 return r.getFollowRequest( 49 ctx, 50 "URI", 51 func(followReq *gtsmodel.FollowRequest) error { 52 return r.conn.NewSelect(). 53 Model(followReq). 54 Where("? = ?", bun.Ident("uri"), uri). 55 Scan(ctx) 56 }, 57 uri, 58 ) 59 } 60 61 func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) { 62 return r.getFollowRequest( 63 ctx, 64 "AccountID.TargetAccountID", 65 func(followReq *gtsmodel.FollowRequest) error { 66 return r.conn.NewSelect(). 67 Model(followReq). 68 Where("? = ?", bun.Ident("account_id"), sourceAccountID). 69 Where("? = ?", bun.Ident("target_account_id"), targetAccountID). 70 Scan(ctx) 71 }, 72 sourceAccountID, 73 targetAccountID, 74 ) 75 } 76 77 func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.FollowRequest, error) { 78 // Preallocate slice of expected length. 79 followReqs := make([]*gtsmodel.FollowRequest, 0, len(ids)) 80 81 for _, id := range ids { 82 // Fetch follow request model for this ID. 83 followReq, err := r.GetFollowRequestByID(ctx, id) 84 if err != nil { 85 log.Errorf(ctx, "error getting follow request %q: %v", id, err) 86 continue 87 } 88 89 // Append to return slice. 90 followReqs = append(followReqs, followReq) 91 } 92 93 return followReqs, nil 94 } 95 96 func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) { 97 followReq, err := r.GetFollowRequest( 98 gtscontext.SetBarebones(ctx), 99 sourceAccountID, 100 targetAccountID, 101 ) 102 if err != nil && !errors.Is(err, db.ErrNoEntries) { 103 return false, err 104 } 105 return (followReq != nil), nil 106 } 107 108 func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, dbQuery func(*gtsmodel.FollowRequest) error, keyParts ...any) (*gtsmodel.FollowRequest, error) { 109 // Fetch follow request from database cache with loader callback 110 followReq, err := r.state.Caches.GTS.FollowRequest().Load(lookup, func() (*gtsmodel.FollowRequest, error) { 111 var followReq gtsmodel.FollowRequest 112 113 // Not cached! Perform database query 114 if err := dbQuery(&followReq); err != nil { 115 return nil, r.conn.ProcessError(err) 116 } 117 118 return &followReq, nil 119 }, keyParts...) 120 if err != nil { 121 // error already processed 122 return nil, err 123 } 124 125 if gtscontext.Barebones(ctx) { 126 // Only a barebones model was requested. 127 return followReq, nil 128 } 129 130 // Set the follow request source account 131 followReq.Account, err = r.state.DB.GetAccountByID( 132 gtscontext.SetBarebones(ctx), 133 followReq.AccountID, 134 ) 135 if err != nil { 136 return nil, fmt.Errorf("error getting follow request source account: %w", err) 137 } 138 139 // Set the follow request target account 140 followReq.TargetAccount, err = r.state.DB.GetAccountByID( 141 gtscontext.SetBarebones(ctx), 142 followReq.TargetAccountID, 143 ) 144 if err != nil { 145 return nil, fmt.Errorf("error getting follow request target account: %w", err) 146 } 147 148 return followReq, nil 149 } 150 151 func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error { 152 return r.state.Caches.GTS.FollowRequest().Store(follow, func() error { 153 _, err := r.conn.NewInsert().Model(follow).Exec(ctx) 154 return r.conn.ProcessError(err) 155 }) 156 } 157 158 func (r *relationshipDB) UpdateFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest, columns ...string) error { 159 followRequest.UpdatedAt = time.Now() 160 if len(columns) > 0 { 161 // If we're updating by column, ensure "updated_at" is included. 162 columns = append(columns, "updated_at") 163 } 164 165 return r.state.Caches.GTS.FollowRequest().Store(followRequest, func() error { 166 if _, err := r.conn.NewUpdate(). 167 Model(followRequest). 168 Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). 169 Column(columns...). 170 Exec(ctx); err != nil { 171 return r.conn.ProcessError(err) 172 } 173 174 return nil 175 }) 176 } 177 178 func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { 179 // Get original follow request. 180 followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID) 181 if err != nil { 182 return nil, err 183 } 184 185 // Create a new follow to 'replace' 186 // the original follow request with. 187 follow := >smodel.Follow{ 188 ID: followReq.ID, 189 AccountID: sourceAccountID, 190 Account: followReq.Account, 191 TargetAccountID: targetAccountID, 192 TargetAccount: followReq.TargetAccount, 193 URI: followReq.URI, 194 ShowReblogs: followReq.ShowReblogs, 195 Notify: followReq.Notify, 196 } 197 198 if err := r.state.Caches.GTS.Follow().Store(follow, func() error { 199 // If the follow already exists, just 200 // replace the URI with the new one. 201 _, err := r.conn. 202 NewInsert(). 203 Model(follow). 204 On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI). 205 Exec(ctx) 206 return r.conn.ProcessError(err) 207 }); err != nil { 208 return nil, err 209 } 210 211 // Invalidate follow request from cache lookups on return. 212 defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID) 213 214 // Delete original follow request. 215 if _, err := r.conn. 216 NewDelete(). 217 Table("follow_requests"). 218 Where("? = ?", bun.Ident("id"), followReq.ID). 219 Exec(ctx); err != nil { 220 return nil, r.conn.ProcessError(err) 221 } 222 223 // Delete original follow request notification 224 if err := r.state.DB.DeleteNotifications(ctx, []string{ 225 string(gtsmodel.NotificationFollowRequest), 226 }, targetAccountID, sourceAccountID); err != nil { 227 return nil, err 228 } 229 230 return follow, nil 231 } 232 233 func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) db.Error { 234 defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) 235 236 // Load followreq into cache before attempting a delete, 237 // as we need it cached in order to trigger the invalidate 238 // callback. This in turn invalidates others. 239 _, err := r.GetFollowRequest(gtscontext.SetBarebones(ctx), 240 sourceAccountID, 241 targetAccountID, 242 ) 243 if err != nil { 244 return err 245 } 246 247 // Attempt to delete follow request. 248 if _, err = r.conn.NewDelete(). 249 Table("follow_requests"). 250 Where("? = ? AND ? = ?", 251 bun.Ident("account_id"), 252 sourceAccountID, 253 bun.Ident("target_account_id"), 254 targetAccountID, 255 ). 256 Exec(ctx); err != nil { 257 return r.conn.ProcessError(err) 258 } 259 260 // Delete original follow request notification 261 return r.state.DB.DeleteNotifications(ctx, []string{ 262 string(gtsmodel.NotificationFollowRequest), 263 }, targetAccountID, sourceAccountID) 264 } 265 266 func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error { 267 defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) 268 269 // Load followreq into cache before attempting a delete, 270 // as we need it cached in order to trigger the invalidate 271 // callback. This in turn invalidates others. 272 _, err := r.GetFollowRequestByID(gtscontext.SetBarebones(ctx), id) 273 if err != nil { 274 if errors.Is(err, db.ErrNoEntries) { 275 // not an issue. 276 err = nil 277 } 278 return err 279 } 280 281 // Finally delete followreq from DB. 282 _, err = r.conn.NewDelete(). 283 Table("follow_requests"). 284 Where("? = ?", bun.Ident("id"), id). 285 Exec(ctx) 286 return r.conn.ProcessError(err) 287 } 288 289 func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error { 290 defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri) 291 292 // Load followreq into cache before attempting a delete, 293 // as we need it cached in order to trigger the invalidate 294 // callback. This in turn invalidates others. 295 _, err := r.GetFollowRequestByURI(gtscontext.SetBarebones(ctx), uri) 296 if err != nil { 297 if errors.Is(err, db.ErrNoEntries) { 298 // not an issue. 299 err = nil 300 } 301 return err 302 } 303 304 // Finally delete followreq from DB. 305 _, err = r.conn.NewDelete(). 306 Table("follow_requests"). 307 Where("? = ?", bun.Ident("uri"), uri). 308 Exec(ctx) 309 return r.conn.ProcessError(err) 310 } 311 312 func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accountID string) error { 313 var followReqIDs []string 314 315 // Get full list of IDs. 316 if _, err := r.conn. 317 NewSelect(). 318 Column("id"). 319 Table("follow_requestss"). 320 WhereOr("? = ? OR ? = ?", 321 bun.Ident("account_id"), 322 accountID, 323 bun.Ident("target_account_id"), 324 accountID, 325 ). 326 Exec(ctx, &followReqIDs); err != nil { 327 return r.conn.ProcessError(err) 328 } 329 330 defer func() { 331 // Invalidate all IDs on return. 332 for _, id := range followReqIDs { 333 r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) 334 } 335 }() 336 337 // Load all followreqs into cache, this *really* isn't 338 // great but it is the only way we can ensure we invalidate 339 // all related caches correctly (e.g. visibility). 340 for _, id := range followReqIDs { 341 _, err := r.GetFollowRequestByID(ctx, id) 342 if err != nil && !errors.Is(err, db.ErrNoEntries) { 343 return err 344 } 345 } 346 347 // Finally delete all from DB. 348 _, err := r.conn.NewDelete(). 349 Table("follow_requests"). 350 Where("? IN (?)", bun.Ident("id"), bun.In(followReqIDs)). 351 Exec(ctx) 352 return r.conn.ProcessError(err) 353 }