gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

relationship_block.go (6604B)


      1 // GoToSocial
      2 // Copyright (C) GoToSocial Authors admin@gotosocial.org
      3 // SPDX-License-Identifier: AGPL-3.0-or-later
      4 //
      5 // This program is free software: you can redistribute it and/or modify
      6 // it under the terms of the GNU Affero General Public License as published by
      7 // the Free Software Foundation, either version 3 of the License, or
      8 // (at your option) any later version.
      9 //
     10 // This program is distributed in the hope that it will be useful,
     11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
     12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     13 // GNU Affero General Public License for more details.
     14 //
     15 // You should have received a copy of the GNU Affero General Public License
     16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
     17 
     18 package bundb
     19 
     20 import (
     21 	"context"
     22 	"errors"
     23 	"fmt"
     24 
     25 	"github.com/superseriousbusiness/gotosocial/internal/db"
     26 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
     27 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
     28 	"github.com/uptrace/bun"
     29 )
     30 
     31 func (r *relationshipDB) IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
     32 	block, err := r.GetBlock(
     33 		gtscontext.SetBarebones(ctx),
     34 		sourceAccountID,
     35 		targetAccountID,
     36 	)
     37 	if err != nil && !errors.Is(err, db.ErrNoEntries) {
     38 		return false, err
     39 	}
     40 	return (block != nil), nil
     41 }
     42 
     43 func (r *relationshipDB) IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error) {
     44 	// Look for a block in direction of account1->account2
     45 	b1, err := r.IsBlocked(ctx, accountID1, accountID2)
     46 	if err != nil || b1 {
     47 		return true, err
     48 	}
     49 
     50 	// Look for a block in direction of account2->account1
     51 	b2, err := r.IsBlocked(ctx, accountID2, accountID1)
     52 	if err != nil || b2 {
     53 		return true, err
     54 	}
     55 
     56 	return false, nil
     57 }
     58 
     59 func (r *relationshipDB) GetBlockByID(ctx context.Context, id string) (*gtsmodel.Block, error) {
     60 	return r.getBlock(
     61 		ctx,
     62 		"ID",
     63 		func(block *gtsmodel.Block) error {
     64 			return r.conn.NewSelect().Model(block).
     65 				Where("? = ?", bun.Ident("block.id"), id).
     66 				Scan(ctx)
     67 		},
     68 		id,
     69 	)
     70 }
     71 
     72 func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmodel.Block, error) {
     73 	return r.getBlock(
     74 		ctx,
     75 		"URI",
     76 		func(block *gtsmodel.Block) error {
     77 			return r.conn.NewSelect().Model(block).
     78 				Where("? = ?", bun.Ident("block.uri"), uri).
     79 				Scan(ctx)
     80 		},
     81 		uri,
     82 	)
     83 }
     84 
     85 func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) {
     86 	return r.getBlock(
     87 		ctx,
     88 		"AccountID.TargetAccountID",
     89 		func(block *gtsmodel.Block) error {
     90 			return r.conn.NewSelect().Model(block).
     91 				Where("? = ?", bun.Ident("block.account_id"), sourceAccountID).
     92 				Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID).
     93 				Scan(ctx)
     94 		},
     95 		sourceAccountID,
     96 		targetAccountID,
     97 	)
     98 }
     99 
    100 func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) {
    101 	// Fetch block from cache with loader callback
    102 	block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) {
    103 		var block gtsmodel.Block
    104 
    105 		// Not cached! Perform database query
    106 		if err := dbQuery(&block); err != nil {
    107 			return nil, r.conn.ProcessError(err)
    108 		}
    109 
    110 		return &block, nil
    111 	}, keyParts...)
    112 	if err != nil {
    113 		// already processe
    114 		return nil, err
    115 	}
    116 
    117 	if gtscontext.Barebones(ctx) {
    118 		// Only a barebones model was requested.
    119 		return block, nil
    120 	}
    121 
    122 	// Set the block source account
    123 	block.Account, err = r.state.DB.GetAccountByID(
    124 		gtscontext.SetBarebones(ctx),
    125 		block.AccountID,
    126 	)
    127 	if err != nil {
    128 		return nil, fmt.Errorf("error getting block source account: %w", err)
    129 	}
    130 
    131 	// Set the block target account
    132 	block.TargetAccount, err = r.state.DB.GetAccountByID(
    133 		gtscontext.SetBarebones(ctx),
    134 		block.TargetAccountID,
    135 	)
    136 	if err != nil {
    137 		return nil, fmt.Errorf("error getting block target account: %w", err)
    138 	}
    139 
    140 	return block, nil
    141 }
    142 
    143 func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error {
    144 	return r.state.Caches.GTS.Block().Store(block, func() error {
    145 		_, err := r.conn.NewInsert().Model(block).Exec(ctx)
    146 		return r.conn.ProcessError(err)
    147 	})
    148 }
    149 
    150 func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {
    151 	defer r.state.Caches.GTS.Block().Invalidate("ID", id)
    152 
    153 	// Load block into cache before attempting a delete,
    154 	// as we need it cached in order to trigger the invalidate
    155 	// callback. This in turn invalidates others.
    156 	_, err := r.GetBlockByID(gtscontext.SetBarebones(ctx), id)
    157 	if err != nil {
    158 		if errors.Is(err, db.ErrNoEntries) {
    159 			// not an issue.
    160 			err = nil
    161 		}
    162 		return err
    163 	}
    164 
    165 	// Finally delete block from DB.
    166 	_, err = r.conn.NewDelete().
    167 		Table("blocks").
    168 		Where("? = ?", bun.Ident("id"), id).
    169 		Exec(ctx)
    170 	return r.conn.ProcessError(err)
    171 }
    172 
    173 func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error {
    174 	defer r.state.Caches.GTS.Block().Invalidate("URI", uri)
    175 
    176 	// Load block into cache before attempting a delete,
    177 	// as we need it cached in order to trigger the invalidate
    178 	// callback. This in turn invalidates others.
    179 	_, err := r.GetBlockByURI(gtscontext.SetBarebones(ctx), uri)
    180 	if err != nil {
    181 		if errors.Is(err, db.ErrNoEntries) {
    182 			// not an issue.
    183 			err = nil
    184 		}
    185 		return err
    186 	}
    187 
    188 	// Finally delete block from DB.
    189 	_, err = r.conn.NewDelete().
    190 		Table("blocks").
    191 		Where("? = ?", bun.Ident("uri"), uri).
    192 		Exec(ctx)
    193 	return r.conn.ProcessError(err)
    194 }
    195 
    196 func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID string) error {
    197 	var blockIDs []string
    198 
    199 	// Get full list of IDs.
    200 	if err := r.conn.NewSelect().
    201 		Column("id").
    202 		Table("blocks").
    203 		WhereOr("? = ? OR ? = ?",
    204 			bun.Ident("account_id"),
    205 			accountID,
    206 			bun.Ident("target_account_id"),
    207 			accountID,
    208 		).
    209 		Scan(ctx, &blockIDs); err != nil {
    210 		return r.conn.ProcessError(err)
    211 	}
    212 
    213 	defer func() {
    214 		// Invalidate all IDs on return.
    215 		for _, id := range blockIDs {
    216 			r.state.Caches.GTS.Block().Invalidate("ID", id)
    217 		}
    218 	}()
    219 
    220 	// Load all blocks into cache, this *really* isn't great
    221 	// but it is the only way we can ensure we invalidate all
    222 	// related caches correctly (e.g. visibility).
    223 	for _, id := range blockIDs {
    224 		_, err := r.GetBlockByID(ctx, id)
    225 		if err != nil && !errors.Is(err, db.ErrNoEntries) {
    226 			return err
    227 		}
    228 	}
    229 
    230 	// Finally delete all from DB.
    231 	_, err := r.conn.NewDelete().
    232 		Table("blocks").
    233 		Where("? IN (?)", bun.Ident("id"), bun.In(blockIDs)).
    234 		Exec(ctx)
    235 	return r.conn.ProcessError(err)
    236 }