gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

relationship_follow_req.go (10618B)


      1 // GoToSocial
      2 // Copyright (C) GoToSocial Authors admin@gotosocial.org
      3 // SPDX-License-Identifier: AGPL-3.0-or-later
      4 //
      5 // This program is free software: you can redistribute it and/or modify
      6 // it under the terms of the GNU Affero General Public License as published by
      7 // the Free Software Foundation, either version 3 of the License, or
      8 // (at your option) any later version.
      9 //
     10 // This program is distributed in the hope that it will be useful,
     11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
     12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     13 // GNU Affero General Public License for more details.
     14 //
     15 // You should have received a copy of the GNU Affero General Public License
     16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
     17 
     18 package bundb
     19 
     20 import (
     21 	"context"
     22 	"errors"
     23 	"fmt"
     24 	"time"
     25 
     26 	"github.com/superseriousbusiness/gotosocial/internal/db"
     27 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
     28 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
     29 	"github.com/superseriousbusiness/gotosocial/internal/log"
     30 	"github.com/uptrace/bun"
     31 )
     32 
     33 func (r *relationshipDB) GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error) {
     34 	return r.getFollowRequest(
     35 		ctx,
     36 		"ID",
     37 		func(followReq *gtsmodel.FollowRequest) error {
     38 			return r.conn.NewSelect().
     39 				Model(followReq).
     40 				Where("? = ?", bun.Ident("id"), id).
     41 				Scan(ctx)
     42 		},
     43 		id,
     44 	)
     45 }
     46 
     47 func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string) (*gtsmodel.FollowRequest, error) {
     48 	return r.getFollowRequest(
     49 		ctx,
     50 		"URI",
     51 		func(followReq *gtsmodel.FollowRequest) error {
     52 			return r.conn.NewSelect().
     53 				Model(followReq).
     54 				Where("? = ?", bun.Ident("uri"), uri).
     55 				Scan(ctx)
     56 		},
     57 		uri,
     58 	)
     59 }
     60 
     61 func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) {
     62 	return r.getFollowRequest(
     63 		ctx,
     64 		"AccountID.TargetAccountID",
     65 		func(followReq *gtsmodel.FollowRequest) error {
     66 			return r.conn.NewSelect().
     67 				Model(followReq).
     68 				Where("? = ?", bun.Ident("account_id"), sourceAccountID).
     69 				Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
     70 				Scan(ctx)
     71 		},
     72 		sourceAccountID,
     73 		targetAccountID,
     74 	)
     75 }
     76 
     77 func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.FollowRequest, error) {
     78 	// Preallocate slice of expected length.
     79 	followReqs := make([]*gtsmodel.FollowRequest, 0, len(ids))
     80 
     81 	for _, id := range ids {
     82 		// Fetch follow request model for this ID.
     83 		followReq, err := r.GetFollowRequestByID(ctx, id)
     84 		if err != nil {
     85 			log.Errorf(ctx, "error getting follow request %q: %v", id, err)
     86 			continue
     87 		}
     88 
     89 		// Append to return slice.
     90 		followReqs = append(followReqs, followReq)
     91 	}
     92 
     93 	return followReqs, nil
     94 }
     95 
     96 func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
     97 	followReq, err := r.GetFollowRequest(
     98 		gtscontext.SetBarebones(ctx),
     99 		sourceAccountID,
    100 		targetAccountID,
    101 	)
    102 	if err != nil && !errors.Is(err, db.ErrNoEntries) {
    103 		return false, err
    104 	}
    105 	return (followReq != nil), nil
    106 }
    107 
    108 func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, dbQuery func(*gtsmodel.FollowRequest) error, keyParts ...any) (*gtsmodel.FollowRequest, error) {
    109 	// Fetch follow request from database cache with loader callback
    110 	followReq, err := r.state.Caches.GTS.FollowRequest().Load(lookup, func() (*gtsmodel.FollowRequest, error) {
    111 		var followReq gtsmodel.FollowRequest
    112 
    113 		// Not cached! Perform database query
    114 		if err := dbQuery(&followReq); err != nil {
    115 			return nil, r.conn.ProcessError(err)
    116 		}
    117 
    118 		return &followReq, nil
    119 	}, keyParts...)
    120 	if err != nil {
    121 		// error already processed
    122 		return nil, err
    123 	}
    124 
    125 	if gtscontext.Barebones(ctx) {
    126 		// Only a barebones model was requested.
    127 		return followReq, nil
    128 	}
    129 
    130 	// Set the follow request source account
    131 	followReq.Account, err = r.state.DB.GetAccountByID(
    132 		gtscontext.SetBarebones(ctx),
    133 		followReq.AccountID,
    134 	)
    135 	if err != nil {
    136 		return nil, fmt.Errorf("error getting follow request source account: %w", err)
    137 	}
    138 
    139 	// Set the follow request target account
    140 	followReq.TargetAccount, err = r.state.DB.GetAccountByID(
    141 		gtscontext.SetBarebones(ctx),
    142 		followReq.TargetAccountID,
    143 	)
    144 	if err != nil {
    145 		return nil, fmt.Errorf("error getting follow request target account: %w", err)
    146 	}
    147 
    148 	return followReq, nil
    149 }
    150 
    151 func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error {
    152 	return r.state.Caches.GTS.FollowRequest().Store(follow, func() error {
    153 		_, err := r.conn.NewInsert().Model(follow).Exec(ctx)
    154 		return r.conn.ProcessError(err)
    155 	})
    156 }
    157 
    158 func (r *relationshipDB) UpdateFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest, columns ...string) error {
    159 	followRequest.UpdatedAt = time.Now()
    160 	if len(columns) > 0 {
    161 		// If we're updating by column, ensure "updated_at" is included.
    162 		columns = append(columns, "updated_at")
    163 	}
    164 
    165 	return r.state.Caches.GTS.FollowRequest().Store(followRequest, func() error {
    166 		if _, err := r.conn.NewUpdate().
    167 			Model(followRequest).
    168 			Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
    169 			Column(columns...).
    170 			Exec(ctx); err != nil {
    171 			return r.conn.ProcessError(err)
    172 		}
    173 
    174 		return nil
    175 	})
    176 }
    177 
    178 func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
    179 	// Get original follow request.
    180 	followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID)
    181 	if err != nil {
    182 		return nil, err
    183 	}
    184 
    185 	// Create a new follow to 'replace'
    186 	// the original follow request with.
    187 	follow := &gtsmodel.Follow{
    188 		ID:              followReq.ID,
    189 		AccountID:       sourceAccountID,
    190 		Account:         followReq.Account,
    191 		TargetAccountID: targetAccountID,
    192 		TargetAccount:   followReq.TargetAccount,
    193 		URI:             followReq.URI,
    194 		ShowReblogs:     followReq.ShowReblogs,
    195 		Notify:          followReq.Notify,
    196 	}
    197 
    198 	if err := r.state.Caches.GTS.Follow().Store(follow, func() error {
    199 		// If the follow already exists, just
    200 		// replace the URI with the new one.
    201 		_, err := r.conn.
    202 			NewInsert().
    203 			Model(follow).
    204 			On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
    205 			Exec(ctx)
    206 		return r.conn.ProcessError(err)
    207 	}); err != nil {
    208 		return nil, err
    209 	}
    210 
    211 	// Invalidate follow request from cache lookups on return.
    212 	defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID)
    213 
    214 	// Delete original follow request.
    215 	if _, err := r.conn.
    216 		NewDelete().
    217 		Table("follow_requests").
    218 		Where("? = ?", bun.Ident("id"), followReq.ID).
    219 		Exec(ctx); err != nil {
    220 		return nil, r.conn.ProcessError(err)
    221 	}
    222 
    223 	// Delete original follow request notification
    224 	if err := r.state.DB.DeleteNotifications(ctx, []string{
    225 		string(gtsmodel.NotificationFollowRequest),
    226 	}, targetAccountID, sourceAccountID); err != nil {
    227 		return nil, err
    228 	}
    229 
    230 	return follow, nil
    231 }
    232 
    233 func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) db.Error {
    234 	defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID)
    235 
    236 	// Load followreq into cache before attempting a delete,
    237 	// as we need it cached in order to trigger the invalidate
    238 	// callback. This in turn invalidates others.
    239 	_, err := r.GetFollowRequest(gtscontext.SetBarebones(ctx),
    240 		sourceAccountID,
    241 		targetAccountID,
    242 	)
    243 	if err != nil {
    244 		return err
    245 	}
    246 
    247 	// Attempt to delete follow request.
    248 	if _, err = r.conn.NewDelete().
    249 		Table("follow_requests").
    250 		Where("? = ? AND ? = ?",
    251 			bun.Ident("account_id"),
    252 			sourceAccountID,
    253 			bun.Ident("target_account_id"),
    254 			targetAccountID,
    255 		).
    256 		Exec(ctx); err != nil {
    257 		return r.conn.ProcessError(err)
    258 	}
    259 
    260 	// Delete original follow request notification
    261 	return r.state.DB.DeleteNotifications(ctx, []string{
    262 		string(gtsmodel.NotificationFollowRequest),
    263 	}, targetAccountID, sourceAccountID)
    264 }
    265 
    266 func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error {
    267 	defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
    268 
    269 	// Load followreq into cache before attempting a delete,
    270 	// as we need it cached in order to trigger the invalidate
    271 	// callback. This in turn invalidates others.
    272 	_, err := r.GetFollowRequestByID(gtscontext.SetBarebones(ctx), id)
    273 	if err != nil {
    274 		if errors.Is(err, db.ErrNoEntries) {
    275 			// not an issue.
    276 			err = nil
    277 		}
    278 		return err
    279 	}
    280 
    281 	// Finally delete followreq from DB.
    282 	_, err = r.conn.NewDelete().
    283 		Table("follow_requests").
    284 		Where("? = ?", bun.Ident("id"), id).
    285 		Exec(ctx)
    286 	return r.conn.ProcessError(err)
    287 }
    288 
    289 func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error {
    290 	defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri)
    291 
    292 	// Load followreq into cache before attempting a delete,
    293 	// as we need it cached in order to trigger the invalidate
    294 	// callback. This in turn invalidates others.
    295 	_, err := r.GetFollowRequestByURI(gtscontext.SetBarebones(ctx), uri)
    296 	if err != nil {
    297 		if errors.Is(err, db.ErrNoEntries) {
    298 			// not an issue.
    299 			err = nil
    300 		}
    301 		return err
    302 	}
    303 
    304 	// Finally delete followreq from DB.
    305 	_, err = r.conn.NewDelete().
    306 		Table("follow_requests").
    307 		Where("? = ?", bun.Ident("uri"), uri).
    308 		Exec(ctx)
    309 	return r.conn.ProcessError(err)
    310 }
    311 
    312 func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accountID string) error {
    313 	var followReqIDs []string
    314 
    315 	// Get full list of IDs.
    316 	if _, err := r.conn.
    317 		NewSelect().
    318 		Column("id").
    319 		Table("follow_requestss").
    320 		WhereOr("? = ? OR ? = ?",
    321 			bun.Ident("account_id"),
    322 			accountID,
    323 			bun.Ident("target_account_id"),
    324 			accountID,
    325 		).
    326 		Exec(ctx, &followReqIDs); err != nil {
    327 		return r.conn.ProcessError(err)
    328 	}
    329 
    330 	defer func() {
    331 		// Invalidate all IDs on return.
    332 		for _, id := range followReqIDs {
    333 			r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
    334 		}
    335 	}()
    336 
    337 	// Load all followreqs into cache, this *really* isn't
    338 	// great but it is the only way we can ensure we invalidate
    339 	// all related caches correctly (e.g. visibility).
    340 	for _, id := range followReqIDs {
    341 		_, err := r.GetFollowRequestByID(ctx, id)
    342 		if err != nil && !errors.Is(err, db.ErrNoEntries) {
    343 			return err
    344 		}
    345 	}
    346 
    347 	// Finally delete all from DB.
    348 	_, err := r.conn.NewDelete().
    349 		Table("follow_requests").
    350 		Where("? IN (?)", bun.Ident("id"), bun.In(followReqIDs)).
    351 		Exec(ctx)
    352 	return r.conn.ProcessError(err)
    353 }