relationship.go (8797B)
1 // GoToSocial 2 // Copyright (C) GoToSocial Authors admin@gotosocial.org 3 // SPDX-License-Identifier: AGPL-3.0-or-later 4 // 5 // This program is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Affero General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // This program is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Affero General Public License for more details. 14 // 15 // You should have received a copy of the GNU Affero General Public License 16 // along with this program. If not, see <http://www.gnu.org/licenses/>. 17 18 package bundb 19 20 import ( 21 "context" 22 "errors" 23 "fmt" 24 25 "github.com/superseriousbusiness/gotosocial/internal/db" 26 "github.com/superseriousbusiness/gotosocial/internal/gtscontext" 27 "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" 28 "github.com/superseriousbusiness/gotosocial/internal/state" 29 "github.com/uptrace/bun" 30 ) 31 32 type relationshipDB struct { 33 conn *DBConn 34 state *state.State 35 } 36 37 func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) { 38 var rel gtsmodel.Relationship 39 rel.ID = targetAccount 40 41 // check if the requesting follows the target 42 follow, err := r.GetFollow( 43 gtscontext.SetBarebones(ctx), 44 requestingAccount, 45 targetAccount, 46 ) 47 if err != nil && !errors.Is(err, db.ErrNoEntries) { 48 return nil, fmt.Errorf("GetRelationship: error fetching follow: %w", err) 49 } 50 51 if follow != nil { 52 // follow exists so we can fill these fields out... 53 rel.Following = true 54 rel.ShowingReblogs = *follow.ShowReblogs 55 rel.Notifying = *follow.Notify 56 } 57 58 // check if the target follows the requesting 59 rel.FollowedBy, err = r.IsFollowing(ctx, 60 targetAccount, 61 requestingAccount, 62 ) 63 if err != nil { 64 return nil, fmt.Errorf("GetRelationship: error checking followedBy: %w", err) 65 } 66 67 // check if requesting has follow requested target 68 rel.Requested, err = r.IsFollowRequested(ctx, 69 requestingAccount, 70 targetAccount, 71 ) 72 if err != nil { 73 return nil, fmt.Errorf("GetRelationship: error checking requested: %w", err) 74 } 75 76 // check if the requesting account is blocking the target account 77 rel.Blocking, err = r.IsBlocked(ctx, requestingAccount, targetAccount) 78 if err != nil { 79 return nil, fmt.Errorf("GetRelationship: error checking blocking: %w", err) 80 } 81 82 // check if the requesting account is blocked by the target account 83 rel.BlockedBy, err = r.IsBlocked(ctx, targetAccount, requestingAccount) 84 if err != nil { 85 return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %w", err) 86 } 87 88 return &rel, nil 89 } 90 91 func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { 92 var followIDs []string 93 if err := newSelectFollows(r.conn, accountID). 94 Scan(ctx, &followIDs); err != nil { 95 return nil, r.conn.ProcessError(err) 96 } 97 return r.GetFollowsByIDs(ctx, followIDs) 98 } 99 100 func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { 101 var followIDs []string 102 if err := newSelectLocalFollows(r.conn, accountID). 103 Scan(ctx, &followIDs); err != nil { 104 return nil, r.conn.ProcessError(err) 105 } 106 return r.GetFollowsByIDs(ctx, followIDs) 107 } 108 109 func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { 110 var followIDs []string 111 if err := newSelectFollowers(r.conn, accountID). 112 Scan(ctx, &followIDs); err != nil { 113 return nil, r.conn.ProcessError(err) 114 } 115 return r.GetFollowsByIDs(ctx, followIDs) 116 } 117 118 func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { 119 var followIDs []string 120 if err := newSelectLocalFollowers(r.conn, accountID). 121 Scan(ctx, &followIDs); err != nil { 122 return nil, r.conn.ProcessError(err) 123 } 124 return r.GetFollowsByIDs(ctx, followIDs) 125 } 126 127 func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { 128 n, err := newSelectFollows(r.conn, accountID).Count(ctx) 129 return n, r.conn.ProcessError(err) 130 } 131 132 func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) { 133 n, err := newSelectLocalFollows(r.conn, accountID).Count(ctx) 134 return n, r.conn.ProcessError(err) 135 } 136 137 func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { 138 n, err := newSelectFollowers(r.conn, accountID).Count(ctx) 139 return n, r.conn.ProcessError(err) 140 } 141 142 func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) { 143 n, err := newSelectLocalFollowers(r.conn, accountID).Count(ctx) 144 return n, r.conn.ProcessError(err) 145 } 146 147 func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { 148 var followReqIDs []string 149 if err := newSelectFollowRequests(r.conn, accountID). 150 Scan(ctx, &followReqIDs); err != nil { 151 return nil, r.conn.ProcessError(err) 152 } 153 return r.GetFollowRequestsByIDs(ctx, followReqIDs) 154 } 155 156 func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { 157 var followReqIDs []string 158 if err := newSelectFollowRequesting(r.conn, accountID). 159 Scan(ctx, &followReqIDs); err != nil { 160 return nil, r.conn.ProcessError(err) 161 } 162 return r.GetFollowRequestsByIDs(ctx, followReqIDs) 163 } 164 165 func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { 166 n, err := newSelectFollowRequests(r.conn, accountID).Count(ctx) 167 return n, r.conn.ProcessError(err) 168 } 169 170 func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { 171 n, err := newSelectFollowRequesting(r.conn, accountID).Count(ctx) 172 return n, r.conn.ProcessError(err) 173 } 174 175 // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. 176 func newSelectFollowRequests(conn *DBConn, accountID string) *bun.SelectQuery { 177 return conn.NewSelect(). 178 TableExpr("?", bun.Ident("follow_requests")). 179 ColumnExpr("?", bun.Ident("id")). 180 Where("? = ?", bun.Ident("target_account_id"), accountID). 181 OrderExpr("? DESC", bun.Ident("updated_at")) 182 } 183 184 // newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID. 185 func newSelectFollowRequesting(conn *DBConn, accountID string) *bun.SelectQuery { 186 return conn.NewSelect(). 187 TableExpr("?", bun.Ident("follow_requests")). 188 ColumnExpr("?", bun.Ident("id")). 189 Where("? = ?", bun.Ident("target_account_id"), accountID). 190 OrderExpr("? DESC", bun.Ident("updated_at")) 191 } 192 193 // newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID. 194 func newSelectFollows(conn *DBConn, accountID string) *bun.SelectQuery { 195 return conn.NewSelect(). 196 Table("follows"). 197 Column("id"). 198 Where("? = ?", bun.Ident("account_id"), accountID). 199 OrderExpr("? DESC", bun.Ident("updated_at")) 200 } 201 202 // newSelectLocalFollows returns a new select query for all rows in the follows table with 203 // account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). 204 func newSelectLocalFollows(conn *DBConn, accountID string) *bun.SelectQuery { 205 return conn.NewSelect(). 206 Table("follows"). 207 Column("id"). 208 Where("? = ? AND ? IN (?)", 209 bun.Ident("account_id"), 210 accountID, 211 bun.Ident("target_account_id"), 212 conn.NewSelect(). 213 Table("accounts"). 214 Column("id"). 215 Where("? IS NULL", bun.Ident("domain")), 216 ). 217 OrderExpr("? DESC", bun.Ident("updated_at")) 218 } 219 220 // newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID. 221 func newSelectFollowers(conn *DBConn, accountID string) *bun.SelectQuery { 222 return conn.NewSelect(). 223 Table("follows"). 224 Column("id"). 225 Where("? = ?", bun.Ident("target_account_id"), accountID). 226 OrderExpr("? DESC", bun.Ident("updated_at")) 227 } 228 229 // newSelectLocalFollowers returns a new select query for all rows in the follows table with 230 // target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). 231 func newSelectLocalFollowers(conn *DBConn, accountID string) *bun.SelectQuery { 232 return conn.NewSelect(). 233 Table("follows"). 234 Column("id"). 235 Where("? = ? AND ? IN (?)", 236 bun.Ident("target_account_id"), 237 accountID, 238 bun.Ident("account_id"), 239 conn.NewSelect(). 240 Table("accounts"). 241 Column("id"). 242 Where("? IS NULL", bun.Ident("domain")), 243 ). 244 OrderExpr("? DESC", bun.Ident("updated_at")) 245 }