notification.go (8318B)
1 // GoToSocial 2 // Copyright (C) GoToSocial Authors admin@gotosocial.org 3 // SPDX-License-Identifier: AGPL-3.0-or-later 4 // 5 // This program is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Affero General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // This program is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Affero General Public License for more details. 14 // 15 // You should have received a copy of the GNU Affero General Public License 16 // along with this program. If not, see <http://www.gnu.org/licenses/>. 17 18 package bundb 19 20 import ( 21 "context" 22 "errors" 23 24 "github.com/superseriousbusiness/gotosocial/internal/db" 25 "github.com/superseriousbusiness/gotosocial/internal/gtscontext" 26 "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" 27 "github.com/superseriousbusiness/gotosocial/internal/id" 28 "github.com/superseriousbusiness/gotosocial/internal/log" 29 "github.com/superseriousbusiness/gotosocial/internal/state" 30 "github.com/uptrace/bun" 31 ) 32 33 type notificationDB struct { 34 conn *DBConn 35 state *state.State 36 } 37 38 func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { 39 return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) { 40 var notif gtsmodel.Notification 41 42 q := n.conn.NewSelect(). 43 Model(¬if). 44 Where("? = ?", bun.Ident("notification.id"), id) 45 if err := q.Scan(ctx); err != nil { 46 return nil, n.conn.ProcessError(err) 47 } 48 49 return ¬if, nil 50 }, id) 51 } 52 53 func (n *notificationDB) GetNotification( 54 ctx context.Context, 55 notificationType gtsmodel.NotificationType, 56 targetAccountID string, 57 originAccountID string, 58 statusID string, 59 ) (*gtsmodel.Notification, db.Error) { 60 return n.state.Caches.GTS.Notification().Load("NotificationType.TargetAccountID.OriginAccountID.StatusID", func() (*gtsmodel.Notification, error) { 61 var notif gtsmodel.Notification 62 63 q := n.conn.NewSelect(). 64 Model(¬if). 65 Where("? = ?", bun.Ident("notification_type"), notificationType). 66 Where("? = ?", bun.Ident("target_account_id"), targetAccountID). 67 Where("? = ?", bun.Ident("origin_account_id"), originAccountID). 68 Where("? = ?", bun.Ident("status_id"), statusID) 69 70 if err := q.Scan(ctx); err != nil { 71 return nil, n.conn.ProcessError(err) 72 } 73 74 return ¬if, nil 75 }, notificationType, targetAccountID, originAccountID, statusID) 76 } 77 78 func (n *notificationDB) GetAccountNotifications( 79 ctx context.Context, 80 accountID string, 81 maxID string, 82 sinceID string, 83 minID string, 84 limit int, 85 excludeTypes []string, 86 ) ([]*gtsmodel.Notification, db.Error) { 87 // Ensure reasonable 88 if limit < 0 { 89 limit = 0 90 } 91 92 // Make educated guess for slice size 93 var ( 94 notifIDs = make([]string, 0, limit) 95 frontToBack = true 96 ) 97 98 q := n.conn. 99 NewSelect(). 100 TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). 101 Column("notification.id") 102 103 if maxID == "" { 104 maxID = id.Highest 105 } 106 107 // Return only notifs LOWER (ie., older) than maxID. 108 q = q.Where("? < ?", bun.Ident("notification.id"), maxID) 109 110 if sinceID != "" { 111 // Return only notifs HIGHER (ie., newer) than sinceID. 112 q = q.Where("? > ?", bun.Ident("notification.id"), sinceID) 113 } 114 115 if minID != "" { 116 // Return only notifs HIGHER (ie., newer) than minID. 117 q = q.Where("? > ?", bun.Ident("notification.id"), minID) 118 119 frontToBack = false // page up 120 } 121 122 for _, excludeType := range excludeTypes { 123 // Filter out unwanted notif types. 124 q = q.Where("? != ?", bun.Ident("notification.notification_type"), excludeType) 125 } 126 127 // Return only notifs for this account. 128 q = q.Where("? = ?", bun.Ident("notification.target_account_id"), accountID) 129 130 if limit > 0 { 131 q = q.Limit(limit) 132 } 133 134 if frontToBack { 135 // Page down. 136 q = q.Order("notification.id DESC") 137 } else { 138 // Page up. 139 q = q.Order("notification.id ASC") 140 } 141 142 if err := q.Scan(ctx, ¬ifIDs); err != nil { 143 return nil, n.conn.ProcessError(err) 144 } 145 146 if len(notifIDs) == 0 { 147 return nil, nil 148 } 149 150 // If we're paging up, we still want notifications 151 // to be sorted by ID desc, so reverse ids slice. 152 // https://zchee.github.io/golang-wiki/SliceTricks/#reversing 153 if !frontToBack { 154 for l, r := 0, len(notifIDs)-1; l < r; l, r = l+1, r-1 { 155 notifIDs[l], notifIDs[r] = notifIDs[r], notifIDs[l] 156 } 157 } 158 159 notifs := make([]*gtsmodel.Notification, 0, len(notifIDs)) 160 for _, id := range notifIDs { 161 // Attempt fetch from DB 162 notif, err := n.GetNotificationByID(ctx, id) 163 if err != nil { 164 log.Errorf(ctx, "error fetching notification %q: %v", id, err) 165 continue 166 } 167 168 // Append notification 169 notifs = append(notifs, notif) 170 } 171 172 return notifs, nil 173 } 174 175 func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error { 176 return n.state.Caches.GTS.Notification().Store(notif, func() error { 177 _, err := n.conn.NewInsert().Model(notif).Exec(ctx) 178 return n.conn.ProcessError(err) 179 }) 180 } 181 182 func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) db.Error { 183 defer n.state.Caches.GTS.Notification().Invalidate("ID", id) 184 185 // Load notif into cache before attempting a delete, 186 // as we need it cached in order to trigger the invalidate 187 // callback. This in turn invalidates others. 188 _, err := n.GetNotificationByID(gtscontext.SetBarebones(ctx), id) 189 if err != nil { 190 if errors.Is(err, db.ErrNoEntries) { 191 // not an issue. 192 err = nil 193 } 194 return err 195 } 196 197 // Finally delete notif from DB. 198 _, err = n.conn.NewDelete(). 199 TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). 200 Where("? = ?", bun.Ident("notification.id"), id). 201 Exec(ctx) 202 return n.conn.ProcessError(err) 203 } 204 205 func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) db.Error { 206 if targetAccountID == "" && originAccountID == "" { 207 return errors.New("DeleteNotifications: one of targetAccountID or originAccountID must be set") 208 } 209 210 var notifIDs []string 211 212 q := n.conn. 213 NewSelect(). 214 Column("id"). 215 Table("notifications") 216 217 if len(types) > 0 { 218 q = q.Where("? IN (?)", bun.Ident("notification_type"), bun.In(types)) 219 } 220 221 if targetAccountID != "" { 222 q = q.Where("? = ?", bun.Ident("target_account_id"), targetAccountID) 223 } 224 225 if originAccountID != "" { 226 q = q.Where("? = ?", bun.Ident("origin_account_id"), originAccountID) 227 } 228 229 if _, err := q.Exec(ctx, ¬ifIDs); err != nil { 230 return n.conn.ProcessError(err) 231 } 232 233 defer func() { 234 // Invalidate all IDs on return. 235 for _, id := range notifIDs { 236 n.state.Caches.GTS.Notification().Invalidate("ID", id) 237 } 238 }() 239 240 // Load all notif into cache, this *really* isn't great 241 // but it is the only way we can ensure we invalidate all 242 // related caches correctly (e.g. visibility). 243 for _, id := range notifIDs { 244 _, err := n.GetNotificationByID(ctx, id) 245 if err != nil && !errors.Is(err, db.ErrNoEntries) { 246 return err 247 } 248 } 249 250 // Finally delete all from DB. 251 _, err := n.conn.NewDelete(). 252 Table("notifications"). 253 Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)). 254 Exec(ctx) 255 return n.conn.ProcessError(err) 256 } 257 258 func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) db.Error { 259 var notifIDs []string 260 261 q := n.conn. 262 NewSelect(). 263 Column("id"). 264 Table("notifications"). 265 Where("? = ?", bun.Ident("status_id"), statusID) 266 267 if _, err := q.Exec(ctx, ¬ifIDs); err != nil { 268 return n.conn.ProcessError(err) 269 } 270 271 defer func() { 272 // Invalidate all IDs on return. 273 for _, id := range notifIDs { 274 n.state.Caches.GTS.Notification().Invalidate("ID", id) 275 } 276 }() 277 278 // Load all notif into cache, this *really* isn't great 279 // but it is the only way we can ensure we invalidate all 280 // related caches correctly (e.g. visibility). 281 for _, id := range notifIDs { 282 _, err := n.GetNotificationByID(ctx, id) 283 if err != nil && !errors.Is(err, db.ErrNoEntries) { 284 return err 285 } 286 } 287 288 // Finally delete all from DB. 289 _, err := n.conn.NewDelete(). 290 Table("notifications"). 291 Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)). 292 Exec(ctx) 293 return n.conn.ProcessError(err) 294 }