list.go (12172B)
1 // GoToSocial 2 // Copyright (C) GoToSocial Authors admin@gotosocial.org 3 // SPDX-License-Identifier: AGPL-3.0-or-later 4 // 5 // This program is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Affero General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // This program is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Affero General Public License for more details. 14 // 15 // You should have received a copy of the GNU Affero General Public License 16 // along with this program. If not, see <http://www.gnu.org/licenses/>. 17 18 package bundb 19 20 import ( 21 "context" 22 "errors" 23 "fmt" 24 "time" 25 26 "github.com/superseriousbusiness/gotosocial/internal/db" 27 "github.com/superseriousbusiness/gotosocial/internal/gtscontext" 28 "github.com/superseriousbusiness/gotosocial/internal/gtserror" 29 "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" 30 "github.com/superseriousbusiness/gotosocial/internal/log" 31 "github.com/superseriousbusiness/gotosocial/internal/state" 32 "github.com/uptrace/bun" 33 ) 34 35 type listDB struct { 36 conn *DBConn 37 state *state.State 38 } 39 40 /* 41 LIST FUNCTIONS 42 */ 43 44 func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) { 45 list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) { 46 var list gtsmodel.List 47 48 // Not cached! Perform database query. 49 if err := dbQuery(&list); err != nil { 50 return nil, l.conn.ProcessError(err) 51 } 52 53 return &list, nil 54 }, keyParts...) 55 if err != nil { 56 return nil, err // already processed 57 } 58 59 if gtscontext.Barebones(ctx) { 60 // Only a barebones model was requested. 61 return list, nil 62 } 63 64 if err := l.state.DB.PopulateList(ctx, list); err != nil { 65 return nil, err 66 } 67 68 return list, nil 69 } 70 71 func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, error) { 72 return l.getList( 73 ctx, 74 "ID", 75 func(list *gtsmodel.List) error { 76 return l.conn.NewSelect(). 77 Model(list). 78 Where("? = ?", bun.Ident("list.id"), id). 79 Scan(ctx) 80 }, 81 id, 82 ) 83 } 84 85 func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) { 86 // Fetch IDs of all lists owned by this account. 87 var listIDs []string 88 if err := l.conn. 89 NewSelect(). 90 TableExpr("? AS ?", bun.Ident("lists"), bun.Ident("list")). 91 Column("list.id"). 92 Where("? = ?", bun.Ident("list.account_id"), accountID). 93 Order("list.id DESC"). 94 Scan(ctx, &listIDs); err != nil { 95 return nil, l.conn.ProcessError(err) 96 } 97 98 if len(listIDs) == 0 { 99 return nil, nil 100 } 101 102 // Select each list using its ID to ensure cache used. 103 lists := make([]*gtsmodel.List, 0, len(listIDs)) 104 for _, id := range listIDs { 105 list, err := l.state.DB.GetListByID(ctx, id) 106 if err != nil { 107 log.Errorf(ctx, "error fetching list %q: %v", id, err) 108 continue 109 } 110 111 // Append list. 112 lists = append(lists, list) 113 } 114 115 return lists, nil 116 } 117 118 func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { 119 var ( 120 err error 121 errs = make(gtserror.MultiError, 0, 2) 122 ) 123 124 if list.Account == nil { 125 // List account is not set, fetch from the database. 126 list.Account, err = l.state.DB.GetAccountByID( 127 gtscontext.SetBarebones(ctx), 128 list.AccountID, 129 ) 130 if err != nil { 131 errs.Append(fmt.Errorf("error populating list account: %w", err)) 132 } 133 } 134 135 if list.ListEntries == nil { 136 // List entries are not set, fetch from the database. 137 list.ListEntries, err = l.state.DB.GetListEntries( 138 gtscontext.SetBarebones(ctx), 139 list.ID, 140 "", "", "", 0, 141 ) 142 if err != nil { 143 errs.Append(fmt.Errorf("error populating list entries: %w", err)) 144 } 145 } 146 147 return errs.Combine() 148 } 149 150 func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error { 151 return l.state.Caches.GTS.List().Store(list, func() error { 152 _, err := l.conn.NewInsert().Model(list).Exec(ctx) 153 return l.conn.ProcessError(err) 154 }) 155 } 156 157 func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ...string) error { 158 list.UpdatedAt = time.Now() 159 if len(columns) > 0 { 160 // If we're updating by column, ensure "updated_at" is included. 161 columns = append(columns, "updated_at") 162 } 163 164 return l.state.Caches.GTS.List().Store(list, func() error { 165 if _, err := l.conn.NewUpdate(). 166 Model(list). 167 Where("? = ?", bun.Ident("list.id"), list.ID). 168 Column(columns...). 169 Exec(ctx); err != nil { 170 return l.conn.ProcessError(err) 171 } 172 173 return nil 174 }) 175 } 176 177 func (l *listDB) DeleteListByID(ctx context.Context, id string) error { 178 defer l.state.Caches.GTS.List().Invalidate("ID", id) 179 180 // Select all entries that belong to this list. 181 listEntries, err := l.state.DB.GetListEntries(ctx, id, "", "", "", 0) 182 if err != nil { 183 return fmt.Errorf("error selecting entries from list %q: %w", id, err) 184 } 185 186 // Delete each list entry. This will 187 // invalidate the list timeline too. 188 for _, listEntry := range listEntries { 189 err := l.state.DB.DeleteListEntry(ctx, listEntry.ID) 190 if err != nil && !errors.Is(err, db.ErrNoEntries) { 191 return err 192 } 193 } 194 195 // Finally delete list itself from DB. 196 _, err = l.conn.NewDelete(). 197 Table("lists"). 198 Where("? = ?", bun.Ident("id"), id). 199 Exec(ctx) 200 return l.conn.ProcessError(err) 201 } 202 203 /* 204 LIST ENTRY functions 205 */ 206 207 func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) { 208 listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) { 209 var listEntry gtsmodel.ListEntry 210 211 // Not cached! Perform database query. 212 if err := dbQuery(&listEntry); err != nil { 213 return nil, l.conn.ProcessError(err) 214 } 215 216 return &listEntry, nil 217 }, keyParts...) 218 if err != nil { 219 return nil, err // already processed 220 } 221 222 if gtscontext.Barebones(ctx) { 223 // Only a barebones model was requested. 224 return listEntry, nil 225 } 226 227 // Further populate the list entry fields where applicable. 228 if err := l.state.DB.PopulateListEntry(ctx, listEntry); err != nil { 229 return nil, err 230 } 231 232 return listEntry, nil 233 } 234 235 func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) { 236 return l.getListEntry( 237 ctx, 238 "ID", 239 func(listEntry *gtsmodel.ListEntry) error { 240 return l.conn.NewSelect(). 241 Model(listEntry). 242 Where("? = ?", bun.Ident("list_entry.id"), id). 243 Scan(ctx) 244 }, 245 id, 246 ) 247 } 248 249 func (l *listDB) GetListEntries(ctx context.Context, 250 listID string, 251 maxID string, 252 sinceID string, 253 minID string, 254 limit int, 255 ) ([]*gtsmodel.ListEntry, error) { 256 // Ensure reasonable 257 if limit < 0 { 258 limit = 0 259 } 260 261 // Make educated guess for slice size 262 var ( 263 entryIDs = make([]string, 0, limit) 264 frontToBack = true 265 ) 266 267 q := l.conn. 268 NewSelect(). 269 TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")). 270 // Select only IDs from table 271 Column("entry.id"). 272 // Select only entries belonging to listID. 273 Where("? = ?", bun.Ident("entry.list_id"), listID) 274 275 if maxID != "" { 276 // return only entries LOWER (ie., older) than maxID 277 q = q.Where("? < ?", bun.Ident("entry.id"), maxID) 278 } 279 280 if sinceID != "" { 281 // return only entries HIGHER (ie., newer) than sinceID 282 q = q.Where("? > ?", bun.Ident("entry.id"), sinceID) 283 } 284 285 if minID != "" { 286 // return only entries HIGHER (ie., newer) than minID 287 q = q.Where("? > ?", bun.Ident("entry.id"), minID) 288 289 // page up 290 frontToBack = false 291 } 292 293 if limit > 0 { 294 // limit amount of entries returned 295 q = q.Limit(limit) 296 } 297 298 if frontToBack { 299 // Page down. 300 q = q.Order("entry.id DESC") 301 } else { 302 // Page up. 303 q = q.Order("entry.id ASC") 304 } 305 306 if err := q.Scan(ctx, &entryIDs); err != nil { 307 return nil, l.conn.ProcessError(err) 308 } 309 310 if len(entryIDs) == 0 { 311 return nil, nil 312 } 313 314 // If we're paging up, we still want entries 315 // to be sorted by ID desc, so reverse ids slice. 316 // https://zchee.github.io/golang-wiki/SliceTricks/#reversing 317 if !frontToBack { 318 for l, r := 0, len(entryIDs)-1; l < r; l, r = l+1, r-1 { 319 entryIDs[l], entryIDs[r] = entryIDs[r], entryIDs[l] 320 } 321 } 322 323 // Select each list entry using its ID to ensure cache used. 324 listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs)) 325 for _, id := range entryIDs { 326 listEntry, err := l.state.DB.GetListEntryByID(ctx, id) 327 if err != nil { 328 log.Errorf(ctx, "error fetching list entry %q: %v", id, err) 329 continue 330 } 331 332 // Append list entries. 333 listEntries = append(listEntries, listEntry) 334 } 335 336 return listEntries, nil 337 } 338 339 func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) { 340 entryIDs := []string{} 341 342 if err := l.conn. 343 NewSelect(). 344 TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")). 345 // Select only IDs from table 346 Column("entry.id"). 347 // Select only entries belonging with given followID. 348 Where("? = ?", bun.Ident("entry.follow_id"), followID). 349 Scan(ctx, &entryIDs); err != nil { 350 return nil, l.conn.ProcessError(err) 351 } 352 353 if len(entryIDs) == 0 { 354 return nil, nil 355 } 356 357 // Select each list entry using its ID to ensure cache used. 358 listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs)) 359 for _, id := range entryIDs { 360 listEntry, err := l.state.DB.GetListEntryByID(ctx, id) 361 if err != nil { 362 log.Errorf(ctx, "error fetching list entry %q: %v", id, err) 363 continue 364 } 365 366 // Append list entries. 367 listEntries = append(listEntries, listEntry) 368 } 369 370 return listEntries, nil 371 } 372 373 func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error { 374 var err error 375 376 if listEntry.Follow == nil { 377 // ListEntry follow is not set, fetch from the database. 378 listEntry.Follow, err = l.state.DB.GetFollowByID( 379 gtscontext.SetBarebones(ctx), 380 listEntry.FollowID, 381 ) 382 if err != nil { 383 return fmt.Errorf("error populating listEntry follow: %w", err) 384 } 385 } 386 387 return nil 388 } 389 390 func (l *listDB) PutListEntries(ctx context.Context, listEntries []*gtsmodel.ListEntry) error { 391 return l.conn.RunInTx(ctx, func(tx bun.Tx) error { 392 for _, listEntry := range listEntries { 393 if _, err := tx. 394 NewInsert(). 395 Model(listEntry). 396 Exec(ctx); err != nil { 397 return err 398 } 399 400 // Invalidate the timeline for the list this entry belongs to. 401 if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil { 402 log.Errorf(ctx, "PutListEntries: error invalidating list timeline: %q", err) 403 } 404 } 405 406 return nil 407 }) 408 } 409 410 func (l *listDB) DeleteListEntry(ctx context.Context, id string) error { 411 defer l.state.Caches.GTS.ListEntry().Invalidate("ID", id) 412 413 // Load list entry into cache before attempting a delete, 414 // as we need the followID from it in order to trigger 415 // timeline invalidation. 416 listEntry, err := l.GetListEntryByID( 417 // Don't populate the entry; 418 // we only want the list ID. 419 gtscontext.SetBarebones(ctx), 420 id, 421 ) 422 if err != nil { 423 if errors.Is(err, db.ErrNoEntries) { 424 // Already gone. 425 return nil 426 } 427 return err 428 } 429 430 defer func() { 431 // Invalidate the timeline for the list this entry belongs to. 432 if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil { 433 log.Errorf(ctx, "DeleteListEntry: error invalidating list timeline: %q", err) 434 } 435 }() 436 437 if _, err := l.conn.NewDelete(). 438 Table("list_entries"). 439 Where("? = ?", bun.Ident("id"), listEntry.ID). 440 Exec(ctx); err != nil { 441 return l.conn.ProcessError(err) 442 } 443 444 return nil 445 } 446 447 func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID string) error { 448 // Fetch IDs of all entries that pertain to this follow. 449 var listEntryIDs []string 450 if err := l.conn. 451 NewSelect(). 452 TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("list_entry")). 453 Column("list_entry.id"). 454 Where("? = ?", bun.Ident("list_entry.follow_id"), followID). 455 Order("list_entry.id DESC"). 456 Scan(ctx, &listEntryIDs); err != nil { 457 return l.conn.ProcessError(err) 458 } 459 460 for _, id := range listEntryIDs { 461 if err := l.DeleteListEntry(ctx, id); err != nil { 462 return err 463 } 464 } 465 466 return nil 467 }