gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

emoji.go (13249B)


      1 // GoToSocial
      2 // Copyright (C) GoToSocial Authors admin@gotosocial.org
      3 // SPDX-License-Identifier: AGPL-3.0-or-later
      4 //
      5 // This program is free software: you can redistribute it and/or modify
      6 // it under the terms of the GNU Affero General Public License as published by
      7 // the Free Software Foundation, either version 3 of the License, or
      8 // (at your option) any later version.
      9 //
     10 // This program is distributed in the hope that it will be useful,
     11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
     12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     13 // GNU Affero General Public License for more details.
     14 //
     15 // You should have received a copy of the GNU Affero General Public License
     16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
     17 
     18 package bundb
     19 
     20 import (
     21 	"context"
     22 	"errors"
     23 	"strings"
     24 	"time"
     25 
     26 	"github.com/superseriousbusiness/gotosocial/internal/db"
     27 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
     28 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
     29 	"github.com/superseriousbusiness/gotosocial/internal/log"
     30 	"github.com/superseriousbusiness/gotosocial/internal/state"
     31 	"github.com/uptrace/bun"
     32 	"github.com/uptrace/bun/dialect"
     33 )
     34 
     35 type emojiDB struct {
     36 	conn  *DBConn
     37 	state *state.State
     38 }
     39 
     40 func (e *emojiDB) newEmojiQ(emoji *gtsmodel.Emoji) *bun.SelectQuery {
     41 	return e.conn.
     42 		NewSelect().
     43 		Model(emoji).
     44 		Relation("Category")
     45 }
     46 
     47 func (e *emojiDB) newEmojiCategoryQ(emojiCategory *gtsmodel.EmojiCategory) *bun.SelectQuery {
     48 	return e.conn.
     49 		NewSelect().
     50 		Model(emojiCategory)
     51 }
     52 
     53 func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error {
     54 	return e.state.Caches.GTS.Emoji().Store(emoji, func() error {
     55 		_, err := e.conn.NewInsert().Model(emoji).Exec(ctx)
     56 		return e.conn.ProcessError(err)
     57 	})
     58 }
     59 
     60 func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, columns ...string) (*gtsmodel.Emoji, db.Error) {
     61 	emoji.UpdatedAt = time.Now()
     62 	if len(columns) > 0 {
     63 		// If we're updating by column, ensure "updated_at" is included.
     64 		columns = append(columns, "updated_at")
     65 	}
     66 
     67 	err := e.state.Caches.GTS.Emoji().Store(emoji, func() error {
     68 		_, err := e.conn.
     69 			NewUpdate().
     70 			Model(emoji).
     71 			Where("? = ?", bun.Ident("emoji.id"), emoji.ID).
     72 			Column(columns...).
     73 			Exec(ctx)
     74 		return e.conn.ProcessError(err)
     75 	})
     76 	if err != nil {
     77 		return nil, err
     78 	}
     79 
     80 	return emoji, nil
     81 }
     82 
     83 func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error {
     84 	defer e.state.Caches.GTS.Emoji().Invalidate("ID", id)
     85 
     86 	// Load emoji into cache before attempting a delete,
     87 	// as we need it cached in order to trigger the invalidate
     88 	// callback. This in turn invalidates others.
     89 	_, err := e.GetEmojiByID(
     90 		gtscontext.SetBarebones(ctx),
     91 		id,
     92 	)
     93 	if err != nil && !errors.Is(err, db.ErrNoEntries) {
     94 		// NOTE: even if db.ErrNoEntries is returned, we
     95 		// still run the below transaction to ensure related
     96 		// objects are appropriately deleted.
     97 		return err
     98 	}
     99 
    100 	return e.conn.RunInTx(ctx, func(tx bun.Tx) error {
    101 		// delete links between this emoji and any statuses that use it
    102 		if _, err := tx.
    103 			NewDelete().
    104 			TableExpr("? AS ?", bun.Ident("status_to_emojis"), bun.Ident("status_to_emoji")).
    105 			Where("? = ?", bun.Ident("status_to_emoji.emoji_id"), id).
    106 			Exec(ctx); err != nil {
    107 			return err
    108 		}
    109 
    110 		// delete links between this emoji and any accounts that use it
    111 		if _, err := tx.
    112 			NewDelete().
    113 			TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
    114 			Where("? = ?", bun.Ident("account_to_emoji.emoji_id"), id).
    115 			Exec(ctx); err != nil {
    116 			return err
    117 		}
    118 
    119 		if _, err := tx.
    120 			NewDelete().
    121 			TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")).
    122 			Where("? = ?", bun.Ident("emoji.id"), id).
    123 			Exec(ctx); err != nil {
    124 			return e.conn.ProcessError(err)
    125 		}
    126 
    127 		return nil
    128 	})
    129 }
    130 
    131 func (e *emojiDB) GetEmojis(ctx context.Context, domain string, includeDisabled bool, includeEnabled bool, shortcode string, maxShortcodeDomain string, minShortcodeDomain string, limit int) ([]*gtsmodel.Emoji, db.Error) {
    132 	emojiIDs := []string{}
    133 
    134 	subQuery := e.conn.
    135 		NewSelect().
    136 		ColumnExpr("? AS ?", bun.Ident("emoji.id"), bun.Ident("emoji_ids"))
    137 
    138 	// To ensure consistent ordering and make paging possible, we sort not by shortcode
    139 	// but by [shortcode]@[domain]. Because sqlite and postgres have different syntax
    140 	// for concatenation, that means we need to switch here. Depending on which driver
    141 	// is in use, query will look something like this (sqlite):
    142 	//
    143 	//	SELECT
    144 	//		"emoji"."id" AS "emoji_ids",
    145 	//		lower("emoji"."shortcode" || '@' || COALESCE("emoji"."domain", '')) AS "shortcode_domain"
    146 	//	FROM
    147 	//		"emojis" AS "emoji"
    148 	//	ORDER BY
    149 	//		"shortcode_domain" ASC
    150 	//
    151 	// Or like this (postgres):
    152 	//
    153 	//	SELECT
    154 	//		"emoji"."id" AS "emoji_ids",
    155 	//		LOWER(CONCAT("emoji"."shortcode", '@', COALESCE("emoji"."domain", ''))) AS "shortcode_domain"
    156 	//	FROM
    157 	//		"emojis" AS "emoji"
    158 	//	ORDER BY
    159 	//		"shortcode_domain" ASC
    160 	switch e.conn.Dialect().Name() {
    161 	case dialect.SQLite:
    162 		subQuery = subQuery.ColumnExpr("LOWER(? || ? || COALESCE(?, ?)) AS ?", bun.Ident("emoji.shortcode"), "@", bun.Ident("emoji.domain"), "", bun.Ident("shortcode_domain"))
    163 	case dialect.PG:
    164 		subQuery = subQuery.ColumnExpr("LOWER(CONCAT(?, ?, COALESCE(?, ?))) AS ?", bun.Ident("emoji.shortcode"), "@", bun.Ident("emoji.domain"), "", bun.Ident("shortcode_domain"))
    165 	default:
    166 		panic("db conn was neither pg not sqlite")
    167 	}
    168 
    169 	subQuery = subQuery.TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji"))
    170 
    171 	if domain == "" {
    172 		subQuery = subQuery.Where("? IS NULL", bun.Ident("emoji.domain"))
    173 	} else if domain != db.EmojiAllDomains {
    174 		subQuery = subQuery.Where("? = ?", bun.Ident("emoji.domain"), domain)
    175 	}
    176 
    177 	switch {
    178 	case includeDisabled && !includeEnabled:
    179 		// show only disabled emojis
    180 		subQuery = subQuery.Where("? = ?", bun.Ident("emoji.disabled"), true)
    181 	case includeEnabled && !includeDisabled:
    182 		// show only enabled emojis
    183 		subQuery = subQuery.Where("? = ?", bun.Ident("emoji.disabled"), false)
    184 	default:
    185 		// show emojis regardless of emoji.disabled value
    186 	}
    187 
    188 	if shortcode != "" {
    189 		subQuery = subQuery.Where("LOWER(?) = LOWER(?)", bun.Ident("emoji.shortcode"), shortcode)
    190 	}
    191 
    192 	// assume we want to sort ASC (a-z) unless informed otherwise
    193 	order := "ASC"
    194 
    195 	if maxShortcodeDomain != "" {
    196 		subQuery = subQuery.Where("? > LOWER(?)", bun.Ident("shortcode_domain"), maxShortcodeDomain)
    197 	}
    198 
    199 	if minShortcodeDomain != "" {
    200 		subQuery = subQuery.Where("? < LOWER(?)", bun.Ident("shortcode_domain"), minShortcodeDomain)
    201 		// if we have a minShortcodeDomain we're paging upwards/backwards
    202 		order = "DESC"
    203 	}
    204 
    205 	subQuery = subQuery.Order("shortcode_domain " + order)
    206 
    207 	if limit > 0 {
    208 		subQuery = subQuery.Limit(limit)
    209 	}
    210 
    211 	// Wrap the subQuery in a query, since we don't need to select the shortcode_domain column.
    212 	//
    213 	// The final query will come out looking something like...
    214 	//
    215 	//	SELECT
    216 	//		"subquery"."emoji_ids"
    217 	//	FROM (
    218 	//		SELECT
    219 	//			"emoji"."id" AS "emoji_ids",
    220 	//			LOWER("emoji"."shortcode" || '@' || COALESCE("emoji"."domain", '')) AS "shortcode_domain"
    221 	//		FROM
    222 	//			"emojis" AS "emoji"
    223 	//		ORDER BY
    224 	//			"shortcode_domain" ASC
    225 	//	) AS "subquery"
    226 	if err := e.conn.
    227 		NewSelect().
    228 		Column("subquery.emoji_ids").
    229 		TableExpr("(?) AS ?", subQuery, bun.Ident("subquery")).
    230 		Scan(ctx, &emojiIDs); err != nil {
    231 		return nil, e.conn.ProcessError(err)
    232 	}
    233 
    234 	if order == "DESC" {
    235 		// Reverse the slice order so the caller still
    236 		// gets emojis in expected a-z alphabetical order.
    237 		//
    238 		// See https://github.com/golang/go/wiki/SliceTricks#reversing
    239 		for i := len(emojiIDs)/2 - 1; i >= 0; i-- {
    240 			opp := len(emojiIDs) - 1 - i
    241 			emojiIDs[i], emojiIDs[opp] = emojiIDs[opp], emojiIDs[i]
    242 		}
    243 	}
    244 
    245 	return e.GetEmojisByIDs(ctx, emojiIDs)
    246 }
    247 
    248 func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Error) {
    249 	emojiIDs := []string{}
    250 
    251 	q := e.conn.
    252 		NewSelect().
    253 		TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")).
    254 		Column("emoji.id").
    255 		Where("? = ?", bun.Ident("emoji.visible_in_picker"), true).
    256 		Where("? = ?", bun.Ident("emoji.disabled"), false).
    257 		Where("? IS NULL", bun.Ident("emoji.domain")).
    258 		Order("emoji.shortcode ASC")
    259 
    260 	if err := q.Scan(ctx, &emojiIDs); err != nil {
    261 		return nil, e.conn.ProcessError(err)
    262 	}
    263 
    264 	return e.GetEmojisByIDs(ctx, emojiIDs)
    265 }
    266 
    267 func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, db.Error) {
    268 	return e.getEmoji(
    269 		ctx,
    270 		"ID",
    271 		func(emoji *gtsmodel.Emoji) error {
    272 			return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx)
    273 		},
    274 		id,
    275 	)
    276 }
    277 
    278 func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, db.Error) {
    279 	return e.getEmoji(
    280 		ctx,
    281 		"URI",
    282 		func(emoji *gtsmodel.Emoji) error {
    283 			return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx)
    284 		},
    285 		uri,
    286 	)
    287 }
    288 
    289 func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, db.Error) {
    290 	return e.getEmoji(
    291 		ctx,
    292 		"Shortcode.Domain",
    293 		func(emoji *gtsmodel.Emoji) error {
    294 			q := e.newEmojiQ(emoji)
    295 
    296 			if domain != "" {
    297 				q = q.Where("? = ?", bun.Ident("emoji.shortcode"), shortcode)
    298 				q = q.Where("? = ?", bun.Ident("emoji.domain"), domain)
    299 			} else {
    300 				q = q.Where("? = ?", bun.Ident("emoji.shortcode"), strings.ToLower(shortcode))
    301 				q = q.Where("? IS NULL", bun.Ident("emoji.domain"))
    302 			}
    303 
    304 			return q.Scan(ctx)
    305 		},
    306 		shortcode,
    307 		domain,
    308 	)
    309 }
    310 
    311 func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, db.Error) {
    312 	return e.getEmoji(
    313 		ctx,
    314 		"ImageStaticURL",
    315 		func(emoji *gtsmodel.Emoji) error {
    316 			return e.
    317 				newEmojiQ(emoji).
    318 				Where("? = ?", bun.Ident("emoji.image_static_url"), imageStaticURL).
    319 				Scan(ctx)
    320 		},
    321 		imageStaticURL,
    322 	)
    323 }
    324 
    325 func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) db.Error {
    326 	return e.state.Caches.GTS.EmojiCategory().Store(emojiCategory, func() error {
    327 		_, err := e.conn.NewInsert().Model(emojiCategory).Exec(ctx)
    328 		return e.conn.ProcessError(err)
    329 	})
    330 }
    331 
    332 func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, db.Error) {
    333 	emojiCategoryIDs := []string{}
    334 
    335 	q := e.conn.
    336 		NewSelect().
    337 		TableExpr("? AS ?", bun.Ident("emoji_categories"), bun.Ident("emoji_category")).
    338 		Column("emoji_category.id").
    339 		Order("emoji_category.name ASC")
    340 
    341 	if err := q.Scan(ctx, &emojiCategoryIDs); err != nil {
    342 		return nil, e.conn.ProcessError(err)
    343 	}
    344 
    345 	return e.GetEmojiCategoriesByIDs(ctx, emojiCategoryIDs)
    346 }
    347 
    348 func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, db.Error) {
    349 	return e.getEmojiCategory(
    350 		ctx,
    351 		"ID",
    352 		func(emojiCategory *gtsmodel.EmojiCategory) error {
    353 			return e.newEmojiCategoryQ(emojiCategory).Where("? = ?", bun.Ident("emoji_category.id"), id).Scan(ctx)
    354 		},
    355 		id,
    356 	)
    357 }
    358 
    359 func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, db.Error) {
    360 	return e.getEmojiCategory(
    361 		ctx,
    362 		"Name",
    363 		func(emojiCategory *gtsmodel.EmojiCategory) error {
    364 			return e.newEmojiCategoryQ(emojiCategory).Where("LOWER(?) = ?", bun.Ident("emoji_category.name"), strings.ToLower(name)).Scan(ctx)
    365 		},
    366 		name,
    367 	)
    368 }
    369 
    370 func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, db.Error) {
    371 	return e.state.Caches.GTS.Emoji().Load(lookup, func() (*gtsmodel.Emoji, error) {
    372 		var emoji gtsmodel.Emoji
    373 
    374 		// Not cached! Perform database query
    375 		if err := dbQuery(&emoji); err != nil {
    376 			return nil, e.conn.ProcessError(err)
    377 		}
    378 
    379 		return &emoji, nil
    380 	}, keyParts...)
    381 }
    382 
    383 func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, db.Error) {
    384 	if len(emojiIDs) == 0 {
    385 		return nil, db.ErrNoEntries
    386 	}
    387 
    388 	emojis := make([]*gtsmodel.Emoji, 0, len(emojiIDs))
    389 
    390 	for _, id := range emojiIDs {
    391 		emoji, err := e.GetEmojiByID(ctx, id)
    392 		if err != nil {
    393 			log.Errorf(ctx, "emojisFromIDs: error getting emoji %q: %v", id, err)
    394 			continue
    395 		}
    396 
    397 		emojis = append(emojis, emoji)
    398 	}
    399 
    400 	return emojis, nil
    401 }
    402 
    403 func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, db.Error) {
    404 	return e.state.Caches.GTS.EmojiCategory().Load(lookup, func() (*gtsmodel.EmojiCategory, error) {
    405 		var category gtsmodel.EmojiCategory
    406 
    407 		// Not cached! Perform database query
    408 		if err := dbQuery(&category); err != nil {
    409 			return nil, e.conn.ProcessError(err)
    410 		}
    411 
    412 		return &category, nil
    413 	}, keyParts...)
    414 }
    415 
    416 func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, db.Error) {
    417 	if len(emojiCategoryIDs) == 0 {
    418 		return nil, db.ErrNoEntries
    419 	}
    420 
    421 	emojiCategories := make([]*gtsmodel.EmojiCategory, 0, len(emojiCategoryIDs))
    422 
    423 	for _, id := range emojiCategoryIDs {
    424 		emojiCategory, err := e.GetEmojiCategory(ctx, id)
    425 		if err != nil {
    426 			log.Errorf(ctx, "error getting emoji category %q: %v", id, err)
    427 			continue
    428 		}
    429 
    430 		emojiCategories = append(emojiCategories, emojiCategory)
    431 	}
    432 
    433 	return emojiCategories, nil
    434 }