gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

list.go (12172B)


      1 // GoToSocial
      2 // Copyright (C) GoToSocial Authors admin@gotosocial.org
      3 // SPDX-License-Identifier: AGPL-3.0-or-later
      4 //
      5 // This program is free software: you can redistribute it and/or modify
      6 // it under the terms of the GNU Affero General Public License as published by
      7 // the Free Software Foundation, either version 3 of the License, or
      8 // (at your option) any later version.
      9 //
     10 // This program is distributed in the hope that it will be useful,
     11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
     12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     13 // GNU Affero General Public License for more details.
     14 //
     15 // You should have received a copy of the GNU Affero General Public License
     16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
     17 
     18 package bundb
     19 
     20 import (
     21 	"context"
     22 	"errors"
     23 	"fmt"
     24 	"time"
     25 
     26 	"github.com/superseriousbusiness/gotosocial/internal/db"
     27 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
     28 	"github.com/superseriousbusiness/gotosocial/internal/gtserror"
     29 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
     30 	"github.com/superseriousbusiness/gotosocial/internal/log"
     31 	"github.com/superseriousbusiness/gotosocial/internal/state"
     32 	"github.com/uptrace/bun"
     33 )
     34 
     35 type listDB struct {
     36 	conn  *DBConn
     37 	state *state.State
     38 }
     39 
     40 /*
     41 	LIST FUNCTIONS
     42 */
     43 
     44 func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) {
     45 	list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) {
     46 		var list gtsmodel.List
     47 
     48 		// Not cached! Perform database query.
     49 		if err := dbQuery(&list); err != nil {
     50 			return nil, l.conn.ProcessError(err)
     51 		}
     52 
     53 		return &list, nil
     54 	}, keyParts...)
     55 	if err != nil {
     56 		return nil, err // already processed
     57 	}
     58 
     59 	if gtscontext.Barebones(ctx) {
     60 		// Only a barebones model was requested.
     61 		return list, nil
     62 	}
     63 
     64 	if err := l.state.DB.PopulateList(ctx, list); err != nil {
     65 		return nil, err
     66 	}
     67 
     68 	return list, nil
     69 }
     70 
     71 func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, error) {
     72 	return l.getList(
     73 		ctx,
     74 		"ID",
     75 		func(list *gtsmodel.List) error {
     76 			return l.conn.NewSelect().
     77 				Model(list).
     78 				Where("? = ?", bun.Ident("list.id"), id).
     79 				Scan(ctx)
     80 		},
     81 		id,
     82 	)
     83 }
     84 
     85 func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) {
     86 	// Fetch IDs of all lists owned by this account.
     87 	var listIDs []string
     88 	if err := l.conn.
     89 		NewSelect().
     90 		TableExpr("? AS ?", bun.Ident("lists"), bun.Ident("list")).
     91 		Column("list.id").
     92 		Where("? = ?", bun.Ident("list.account_id"), accountID).
     93 		Order("list.id DESC").
     94 		Scan(ctx, &listIDs); err != nil {
     95 		return nil, l.conn.ProcessError(err)
     96 	}
     97 
     98 	if len(listIDs) == 0 {
     99 		return nil, nil
    100 	}
    101 
    102 	// Select each list using its ID to ensure cache used.
    103 	lists := make([]*gtsmodel.List, 0, len(listIDs))
    104 	for _, id := range listIDs {
    105 		list, err := l.state.DB.GetListByID(ctx, id)
    106 		if err != nil {
    107 			log.Errorf(ctx, "error fetching list %q: %v", id, err)
    108 			continue
    109 		}
    110 
    111 		// Append list.
    112 		lists = append(lists, list)
    113 	}
    114 
    115 	return lists, nil
    116 }
    117 
    118 func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
    119 	var (
    120 		err  error
    121 		errs = make(gtserror.MultiError, 0, 2)
    122 	)
    123 
    124 	if list.Account == nil {
    125 		// List account is not set, fetch from the database.
    126 		list.Account, err = l.state.DB.GetAccountByID(
    127 			gtscontext.SetBarebones(ctx),
    128 			list.AccountID,
    129 		)
    130 		if err != nil {
    131 			errs.Append(fmt.Errorf("error populating list account: %w", err))
    132 		}
    133 	}
    134 
    135 	if list.ListEntries == nil {
    136 		// List entries are not set, fetch from the database.
    137 		list.ListEntries, err = l.state.DB.GetListEntries(
    138 			gtscontext.SetBarebones(ctx),
    139 			list.ID,
    140 			"", "", "", 0,
    141 		)
    142 		if err != nil {
    143 			errs.Append(fmt.Errorf("error populating list entries: %w", err))
    144 		}
    145 	}
    146 
    147 	return errs.Combine()
    148 }
    149 
    150 func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error {
    151 	return l.state.Caches.GTS.List().Store(list, func() error {
    152 		_, err := l.conn.NewInsert().Model(list).Exec(ctx)
    153 		return l.conn.ProcessError(err)
    154 	})
    155 }
    156 
    157 func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ...string) error {
    158 	list.UpdatedAt = time.Now()
    159 	if len(columns) > 0 {
    160 		// If we're updating by column, ensure "updated_at" is included.
    161 		columns = append(columns, "updated_at")
    162 	}
    163 
    164 	return l.state.Caches.GTS.List().Store(list, func() error {
    165 		if _, err := l.conn.NewUpdate().
    166 			Model(list).
    167 			Where("? = ?", bun.Ident("list.id"), list.ID).
    168 			Column(columns...).
    169 			Exec(ctx); err != nil {
    170 			return l.conn.ProcessError(err)
    171 		}
    172 
    173 		return nil
    174 	})
    175 }
    176 
    177 func (l *listDB) DeleteListByID(ctx context.Context, id string) error {
    178 	defer l.state.Caches.GTS.List().Invalidate("ID", id)
    179 
    180 	// Select all entries that belong to this list.
    181 	listEntries, err := l.state.DB.GetListEntries(ctx, id, "", "", "", 0)
    182 	if err != nil {
    183 		return fmt.Errorf("error selecting entries from list %q: %w", id, err)
    184 	}
    185 
    186 	// Delete each list entry. This will
    187 	// invalidate the list timeline too.
    188 	for _, listEntry := range listEntries {
    189 		err := l.state.DB.DeleteListEntry(ctx, listEntry.ID)
    190 		if err != nil && !errors.Is(err, db.ErrNoEntries) {
    191 			return err
    192 		}
    193 	}
    194 
    195 	// Finally delete list itself from DB.
    196 	_, err = l.conn.NewDelete().
    197 		Table("lists").
    198 		Where("? = ?", bun.Ident("id"), id).
    199 		Exec(ctx)
    200 	return l.conn.ProcessError(err)
    201 }
    202 
    203 /*
    204 	LIST ENTRY functions
    205 */
    206 
    207 func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) {
    208 	listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) {
    209 		var listEntry gtsmodel.ListEntry
    210 
    211 		// Not cached! Perform database query.
    212 		if err := dbQuery(&listEntry); err != nil {
    213 			return nil, l.conn.ProcessError(err)
    214 		}
    215 
    216 		return &listEntry, nil
    217 	}, keyParts...)
    218 	if err != nil {
    219 		return nil, err // already processed
    220 	}
    221 
    222 	if gtscontext.Barebones(ctx) {
    223 		// Only a barebones model was requested.
    224 		return listEntry, nil
    225 	}
    226 
    227 	// Further populate the list entry fields where applicable.
    228 	if err := l.state.DB.PopulateListEntry(ctx, listEntry); err != nil {
    229 		return nil, err
    230 	}
    231 
    232 	return listEntry, nil
    233 }
    234 
    235 func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) {
    236 	return l.getListEntry(
    237 		ctx,
    238 		"ID",
    239 		func(listEntry *gtsmodel.ListEntry) error {
    240 			return l.conn.NewSelect().
    241 				Model(listEntry).
    242 				Where("? = ?", bun.Ident("list_entry.id"), id).
    243 				Scan(ctx)
    244 		},
    245 		id,
    246 	)
    247 }
    248 
    249 func (l *listDB) GetListEntries(ctx context.Context,
    250 	listID string,
    251 	maxID string,
    252 	sinceID string,
    253 	minID string,
    254 	limit int,
    255 ) ([]*gtsmodel.ListEntry, error) {
    256 	// Ensure reasonable
    257 	if limit < 0 {
    258 		limit = 0
    259 	}
    260 
    261 	// Make educated guess for slice size
    262 	var (
    263 		entryIDs    = make([]string, 0, limit)
    264 		frontToBack = true
    265 	)
    266 
    267 	q := l.conn.
    268 		NewSelect().
    269 		TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")).
    270 		// Select only IDs from table
    271 		Column("entry.id").
    272 		// Select only entries belonging to listID.
    273 		Where("? = ?", bun.Ident("entry.list_id"), listID)
    274 
    275 	if maxID != "" {
    276 		// return only entries LOWER (ie., older) than maxID
    277 		q = q.Where("? < ?", bun.Ident("entry.id"), maxID)
    278 	}
    279 
    280 	if sinceID != "" {
    281 		// return only entries HIGHER (ie., newer) than sinceID
    282 		q = q.Where("? > ?", bun.Ident("entry.id"), sinceID)
    283 	}
    284 
    285 	if minID != "" {
    286 		// return only entries HIGHER (ie., newer) than minID
    287 		q = q.Where("? > ?", bun.Ident("entry.id"), minID)
    288 
    289 		// page up
    290 		frontToBack = false
    291 	}
    292 
    293 	if limit > 0 {
    294 		// limit amount of entries returned
    295 		q = q.Limit(limit)
    296 	}
    297 
    298 	if frontToBack {
    299 		// Page down.
    300 		q = q.Order("entry.id DESC")
    301 	} else {
    302 		// Page up.
    303 		q = q.Order("entry.id ASC")
    304 	}
    305 
    306 	if err := q.Scan(ctx, &entryIDs); err != nil {
    307 		return nil, l.conn.ProcessError(err)
    308 	}
    309 
    310 	if len(entryIDs) == 0 {
    311 		return nil, nil
    312 	}
    313 
    314 	// If we're paging up, we still want entries
    315 	// to be sorted by ID desc, so reverse ids slice.
    316 	// https://zchee.github.io/golang-wiki/SliceTricks/#reversing
    317 	if !frontToBack {
    318 		for l, r := 0, len(entryIDs)-1; l < r; l, r = l+1, r-1 {
    319 			entryIDs[l], entryIDs[r] = entryIDs[r], entryIDs[l]
    320 		}
    321 	}
    322 
    323 	// Select each list entry using its ID to ensure cache used.
    324 	listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs))
    325 	for _, id := range entryIDs {
    326 		listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
    327 		if err != nil {
    328 			log.Errorf(ctx, "error fetching list entry %q: %v", id, err)
    329 			continue
    330 		}
    331 
    332 		// Append list entries.
    333 		listEntries = append(listEntries, listEntry)
    334 	}
    335 
    336 	return listEntries, nil
    337 }
    338 
    339 func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) {
    340 	entryIDs := []string{}
    341 
    342 	if err := l.conn.
    343 		NewSelect().
    344 		TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")).
    345 		// Select only IDs from table
    346 		Column("entry.id").
    347 		// Select only entries belonging with given followID.
    348 		Where("? = ?", bun.Ident("entry.follow_id"), followID).
    349 		Scan(ctx, &entryIDs); err != nil {
    350 		return nil, l.conn.ProcessError(err)
    351 	}
    352 
    353 	if len(entryIDs) == 0 {
    354 		return nil, nil
    355 	}
    356 
    357 	// Select each list entry using its ID to ensure cache used.
    358 	listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs))
    359 	for _, id := range entryIDs {
    360 		listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
    361 		if err != nil {
    362 			log.Errorf(ctx, "error fetching list entry %q: %v", id, err)
    363 			continue
    364 		}
    365 
    366 		// Append list entries.
    367 		listEntries = append(listEntries, listEntry)
    368 	}
    369 
    370 	return listEntries, nil
    371 }
    372 
    373 func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error {
    374 	var err error
    375 
    376 	if listEntry.Follow == nil {
    377 		// ListEntry follow is not set, fetch from the database.
    378 		listEntry.Follow, err = l.state.DB.GetFollowByID(
    379 			gtscontext.SetBarebones(ctx),
    380 			listEntry.FollowID,
    381 		)
    382 		if err != nil {
    383 			return fmt.Errorf("error populating listEntry follow: %w", err)
    384 		}
    385 	}
    386 
    387 	return nil
    388 }
    389 
    390 func (l *listDB) PutListEntries(ctx context.Context, listEntries []*gtsmodel.ListEntry) error {
    391 	return l.conn.RunInTx(ctx, func(tx bun.Tx) error {
    392 		for _, listEntry := range listEntries {
    393 			if _, err := tx.
    394 				NewInsert().
    395 				Model(listEntry).
    396 				Exec(ctx); err != nil {
    397 				return err
    398 			}
    399 
    400 			// Invalidate the timeline for the list this entry belongs to.
    401 			if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil {
    402 				log.Errorf(ctx, "PutListEntries: error invalidating list timeline: %q", err)
    403 			}
    404 		}
    405 
    406 		return nil
    407 	})
    408 }
    409 
    410 func (l *listDB) DeleteListEntry(ctx context.Context, id string) error {
    411 	defer l.state.Caches.GTS.ListEntry().Invalidate("ID", id)
    412 
    413 	// Load list entry into cache before attempting a delete,
    414 	// as we need the followID from it in order to trigger
    415 	// timeline invalidation.
    416 	listEntry, err := l.GetListEntryByID(
    417 		// Don't populate the entry;
    418 		// we only want the list ID.
    419 		gtscontext.SetBarebones(ctx),
    420 		id,
    421 	)
    422 	if err != nil {
    423 		if errors.Is(err, db.ErrNoEntries) {
    424 			// Already gone.
    425 			return nil
    426 		}
    427 		return err
    428 	}
    429 
    430 	defer func() {
    431 		// Invalidate the timeline for the list this entry belongs to.
    432 		if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil {
    433 			log.Errorf(ctx, "DeleteListEntry: error invalidating list timeline: %q", err)
    434 		}
    435 	}()
    436 
    437 	if _, err := l.conn.NewDelete().
    438 		Table("list_entries").
    439 		Where("? = ?", bun.Ident("id"), listEntry.ID).
    440 		Exec(ctx); err != nil {
    441 		return l.conn.ProcessError(err)
    442 	}
    443 
    444 	return nil
    445 }
    446 
    447 func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID string) error {
    448 	// Fetch IDs of all entries that pertain to this follow.
    449 	var listEntryIDs []string
    450 	if err := l.conn.
    451 		NewSelect().
    452 		TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("list_entry")).
    453 		Column("list_entry.id").
    454 		Where("? = ?", bun.Ident("list_entry.follow_id"), followID).
    455 		Order("list_entry.id DESC").
    456 		Scan(ctx, &listEntryIDs); err != nil {
    457 		return l.conn.ProcessError(err)
    458 	}
    459 
    460 	for _, id := range listEntryIDs {
    461 		if err := l.DeleteListEntry(ctx, id); err != nil {
    462 			return err
    463 		}
    464 	}
    465 
    466 	return nil
    467 }