tokenstore.go (8567B)
1 // GoToSocial 2 // Copyright (C) GoToSocial Authors admin@gotosocial.org 3 // SPDX-License-Identifier: AGPL-3.0-or-later 4 // 5 // This program is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Affero General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // This program is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Affero General Public License for more details. 14 // 15 // You should have received a copy of the GNU Affero General Public License 16 // along with this program. If not, see <http://www.gnu.org/licenses/>. 17 18 package oauth 19 20 import ( 21 "context" 22 "errors" 23 "fmt" 24 "time" 25 26 "github.com/superseriousbusiness/gotosocial/internal/db" 27 "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" 28 "github.com/superseriousbusiness/gotosocial/internal/id" 29 "github.com/superseriousbusiness/gotosocial/internal/log" 30 "github.com/superseriousbusiness/oauth2/v4" 31 "github.com/superseriousbusiness/oauth2/v4/models" 32 ) 33 34 // tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend. 35 type tokenStore struct { 36 oauth2.TokenStore 37 db db.Basic 38 } 39 40 // newTokenStore returns a token store that satisfies the oauth2.TokenStore interface. 41 // 42 // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through 43 // the tokens in the DB once per minute and deletes any that have expired. 44 func newTokenStore(ctx context.Context, db db.Basic) oauth2.TokenStore { 45 ts := &tokenStore{ 46 db: db, 47 } 48 49 // set the token store to clean out expired tokens once per minute, or return if we're done 50 go func(ctx context.Context, ts *tokenStore) { 51 cleanloop: 52 for { 53 select { 54 case <-ctx.Done(): 55 log.Info(ctx, "breaking cleanloop") 56 break cleanloop 57 case <-time.After(1 * time.Minute): 58 log.Trace(ctx, "sweeping out old oauth entries broom broom") 59 if err := ts.sweep(ctx); err != nil { 60 log.Errorf(ctx, "error while sweeping oauth entries: %s", err) 61 } 62 } 63 } 64 }(ctx, ts) 65 return ts 66 } 67 68 // sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so. 69 func (ts *tokenStore) sweep(ctx context.Context) error { 70 // select *all* tokens from the db 71 // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. 72 tokens := new([]*gtsmodel.Token) 73 if err := ts.db.GetAll(ctx, tokens); err != nil { 74 return err 75 } 76 77 // iterate through and remove expired tokens 78 now := time.Now() 79 for _, dbt := range *tokens { 80 // The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So: 81 // we only want to check if a token expired before now if the expiry time is *not zero*; 82 // ie., if it's been explicity set. 83 if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) { 84 if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil { 85 return err 86 } 87 } 88 } 89 90 return nil 91 } 92 93 // Create creates and store the new token information. 94 // For the original implementation, see https://github.com/superseriousbusiness/oauth2/blob/master/store/token.go#L34 95 func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { 96 t, ok := info.(*models.Token) 97 if !ok { 98 return errors.New("info param was not a models.Token") 99 } 100 101 dbt := TokenToDBToken(t) 102 if dbt.ID == "" { 103 dbtID, err := id.NewRandomULID() 104 if err != nil { 105 return err 106 } 107 dbt.ID = dbtID 108 } 109 110 if err := ts.db.Put(ctx, dbt); err != nil { 111 return fmt.Errorf("error in tokenstore create: %s", err) 112 } 113 return nil 114 } 115 116 // RemoveByCode deletes a token from the DB based on the Code field 117 func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error { 118 return ts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, >smodel.Token{}) 119 } 120 121 // RemoveByAccess deletes a token from the DB based on the Access field 122 func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { 123 return ts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, >smodel.Token{}) 124 } 125 126 // RemoveByRefresh deletes a token from the DB based on the Refresh field 127 func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { 128 return ts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, >smodel.Token{}) 129 } 130 131 // GetByCode selects a token from the DB based on the Code field 132 func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { 133 if code == "" { 134 return nil, nil 135 } 136 dbt := >smodel.Token{ 137 Code: code, 138 } 139 if err := ts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, dbt); err != nil { 140 return nil, err 141 } 142 return DBTokenToToken(dbt), nil 143 } 144 145 // GetByAccess selects a token from the DB based on the Access field 146 func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { 147 if access == "" { 148 return nil, nil 149 } 150 dbt := >smodel.Token{ 151 Access: access, 152 } 153 if err := ts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, dbt); err != nil { 154 return nil, err 155 } 156 return DBTokenToToken(dbt), nil 157 } 158 159 // GetByRefresh selects a token from the DB based on the Refresh field 160 func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { 161 if refresh == "" { 162 return nil, nil 163 } 164 dbt := >smodel.Token{ 165 Refresh: refresh, 166 } 167 if err := ts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, dbt); err != nil { 168 return nil, err 169 } 170 return DBTokenToToken(dbt), nil 171 } 172 173 /* 174 The following models are basically helpers for the token store implementation, they should only be used internally. 175 */ 176 177 // TokenToDBToken is a lil util function that takes a gotosocial token and gives back a token for inserting into a database. 178 func TokenToDBToken(tkn *models.Token) *gtsmodel.Token { 179 now := time.Now() 180 181 // For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's 182 // going to cause all sorts of interesting problems. So check first to make sure that the ExpiresIn is not equal 183 // to the zero value of a time.Duration, which is 0s. If it *is* empty/nil, just leave the ExpiresAt at nil as well. 184 185 cea := time.Time{} 186 if tkn.CodeExpiresIn != 0*time.Second { 187 cea = now.Add(tkn.CodeExpiresIn) 188 } 189 190 aea := time.Time{} 191 if tkn.AccessExpiresIn != 0*time.Second { 192 aea = now.Add(tkn.AccessExpiresIn) 193 } 194 195 rea := time.Time{} 196 if tkn.RefreshExpiresIn != 0*time.Second { 197 rea = now.Add(tkn.RefreshExpiresIn) 198 } 199 200 return >smodel.Token{ 201 ClientID: tkn.ClientID, 202 UserID: tkn.UserID, 203 RedirectURI: tkn.RedirectURI, 204 Scope: tkn.Scope, 205 Code: tkn.Code, 206 CodeChallenge: tkn.CodeChallenge, 207 CodeChallengeMethod: tkn.CodeChallengeMethod, 208 CodeCreateAt: tkn.CodeCreateAt, 209 CodeExpiresAt: cea, 210 Access: tkn.Access, 211 AccessCreateAt: tkn.AccessCreateAt, 212 AccessExpiresAt: aea, 213 Refresh: tkn.Refresh, 214 RefreshCreateAt: tkn.RefreshCreateAt, 215 RefreshExpiresAt: rea, 216 } 217 } 218 219 // DBTokenToToken is a lil util function that takes a database token and gives back a gotosocial token 220 func DBTokenToToken(dbt *gtsmodel.Token) *models.Token { 221 now := time.Now() 222 223 var codeExpiresIn time.Duration 224 if !dbt.CodeExpiresAt.IsZero() { 225 codeExpiresIn = dbt.CodeExpiresAt.Sub(now) 226 } 227 228 var accessExpiresIn time.Duration 229 if !dbt.AccessExpiresAt.IsZero() { 230 accessExpiresIn = dbt.AccessExpiresAt.Sub(now) 231 } 232 233 var refreshExpiresIn time.Duration 234 if !dbt.RefreshExpiresAt.IsZero() { 235 refreshExpiresIn = dbt.RefreshExpiresAt.Sub(now) 236 } 237 238 return &models.Token{ 239 ClientID: dbt.ClientID, 240 UserID: dbt.UserID, 241 RedirectURI: dbt.RedirectURI, 242 Scope: dbt.Scope, 243 Code: dbt.Code, 244 CodeChallenge: dbt.CodeChallenge, 245 CodeChallengeMethod: dbt.CodeChallengeMethod, 246 CodeCreateAt: dbt.CodeCreateAt, 247 CodeExpiresIn: codeExpiresIn, 248 Access: dbt.Access, 249 AccessCreateAt: dbt.AccessCreateAt, 250 AccessExpiresIn: accessExpiresIn, 251 Refresh: dbt.Refresh, 252 RefreshCreateAt: dbt.RefreshCreateAt, 253 RefreshExpiresIn: refreshExpiresIn, 254 } 255 }