gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

tokenstore.go (8567B)


      1 // GoToSocial
      2 // Copyright (C) GoToSocial Authors admin@gotosocial.org
      3 // SPDX-License-Identifier: AGPL-3.0-or-later
      4 //
      5 // This program is free software: you can redistribute it and/or modify
      6 // it under the terms of the GNU Affero General Public License as published by
      7 // the Free Software Foundation, either version 3 of the License, or
      8 // (at your option) any later version.
      9 //
     10 // This program is distributed in the hope that it will be useful,
     11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
     12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     13 // GNU Affero General Public License for more details.
     14 //
     15 // You should have received a copy of the GNU Affero General Public License
     16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
     17 
     18 package oauth
     19 
     20 import (
     21 	"context"
     22 	"errors"
     23 	"fmt"
     24 	"time"
     25 
     26 	"github.com/superseriousbusiness/gotosocial/internal/db"
     27 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
     28 	"github.com/superseriousbusiness/gotosocial/internal/id"
     29 	"github.com/superseriousbusiness/gotosocial/internal/log"
     30 	"github.com/superseriousbusiness/oauth2/v4"
     31 	"github.com/superseriousbusiness/oauth2/v4/models"
     32 )
     33 
     34 // tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
     35 type tokenStore struct {
     36 	oauth2.TokenStore
     37 	db db.Basic
     38 }
     39 
     40 // newTokenStore returns a token store that satisfies the oauth2.TokenStore interface.
     41 //
     42 // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
     43 // the tokens in the DB once per minute and deletes any that have expired.
     44 func newTokenStore(ctx context.Context, db db.Basic) oauth2.TokenStore {
     45 	ts := &tokenStore{
     46 		db: db,
     47 	}
     48 
     49 	// set the token store to clean out expired tokens once per minute, or return if we're done
     50 	go func(ctx context.Context, ts *tokenStore) {
     51 	cleanloop:
     52 		for {
     53 			select {
     54 			case <-ctx.Done():
     55 				log.Info(ctx, "breaking cleanloop")
     56 				break cleanloop
     57 			case <-time.After(1 * time.Minute):
     58 				log.Trace(ctx, "sweeping out old oauth entries broom broom")
     59 				if err := ts.sweep(ctx); err != nil {
     60 					log.Errorf(ctx, "error while sweeping oauth entries: %s", err)
     61 				}
     62 			}
     63 		}
     64 	}(ctx, ts)
     65 	return ts
     66 }
     67 
     68 // sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so.
     69 func (ts *tokenStore) sweep(ctx context.Context) error {
     70 	// select *all* tokens from the db
     71 	// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
     72 	tokens := new([]*gtsmodel.Token)
     73 	if err := ts.db.GetAll(ctx, tokens); err != nil {
     74 		return err
     75 	}
     76 
     77 	// iterate through and remove expired tokens
     78 	now := time.Now()
     79 	for _, dbt := range *tokens {
     80 		// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
     81 		// we only want to check if a token expired before now if the expiry time is *not zero*;
     82 		// ie., if it's been explicity set.
     83 		if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) {
     84 			if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil {
     85 				return err
     86 			}
     87 		}
     88 	}
     89 
     90 	return nil
     91 }
     92 
     93 // Create creates and store the new token information.
     94 // For the original implementation, see https://github.com/superseriousbusiness/oauth2/blob/master/store/token.go#L34
     95 func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
     96 	t, ok := info.(*models.Token)
     97 	if !ok {
     98 		return errors.New("info param was not a models.Token")
     99 	}
    100 
    101 	dbt := TokenToDBToken(t)
    102 	if dbt.ID == "" {
    103 		dbtID, err := id.NewRandomULID()
    104 		if err != nil {
    105 			return err
    106 		}
    107 		dbt.ID = dbtID
    108 	}
    109 
    110 	if err := ts.db.Put(ctx, dbt); err != nil {
    111 		return fmt.Errorf("error in tokenstore create: %s", err)
    112 	}
    113 	return nil
    114 }
    115 
    116 // RemoveByCode deletes a token from the DB based on the Code field
    117 func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
    118 	return ts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, &gtsmodel.Token{})
    119 }
    120 
    121 // RemoveByAccess deletes a token from the DB based on the Access field
    122 func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
    123 	return ts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, &gtsmodel.Token{})
    124 }
    125 
    126 // RemoveByRefresh deletes a token from the DB based on the Refresh field
    127 func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
    128 	return ts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, &gtsmodel.Token{})
    129 }
    130 
    131 // GetByCode selects a token from the DB based on the Code field
    132 func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
    133 	if code == "" {
    134 		return nil, nil
    135 	}
    136 	dbt := &gtsmodel.Token{
    137 		Code: code,
    138 	}
    139 	if err := ts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, dbt); err != nil {
    140 		return nil, err
    141 	}
    142 	return DBTokenToToken(dbt), nil
    143 }
    144 
    145 // GetByAccess selects a token from the DB based on the Access field
    146 func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
    147 	if access == "" {
    148 		return nil, nil
    149 	}
    150 	dbt := &gtsmodel.Token{
    151 		Access: access,
    152 	}
    153 	if err := ts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, dbt); err != nil {
    154 		return nil, err
    155 	}
    156 	return DBTokenToToken(dbt), nil
    157 }
    158 
    159 // GetByRefresh selects a token from the DB based on the Refresh field
    160 func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
    161 	if refresh == "" {
    162 		return nil, nil
    163 	}
    164 	dbt := &gtsmodel.Token{
    165 		Refresh: refresh,
    166 	}
    167 	if err := ts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, dbt); err != nil {
    168 		return nil, err
    169 	}
    170 	return DBTokenToToken(dbt), nil
    171 }
    172 
    173 /*
    174 	The following models are basically helpers for the token store implementation, they should only be used internally.
    175 */
    176 
    177 // TokenToDBToken is a lil util function that takes a gotosocial token and gives back a token for inserting into a database.
    178 func TokenToDBToken(tkn *models.Token) *gtsmodel.Token {
    179 	now := time.Now()
    180 
    181 	// For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's
    182 	// going to cause all sorts of interesting problems. So check first to make sure that the ExpiresIn is not equal
    183 	// to the zero value of a time.Duration, which is 0s. If it *is* empty/nil, just leave the ExpiresAt at nil as well.
    184 
    185 	cea := time.Time{}
    186 	if tkn.CodeExpiresIn != 0*time.Second {
    187 		cea = now.Add(tkn.CodeExpiresIn)
    188 	}
    189 
    190 	aea := time.Time{}
    191 	if tkn.AccessExpiresIn != 0*time.Second {
    192 		aea = now.Add(tkn.AccessExpiresIn)
    193 	}
    194 
    195 	rea := time.Time{}
    196 	if tkn.RefreshExpiresIn != 0*time.Second {
    197 		rea = now.Add(tkn.RefreshExpiresIn)
    198 	}
    199 
    200 	return &gtsmodel.Token{
    201 		ClientID:            tkn.ClientID,
    202 		UserID:              tkn.UserID,
    203 		RedirectURI:         tkn.RedirectURI,
    204 		Scope:               tkn.Scope,
    205 		Code:                tkn.Code,
    206 		CodeChallenge:       tkn.CodeChallenge,
    207 		CodeChallengeMethod: tkn.CodeChallengeMethod,
    208 		CodeCreateAt:        tkn.CodeCreateAt,
    209 		CodeExpiresAt:       cea,
    210 		Access:              tkn.Access,
    211 		AccessCreateAt:      tkn.AccessCreateAt,
    212 		AccessExpiresAt:     aea,
    213 		Refresh:             tkn.Refresh,
    214 		RefreshCreateAt:     tkn.RefreshCreateAt,
    215 		RefreshExpiresAt:    rea,
    216 	}
    217 }
    218 
    219 // DBTokenToToken is a lil util function that takes a database token and gives back a gotosocial token
    220 func DBTokenToToken(dbt *gtsmodel.Token) *models.Token {
    221 	now := time.Now()
    222 
    223 	var codeExpiresIn time.Duration
    224 	if !dbt.CodeExpiresAt.IsZero() {
    225 		codeExpiresIn = dbt.CodeExpiresAt.Sub(now)
    226 	}
    227 
    228 	var accessExpiresIn time.Duration
    229 	if !dbt.AccessExpiresAt.IsZero() {
    230 		accessExpiresIn = dbt.AccessExpiresAt.Sub(now)
    231 	}
    232 
    233 	var refreshExpiresIn time.Duration
    234 	if !dbt.RefreshExpiresAt.IsZero() {
    235 		refreshExpiresIn = dbt.RefreshExpiresAt.Sub(now)
    236 	}
    237 
    238 	return &models.Token{
    239 		ClientID:            dbt.ClientID,
    240 		UserID:              dbt.UserID,
    241 		RedirectURI:         dbt.RedirectURI,
    242 		Scope:               dbt.Scope,
    243 		Code:                dbt.Code,
    244 		CodeChallenge:       dbt.CodeChallenge,
    245 		CodeChallengeMethod: dbt.CodeChallengeMethod,
    246 		CodeCreateAt:        dbt.CodeCreateAt,
    247 		CodeExpiresIn:       codeExpiresIn,
    248 		Access:              dbt.Access,
    249 		AccessCreateAt:      dbt.AccessCreateAt,
    250 		AccessExpiresIn:     accessExpiresIn,
    251 		Refresh:             dbt.Refresh,
    252 		RefreshCreateAt:     dbt.RefreshCreateAt,
    253 		RefreshExpiresIn:    refreshExpiresIn,
    254 	}
    255 }