relationship_block.go (6604B)
1 // GoToSocial 2 // Copyright (C) GoToSocial Authors admin@gotosocial.org 3 // SPDX-License-Identifier: AGPL-3.0-or-later 4 // 5 // This program is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Affero General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // This program is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Affero General Public License for more details. 14 // 15 // You should have received a copy of the GNU Affero General Public License 16 // along with this program. If not, see <http://www.gnu.org/licenses/>. 17 18 package bundb 19 20 import ( 21 "context" 22 "errors" 23 "fmt" 24 25 "github.com/superseriousbusiness/gotosocial/internal/db" 26 "github.com/superseriousbusiness/gotosocial/internal/gtscontext" 27 "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" 28 "github.com/uptrace/bun" 29 ) 30 31 func (r *relationshipDB) IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) { 32 block, err := r.GetBlock( 33 gtscontext.SetBarebones(ctx), 34 sourceAccountID, 35 targetAccountID, 36 ) 37 if err != nil && !errors.Is(err, db.ErrNoEntries) { 38 return false, err 39 } 40 return (block != nil), nil 41 } 42 43 func (r *relationshipDB) IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error) { 44 // Look for a block in direction of account1->account2 45 b1, err := r.IsBlocked(ctx, accountID1, accountID2) 46 if err != nil || b1 { 47 return true, err 48 } 49 50 // Look for a block in direction of account2->account1 51 b2, err := r.IsBlocked(ctx, accountID2, accountID1) 52 if err != nil || b2 { 53 return true, err 54 } 55 56 return false, nil 57 } 58 59 func (r *relationshipDB) GetBlockByID(ctx context.Context, id string) (*gtsmodel.Block, error) { 60 return r.getBlock( 61 ctx, 62 "ID", 63 func(block *gtsmodel.Block) error { 64 return r.conn.NewSelect().Model(block). 65 Where("? = ?", bun.Ident("block.id"), id). 66 Scan(ctx) 67 }, 68 id, 69 ) 70 } 71 72 func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmodel.Block, error) { 73 return r.getBlock( 74 ctx, 75 "URI", 76 func(block *gtsmodel.Block) error { 77 return r.conn.NewSelect().Model(block). 78 Where("? = ?", bun.Ident("block.uri"), uri). 79 Scan(ctx) 80 }, 81 uri, 82 ) 83 } 84 85 func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) { 86 return r.getBlock( 87 ctx, 88 "AccountID.TargetAccountID", 89 func(block *gtsmodel.Block) error { 90 return r.conn.NewSelect().Model(block). 91 Where("? = ?", bun.Ident("block.account_id"), sourceAccountID). 92 Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID). 93 Scan(ctx) 94 }, 95 sourceAccountID, 96 targetAccountID, 97 ) 98 } 99 100 func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) { 101 // Fetch block from cache with loader callback 102 block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) { 103 var block gtsmodel.Block 104 105 // Not cached! Perform database query 106 if err := dbQuery(&block); err != nil { 107 return nil, r.conn.ProcessError(err) 108 } 109 110 return &block, nil 111 }, keyParts...) 112 if err != nil { 113 // already processe 114 return nil, err 115 } 116 117 if gtscontext.Barebones(ctx) { 118 // Only a barebones model was requested. 119 return block, nil 120 } 121 122 // Set the block source account 123 block.Account, err = r.state.DB.GetAccountByID( 124 gtscontext.SetBarebones(ctx), 125 block.AccountID, 126 ) 127 if err != nil { 128 return nil, fmt.Errorf("error getting block source account: %w", err) 129 } 130 131 // Set the block target account 132 block.TargetAccount, err = r.state.DB.GetAccountByID( 133 gtscontext.SetBarebones(ctx), 134 block.TargetAccountID, 135 ) 136 if err != nil { 137 return nil, fmt.Errorf("error getting block target account: %w", err) 138 } 139 140 return block, nil 141 } 142 143 func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error { 144 return r.state.Caches.GTS.Block().Store(block, func() error { 145 _, err := r.conn.NewInsert().Model(block).Exec(ctx) 146 return r.conn.ProcessError(err) 147 }) 148 } 149 150 func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error { 151 defer r.state.Caches.GTS.Block().Invalidate("ID", id) 152 153 // Load block into cache before attempting a delete, 154 // as we need it cached in order to trigger the invalidate 155 // callback. This in turn invalidates others. 156 _, err := r.GetBlockByID(gtscontext.SetBarebones(ctx), id) 157 if err != nil { 158 if errors.Is(err, db.ErrNoEntries) { 159 // not an issue. 160 err = nil 161 } 162 return err 163 } 164 165 // Finally delete block from DB. 166 _, err = r.conn.NewDelete(). 167 Table("blocks"). 168 Where("? = ?", bun.Ident("id"), id). 169 Exec(ctx) 170 return r.conn.ProcessError(err) 171 } 172 173 func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error { 174 defer r.state.Caches.GTS.Block().Invalidate("URI", uri) 175 176 // Load block into cache before attempting a delete, 177 // as we need it cached in order to trigger the invalidate 178 // callback. This in turn invalidates others. 179 _, err := r.GetBlockByURI(gtscontext.SetBarebones(ctx), uri) 180 if err != nil { 181 if errors.Is(err, db.ErrNoEntries) { 182 // not an issue. 183 err = nil 184 } 185 return err 186 } 187 188 // Finally delete block from DB. 189 _, err = r.conn.NewDelete(). 190 Table("blocks"). 191 Where("? = ?", bun.Ident("uri"), uri). 192 Exec(ctx) 193 return r.conn.ProcessError(err) 194 } 195 196 func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID string) error { 197 var blockIDs []string 198 199 // Get full list of IDs. 200 if err := r.conn.NewSelect(). 201 Column("id"). 202 Table("blocks"). 203 WhereOr("? = ? OR ? = ?", 204 bun.Ident("account_id"), 205 accountID, 206 bun.Ident("target_account_id"), 207 accountID, 208 ). 209 Scan(ctx, &blockIDs); err != nil { 210 return r.conn.ProcessError(err) 211 } 212 213 defer func() { 214 // Invalidate all IDs on return. 215 for _, id := range blockIDs { 216 r.state.Caches.GTS.Block().Invalidate("ID", id) 217 } 218 }() 219 220 // Load all blocks into cache, this *really* isn't great 221 // but it is the only way we can ensure we invalidate all 222 // related caches correctly (e.g. visibility). 223 for _, id := range blockIDs { 224 _, err := r.GetBlockByID(ctx, id) 225 if err != nil && !errors.Is(err, db.ErrNoEntries) { 226 return err 227 } 228 } 229 230 // Finally delete all from DB. 231 _, err := r.conn.NewDelete(). 232 Table("blocks"). 233 Where("? IN (?)", bun.Ident("id"), bun.In(blockIDs)). 234 Exec(ctx) 235 return r.conn.ProcessError(err) 236 }