gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

relationship_follow.go (8616B)


      1 // GoToSocial
      2 // Copyright (C) GoToSocial Authors admin@gotosocial.org
      3 // SPDX-License-Identifier: AGPL-3.0-or-later
      4 //
      5 // This program is free software: you can redistribute it and/or modify
      6 // it under the terms of the GNU Affero General Public License as published by
      7 // the Free Software Foundation, either version 3 of the License, or
      8 // (at your option) any later version.
      9 //
     10 // This program is distributed in the hope that it will be useful,
     11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
     12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     13 // GNU Affero General Public License for more details.
     14 //
     15 // You should have received a copy of the GNU Affero General Public License
     16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
     17 
     18 package bundb
     19 
     20 import (
     21 	"context"
     22 	"errors"
     23 	"fmt"
     24 	"time"
     25 
     26 	"github.com/superseriousbusiness/gotosocial/internal/db"
     27 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
     28 	"github.com/superseriousbusiness/gotosocial/internal/gtserror"
     29 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
     30 	"github.com/superseriousbusiness/gotosocial/internal/log"
     31 	"github.com/uptrace/bun"
     32 )
     33 
     34 func (r *relationshipDB) GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error) {
     35 	return r.getFollow(
     36 		ctx,
     37 		"ID",
     38 		func(follow *gtsmodel.Follow) error {
     39 			return r.conn.NewSelect().
     40 				Model(follow).
     41 				Where("? = ?", bun.Ident("id"), id).
     42 				Scan(ctx)
     43 		},
     44 		id,
     45 	)
     46 }
     47 
     48 func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmodel.Follow, error) {
     49 	return r.getFollow(
     50 		ctx,
     51 		"URI",
     52 		func(follow *gtsmodel.Follow) error {
     53 			return r.conn.NewSelect().
     54 				Model(follow).
     55 				Where("? = ?", bun.Ident("uri"), uri).
     56 				Scan(ctx)
     57 		},
     58 		uri,
     59 	)
     60 }
     61 
     62 func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
     63 	return r.getFollow(
     64 		ctx,
     65 		"AccountID.TargetAccountID",
     66 		func(follow *gtsmodel.Follow) error {
     67 			return r.conn.NewSelect().
     68 				Model(follow).
     69 				Where("? = ?", bun.Ident("account_id"), sourceAccountID).
     70 				Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
     71 				Scan(ctx)
     72 		},
     73 		sourceAccountID,
     74 		targetAccountID,
     75 	)
     76 }
     77 
     78 func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) {
     79 	// Preallocate slice of expected length.
     80 	follows := make([]*gtsmodel.Follow, 0, len(ids))
     81 
     82 	for _, id := range ids {
     83 		// Fetch follow model for this ID.
     84 		follow, err := r.GetFollowByID(ctx, id)
     85 		if err != nil {
     86 			log.Errorf(ctx, "error getting follow %q: %v", id, err)
     87 			continue
     88 		}
     89 
     90 		// Append to return slice.
     91 		follows = append(follows, follow)
     92 	}
     93 
     94 	return follows, nil
     95 }
     96 
     97 func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
     98 	follow, err := r.GetFollow(
     99 		gtscontext.SetBarebones(ctx),
    100 		sourceAccountID,
    101 		targetAccountID,
    102 	)
    103 	if err != nil && !errors.Is(err, db.ErrNoEntries) {
    104 		return false, err
    105 	}
    106 	return (follow != nil), nil
    107 }
    108 
    109 func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 string, accountID2 string) (bool, db.Error) {
    110 	// make sure account 1 follows account 2
    111 	f1, err := r.IsFollowing(ctx,
    112 		accountID1,
    113 		accountID2,
    114 	)
    115 	if !f1 /* f1 = false when err != nil */ {
    116 		return false, err
    117 	}
    118 
    119 	// make sure account 2 follows account 1
    120 	f2, err := r.IsFollowing(ctx,
    121 		accountID2,
    122 		accountID1,
    123 	)
    124 	if !f2 /* f2 = false when err != nil */ {
    125 		return false, err
    126 	}
    127 
    128 	return true, nil
    129 }
    130 
    131 func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) {
    132 	// Fetch follow from database cache with loader callback
    133 	follow, err := r.state.Caches.GTS.Follow().Load(lookup, func() (*gtsmodel.Follow, error) {
    134 		var follow gtsmodel.Follow
    135 
    136 		// Not cached! Perform database query
    137 		if err := dbQuery(&follow); err != nil {
    138 			return nil, r.conn.ProcessError(err)
    139 		}
    140 
    141 		return &follow, nil
    142 	}, keyParts...)
    143 	if err != nil {
    144 		// error already processed
    145 		return nil, err
    146 	}
    147 
    148 	if gtscontext.Barebones(ctx) {
    149 		// Only a barebones model was requested.
    150 		return follow, nil
    151 	}
    152 
    153 	if err := r.state.DB.PopulateFollow(ctx, follow); err != nil {
    154 		return nil, err
    155 	}
    156 
    157 	return follow, nil
    158 }
    159 
    160 func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Follow) error {
    161 	var (
    162 		err  error
    163 		errs = make(gtserror.MultiError, 0, 2)
    164 	)
    165 
    166 	if follow.Account == nil {
    167 		// Follow account is not set, fetch from the database.
    168 		follow.Account, err = r.state.DB.GetAccountByID(
    169 			gtscontext.SetBarebones(ctx),
    170 			follow.AccountID,
    171 		)
    172 		if err != nil {
    173 			errs.Append(fmt.Errorf("error populating follow account: %w", err))
    174 		}
    175 	}
    176 
    177 	if follow.TargetAccount == nil {
    178 		// Follow target account is not set, fetch from the database.
    179 		follow.TargetAccount, err = r.state.DB.GetAccountByID(
    180 			gtscontext.SetBarebones(ctx),
    181 			follow.TargetAccountID,
    182 		)
    183 		if err != nil {
    184 			errs.Append(fmt.Errorf("error populating follow target account: %w", err))
    185 		}
    186 	}
    187 
    188 	return errs.Combine()
    189 }
    190 
    191 func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {
    192 	return r.state.Caches.GTS.Follow().Store(follow, func() error {
    193 		_, err := r.conn.NewInsert().Model(follow).Exec(ctx)
    194 		return r.conn.ProcessError(err)
    195 	})
    196 }
    197 
    198 func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Follow, columns ...string) error {
    199 	follow.UpdatedAt = time.Now()
    200 	if len(columns) > 0 {
    201 		// If we're updating by column, ensure "updated_at" is included.
    202 		columns = append(columns, "updated_at")
    203 	}
    204 
    205 	return r.state.Caches.GTS.Follow().Store(follow, func() error {
    206 		if _, err := r.conn.NewUpdate().
    207 			Model(follow).
    208 			Where("? = ?", bun.Ident("follow.id"), follow.ID).
    209 			Column(columns...).
    210 			Exec(ctx); err != nil {
    211 			return r.conn.ProcessError(err)
    212 		}
    213 
    214 		return nil
    215 	})
    216 }
    217 
    218 func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error {
    219 	// Delete the follow itself using the given ID.
    220 	if _, err := r.conn.NewDelete().
    221 		Table("follows").
    222 		Where("? = ?", bun.Ident("id"), id).
    223 		Exec(ctx); err != nil {
    224 		return r.conn.ProcessError(err)
    225 	}
    226 
    227 	// Delete every list entry that used this followID.
    228 	if err := r.state.DB.DeleteListEntriesForFollowID(ctx, id); err != nil {
    229 		return fmt.Errorf("deleteFollow: error deleting list entries: %w", err)
    230 	}
    231 
    232 	return nil
    233 }
    234 
    235 func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error {
    236 	defer r.state.Caches.GTS.Follow().Invalidate("ID", id)
    237 
    238 	// Load follow into cache before attempting a delete,
    239 	// as we need it cached in order to trigger the invalidate
    240 	// callback. This in turn invalidates others.
    241 	follow, err := r.GetFollowByID(gtscontext.SetBarebones(ctx), id)
    242 	if err != nil {
    243 		if errors.Is(err, db.ErrNoEntries) {
    244 			// Already gone.
    245 			return nil
    246 		}
    247 		return err
    248 	}
    249 
    250 	// Finally delete follow from DB.
    251 	return r.deleteFollow(ctx, follow.ID)
    252 }
    253 
    254 func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error {
    255 	defer r.state.Caches.GTS.Follow().Invalidate("URI", uri)
    256 
    257 	// Load follow into cache before attempting a delete,
    258 	// as we need it cached in order to trigger the invalidate
    259 	// callback. This in turn invalidates others.
    260 	follow, err := r.GetFollowByURI(gtscontext.SetBarebones(ctx), uri)
    261 	if err != nil {
    262 		if errors.Is(err, db.ErrNoEntries) {
    263 			// Already gone.
    264 			return nil
    265 		}
    266 		return err
    267 	}
    268 
    269 	// Finally delete follow from DB.
    270 	return r.deleteFollow(ctx, follow.ID)
    271 }
    272 
    273 func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID string) error {
    274 	var followIDs []string
    275 
    276 	// Get full list of IDs.
    277 	if _, err := r.conn.
    278 		NewSelect().
    279 		Column("id").
    280 		Table("follows").
    281 		WhereOr("? = ? OR ? = ?",
    282 			bun.Ident("account_id"),
    283 			accountID,
    284 			bun.Ident("target_account_id"),
    285 			accountID,
    286 		).
    287 		Exec(ctx, &followIDs); err != nil {
    288 		return r.conn.ProcessError(err)
    289 	}
    290 
    291 	defer func() {
    292 		// Invalidate all IDs on return.
    293 		for _, id := range followIDs {
    294 			r.state.Caches.GTS.Follow().Invalidate("ID", id)
    295 		}
    296 	}()
    297 
    298 	// Load all follows into cache, this *really* isn't great
    299 	// but it is the only way we can ensure we invalidate all
    300 	// related caches correctly (e.g. visibility).
    301 	for _, id := range followIDs {
    302 		follow, err := r.GetFollowByID(ctx, id)
    303 		if err != nil && !errors.Is(err, db.ErrNoEntries) {
    304 			return err
    305 		}
    306 
    307 		// Delete each follow from DB.
    308 		if err := r.deleteFollow(ctx, follow.ID); err != nil && !errors.Is(err, db.ErrNoEntries) {
    309 			return err
    310 		}
    311 	}
    312 
    313 	return nil
    314 }