gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

relationship.go (8797B)


      1 // GoToSocial
      2 // Copyright (C) GoToSocial Authors admin@gotosocial.org
      3 // SPDX-License-Identifier: AGPL-3.0-or-later
      4 //
      5 // This program is free software: you can redistribute it and/or modify
      6 // it under the terms of the GNU Affero General Public License as published by
      7 // the Free Software Foundation, either version 3 of the License, or
      8 // (at your option) any later version.
      9 //
     10 // This program is distributed in the hope that it will be useful,
     11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
     12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     13 // GNU Affero General Public License for more details.
     14 //
     15 // You should have received a copy of the GNU Affero General Public License
     16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
     17 
     18 package bundb
     19 
     20 import (
     21 	"context"
     22 	"errors"
     23 	"fmt"
     24 
     25 	"github.com/superseriousbusiness/gotosocial/internal/db"
     26 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
     27 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
     28 	"github.com/superseriousbusiness/gotosocial/internal/state"
     29 	"github.com/uptrace/bun"
     30 )
     31 
     32 type relationshipDB struct {
     33 	conn  *DBConn
     34 	state *state.State
     35 }
     36 
     37 func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
     38 	var rel gtsmodel.Relationship
     39 	rel.ID = targetAccount
     40 
     41 	// check if the requesting follows the target
     42 	follow, err := r.GetFollow(
     43 		gtscontext.SetBarebones(ctx),
     44 		requestingAccount,
     45 		targetAccount,
     46 	)
     47 	if err != nil && !errors.Is(err, db.ErrNoEntries) {
     48 		return nil, fmt.Errorf("GetRelationship: error fetching follow: %w", err)
     49 	}
     50 
     51 	if follow != nil {
     52 		// follow exists so we can fill these fields out...
     53 		rel.Following = true
     54 		rel.ShowingReblogs = *follow.ShowReblogs
     55 		rel.Notifying = *follow.Notify
     56 	}
     57 
     58 	// check if the target follows the requesting
     59 	rel.FollowedBy, err = r.IsFollowing(ctx,
     60 		targetAccount,
     61 		requestingAccount,
     62 	)
     63 	if err != nil {
     64 		return nil, fmt.Errorf("GetRelationship: error checking followedBy: %w", err)
     65 	}
     66 
     67 	// check if requesting has follow requested target
     68 	rel.Requested, err = r.IsFollowRequested(ctx,
     69 		requestingAccount,
     70 		targetAccount,
     71 	)
     72 	if err != nil {
     73 		return nil, fmt.Errorf("GetRelationship: error checking requested: %w", err)
     74 	}
     75 
     76 	// check if the requesting account is blocking the target account
     77 	rel.Blocking, err = r.IsBlocked(ctx, requestingAccount, targetAccount)
     78 	if err != nil {
     79 		return nil, fmt.Errorf("GetRelationship: error checking blocking: %w", err)
     80 	}
     81 
     82 	// check if the requesting account is blocked by the target account
     83 	rel.BlockedBy, err = r.IsBlocked(ctx, targetAccount, requestingAccount)
     84 	if err != nil {
     85 		return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %w", err)
     86 	}
     87 
     88 	return &rel, nil
     89 }
     90 
     91 func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
     92 	var followIDs []string
     93 	if err := newSelectFollows(r.conn, accountID).
     94 		Scan(ctx, &followIDs); err != nil {
     95 		return nil, r.conn.ProcessError(err)
     96 	}
     97 	return r.GetFollowsByIDs(ctx, followIDs)
     98 }
     99 
    100 func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
    101 	var followIDs []string
    102 	if err := newSelectLocalFollows(r.conn, accountID).
    103 		Scan(ctx, &followIDs); err != nil {
    104 		return nil, r.conn.ProcessError(err)
    105 	}
    106 	return r.GetFollowsByIDs(ctx, followIDs)
    107 }
    108 
    109 func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
    110 	var followIDs []string
    111 	if err := newSelectFollowers(r.conn, accountID).
    112 		Scan(ctx, &followIDs); err != nil {
    113 		return nil, r.conn.ProcessError(err)
    114 	}
    115 	return r.GetFollowsByIDs(ctx, followIDs)
    116 }
    117 
    118 func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
    119 	var followIDs []string
    120 	if err := newSelectLocalFollowers(r.conn, accountID).
    121 		Scan(ctx, &followIDs); err != nil {
    122 		return nil, r.conn.ProcessError(err)
    123 	}
    124 	return r.GetFollowsByIDs(ctx, followIDs)
    125 }
    126 
    127 func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) {
    128 	n, err := newSelectFollows(r.conn, accountID).Count(ctx)
    129 	return n, r.conn.ProcessError(err)
    130 }
    131 
    132 func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) {
    133 	n, err := newSelectLocalFollows(r.conn, accountID).Count(ctx)
    134 	return n, r.conn.ProcessError(err)
    135 }
    136 
    137 func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) {
    138 	n, err := newSelectFollowers(r.conn, accountID).Count(ctx)
    139 	return n, r.conn.ProcessError(err)
    140 }
    141 
    142 func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) {
    143 	n, err := newSelectLocalFollowers(r.conn, accountID).Count(ctx)
    144 	return n, r.conn.ProcessError(err)
    145 }
    146 
    147 func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
    148 	var followReqIDs []string
    149 	if err := newSelectFollowRequests(r.conn, accountID).
    150 		Scan(ctx, &followReqIDs); err != nil {
    151 		return nil, r.conn.ProcessError(err)
    152 	}
    153 	return r.GetFollowRequestsByIDs(ctx, followReqIDs)
    154 }
    155 
    156 func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
    157 	var followReqIDs []string
    158 	if err := newSelectFollowRequesting(r.conn, accountID).
    159 		Scan(ctx, &followReqIDs); err != nil {
    160 		return nil, r.conn.ProcessError(err)
    161 	}
    162 	return r.GetFollowRequestsByIDs(ctx, followReqIDs)
    163 }
    164 
    165 func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) {
    166 	n, err := newSelectFollowRequests(r.conn, accountID).Count(ctx)
    167 	return n, r.conn.ProcessError(err)
    168 }
    169 
    170 func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) {
    171 	n, err := newSelectFollowRequesting(r.conn, accountID).Count(ctx)
    172 	return n, r.conn.ProcessError(err)
    173 }
    174 
    175 // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID.
    176 func newSelectFollowRequests(conn *DBConn, accountID string) *bun.SelectQuery {
    177 	return conn.NewSelect().
    178 		TableExpr("?", bun.Ident("follow_requests")).
    179 		ColumnExpr("?", bun.Ident("id")).
    180 		Where("? = ?", bun.Ident("target_account_id"), accountID).
    181 		OrderExpr("? DESC", bun.Ident("updated_at"))
    182 }
    183 
    184 // newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID.
    185 func newSelectFollowRequesting(conn *DBConn, accountID string) *bun.SelectQuery {
    186 	return conn.NewSelect().
    187 		TableExpr("?", bun.Ident("follow_requests")).
    188 		ColumnExpr("?", bun.Ident("id")).
    189 		Where("? = ?", bun.Ident("target_account_id"), accountID).
    190 		OrderExpr("? DESC", bun.Ident("updated_at"))
    191 }
    192 
    193 // newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID.
    194 func newSelectFollows(conn *DBConn, accountID string) *bun.SelectQuery {
    195 	return conn.NewSelect().
    196 		Table("follows").
    197 		Column("id").
    198 		Where("? = ?", bun.Ident("account_id"), accountID).
    199 		OrderExpr("? DESC", bun.Ident("updated_at"))
    200 }
    201 
    202 // newSelectLocalFollows returns a new select query for all rows in the follows table with
    203 // account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
    204 func newSelectLocalFollows(conn *DBConn, accountID string) *bun.SelectQuery {
    205 	return conn.NewSelect().
    206 		Table("follows").
    207 		Column("id").
    208 		Where("? = ? AND ? IN (?)",
    209 			bun.Ident("account_id"),
    210 			accountID,
    211 			bun.Ident("target_account_id"),
    212 			conn.NewSelect().
    213 				Table("accounts").
    214 				Column("id").
    215 				Where("? IS NULL", bun.Ident("domain")),
    216 		).
    217 		OrderExpr("? DESC", bun.Ident("updated_at"))
    218 }
    219 
    220 // newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID.
    221 func newSelectFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
    222 	return conn.NewSelect().
    223 		Table("follows").
    224 		Column("id").
    225 		Where("? = ?", bun.Ident("target_account_id"), accountID).
    226 		OrderExpr("? DESC", bun.Ident("updated_at"))
    227 }
    228 
    229 // newSelectLocalFollowers returns a new select query for all rows in the follows table with
    230 // target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
    231 func newSelectLocalFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
    232 	return conn.NewSelect().
    233 		Table("follows").
    234 		Column("id").
    235 		Where("? = ? AND ? IN (?)",
    236 			bun.Ident("target_account_id"),
    237 			accountID,
    238 			bun.Ident("account_id"),
    239 			conn.NewSelect().
    240 				Table("accounts").
    241 				Column("id").
    242 				Where("? IS NULL", bun.Ident("domain")),
    243 		).
    244 		OrderExpr("? DESC", bun.Ident("updated_at"))
    245 }