gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

manager.go (13844B)


      1 package manage
      2 
      3 import (
      4 	"context"
      5 	"time"
      6 
      7 	"github.com/superseriousbusiness/oauth2/v4"
      8 	"github.com/superseriousbusiness/oauth2/v4/errors"
      9 	"github.com/superseriousbusiness/oauth2/v4/generates"
     10 	"github.com/superseriousbusiness/oauth2/v4/models"
     11 )
     12 
     13 // NewDefaultManager create to default authorization management instance
     14 func NewDefaultManager() *Manager {
     15 	m := NewManager()
     16 	// default implementation
     17 	m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate())
     18 	m.MapAccessGenerate(generates.NewAccessGenerate())
     19 
     20 	return m
     21 }
     22 
     23 // NewManager create to authorization management instance
     24 func NewManager() *Manager {
     25 	return &Manager{
     26 		gtcfg:       make(map[oauth2.GrantType]*Config),
     27 		validateURI: DefaultValidateURI,
     28 	}
     29 }
     30 
     31 // Manager provide authorization management
     32 type Manager struct {
     33 	codeExp           time.Duration
     34 	gtcfg             map[oauth2.GrantType]*Config
     35 	rcfg              *RefreshingConfig
     36 	validateURI       ValidateURIHandler
     37 	authorizeGenerate oauth2.AuthorizeGenerate
     38 	accessGenerate    oauth2.AccessGenerate
     39 	tokenStore        oauth2.TokenStore
     40 	clientStore       oauth2.ClientStore
     41 }
     42 
     43 // get grant type config
     44 func (m *Manager) grantConfig(gt oauth2.GrantType) *Config {
     45 	if c, ok := m.gtcfg[gt]; ok && c != nil {
     46 		return c
     47 	}
     48 	switch gt {
     49 	case oauth2.AuthorizationCode:
     50 		return DefaultAuthorizeCodeTokenCfg
     51 	case oauth2.Implicit:
     52 		return DefaultImplicitTokenCfg
     53 	case oauth2.PasswordCredentials:
     54 		return DefaultPasswordTokenCfg
     55 	case oauth2.ClientCredentials:
     56 		return DefaultClientTokenCfg
     57 	}
     58 	return &Config{}
     59 }
     60 
     61 // SetAuthorizeCodeExp set the authorization code expiration time
     62 func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) {
     63 	m.codeExp = exp
     64 }
     65 
     66 // SetAuthorizeCodeTokenCfg set the authorization code grant token config
     67 func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) {
     68 	m.gtcfg[oauth2.AuthorizationCode] = cfg
     69 }
     70 
     71 // SetImplicitTokenCfg set the implicit grant token config
     72 func (m *Manager) SetImplicitTokenCfg(cfg *Config) {
     73 	m.gtcfg[oauth2.Implicit] = cfg
     74 }
     75 
     76 // SetPasswordTokenCfg set the password grant token config
     77 func (m *Manager) SetPasswordTokenCfg(cfg *Config) {
     78 	m.gtcfg[oauth2.PasswordCredentials] = cfg
     79 }
     80 
     81 // SetClientTokenCfg set the client grant token config
     82 func (m *Manager) SetClientTokenCfg(cfg *Config) {
     83 	m.gtcfg[oauth2.ClientCredentials] = cfg
     84 }
     85 
     86 // SetRefreshTokenCfg set the refreshing token config
     87 func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) {
     88 	m.rcfg = cfg
     89 }
     90 
     91 // SetValidateURIHandler set the validates that RedirectURI is contained in baseURI
     92 func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) {
     93 	m.validateURI = handler
     94 }
     95 
     96 // MapAuthorizeGenerate mapping the authorize code generate interface
     97 func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) {
     98 	m.authorizeGenerate = gen
     99 }
    100 
    101 // MapAccessGenerate mapping the access token generate interface
    102 func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) {
    103 	m.accessGenerate = gen
    104 }
    105 
    106 // MapClientStorage mapping the client store interface
    107 func (m *Manager) MapClientStorage(stor oauth2.ClientStore) {
    108 	m.clientStore = stor
    109 }
    110 
    111 // MustClientStorage mandatory mapping the client store interface
    112 func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) {
    113 	if err != nil {
    114 		panic(err.Error())
    115 	}
    116 	m.clientStore = stor
    117 }
    118 
    119 // MapTokenStorage mapping the token store interface
    120 func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) {
    121 	m.tokenStore = stor
    122 }
    123 
    124 // MustTokenStorage mandatory mapping the token store interface
    125 func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) {
    126 	if err != nil {
    127 		panic(err)
    128 	}
    129 	m.tokenStore = stor
    130 }
    131 
    132 // GetClient get the client information
    133 func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) {
    134 	cli, err = m.clientStore.GetByID(ctx, clientID)
    135 	if err != nil {
    136 		return
    137 	} else if cli == nil {
    138 		err = errors.ErrInvalidClient
    139 	}
    140 	return
    141 }
    142 
    143 // GenerateAuthToken generate the authorization token(code)
    144 func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
    145 	cli, err := m.GetClient(ctx, tgr.ClientID)
    146 	if err != nil {
    147 		return nil, err
    148 	} else if tgr.RedirectURI != "" {
    149 		if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil {
    150 			return nil, err
    151 		}
    152 	}
    153 
    154 	ti := models.NewToken()
    155 	ti.SetClientID(tgr.ClientID)
    156 	ti.SetUserID(tgr.UserID)
    157 	ti.SetRedirectURI(tgr.RedirectURI)
    158 	ti.SetScope(tgr.Scope)
    159 
    160 	createAt := time.Now()
    161 	td := &oauth2.GenerateBasic{
    162 		Client:    cli,
    163 		UserID:    tgr.UserID,
    164 		CreateAt:  createAt,
    165 		TokenInfo: ti,
    166 		Request:   tgr.Request,
    167 	}
    168 	switch rt {
    169 	case oauth2.Code:
    170 		codeExp := m.codeExp
    171 		if codeExp == 0 {
    172 			codeExp = DefaultCodeExp
    173 		}
    174 		ti.SetCodeCreateAt(createAt)
    175 		ti.SetCodeExpiresIn(codeExp)
    176 		if exp := tgr.AccessTokenExp; exp > 0 {
    177 			ti.SetAccessExpiresIn(exp)
    178 		}
    179 		if tgr.CodeChallenge != "" {
    180 			ti.SetCodeChallenge(tgr.CodeChallenge)
    181 			ti.SetCodeChallengeMethod(tgr.CodeChallengeMethod)
    182 		}
    183 
    184 		tv, err := m.authorizeGenerate.Token(ctx, td)
    185 		if err != nil {
    186 			return nil, err
    187 		}
    188 		ti.SetCode(tv)
    189 	case oauth2.Token:
    190 		// set access token expires
    191 		icfg := m.grantConfig(oauth2.Implicit)
    192 		aexp := icfg.AccessTokenExp
    193 		if exp := tgr.AccessTokenExp; exp > 0 {
    194 			aexp = exp
    195 		}
    196 		ti.SetAccessCreateAt(createAt)
    197 		ti.SetAccessExpiresIn(aexp)
    198 
    199 		if icfg.IsGenerateRefresh {
    200 			ti.SetRefreshCreateAt(createAt)
    201 			ti.SetRefreshExpiresIn(icfg.RefreshTokenExp)
    202 		}
    203 
    204 		tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh)
    205 		if err != nil {
    206 			return nil, err
    207 		}
    208 		ti.SetAccess(tv)
    209 
    210 		if rv != "" {
    211 			ti.SetRefresh(rv)
    212 		}
    213 	}
    214 
    215 	err = m.tokenStore.Create(ctx, ti)
    216 	if err != nil {
    217 		return nil, err
    218 	}
    219 	return ti, nil
    220 }
    221 
    222 // get authorization code data
    223 func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
    224 	ti, err := m.tokenStore.GetByCode(ctx, code)
    225 	if err != nil {
    226 		return nil, err
    227 	} else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) {
    228 		err = errors.ErrInvalidAuthorizeCode
    229 		return nil, errors.ErrInvalidAuthorizeCode
    230 	}
    231 	return ti, nil
    232 }
    233 
    234 // delete authorization code data
    235 func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error {
    236 	return m.tokenStore.RemoveByCode(ctx, code)
    237 }
    238 
    239 // get and delete authorization code data
    240 func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
    241 	code := tgr.Code
    242 	ti, err := m.getAuthorizationCode(ctx, code)
    243 	if err != nil {
    244 		return nil, err
    245 	} else if ti.GetClientID() != tgr.ClientID {
    246 		return nil, errors.ErrInvalidAuthorizeCode
    247 	} else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI {
    248 		return nil, errors.ErrInvalidAuthorizeCode
    249 	}
    250 
    251 	err = m.delAuthorizationCode(ctx, code)
    252 	if err != nil {
    253 		return nil, err
    254 	}
    255 	return ti, nil
    256 }
    257 
    258 func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error {
    259 	cc := ti.GetCodeChallenge()
    260 	// early return
    261 	if cc == "" && ver == "" {
    262 		return nil
    263 	}
    264 	if cc == "" {
    265 		return errors.ErrMissingCodeVerifier
    266 	}
    267 	if ver == "" {
    268 		return errors.ErrMissingCodeVerifier
    269 	}
    270 	ccm := ti.GetCodeChallengeMethod()
    271 	if ccm.String() == "" {
    272 		ccm = oauth2.CodeChallengePlain
    273 	}
    274 	if !ccm.Validate(cc, ver) {
    275 		return errors.ErrInvalidCodeChallenge
    276 	}
    277 	return nil
    278 }
    279 
    280 // GenerateAccessToken generate the access token
    281 func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
    282 	cli, err := m.GetClient(ctx, tgr.ClientID)
    283 	if err != nil {
    284 		return nil, err
    285 	}
    286 	if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok {
    287 		if !cliPass.VerifyPassword(tgr.ClientSecret) {
    288 			return nil, errors.ErrInvalidClient
    289 		}
    290 	} else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() {
    291 		return nil, errors.ErrInvalidClient
    292 	}
    293 	if tgr.RedirectURI != "" {
    294 		if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil {
    295 			return nil, err
    296 		}
    297 	}
    298 
    299 	if gt == oauth2.AuthorizationCode {
    300 		ti, err := m.getAndDelAuthorizationCode(ctx, tgr)
    301 		if err != nil {
    302 			return nil, err
    303 		}
    304 		if err := m.validateCodeChallenge(ti, tgr.CodeVerifier); err != nil {
    305 			return nil, err
    306 		}
    307 		tgr.UserID = ti.GetUserID()
    308 		tgr.Scope = ti.GetScope()
    309 		if exp := ti.GetAccessExpiresIn(); exp > 0 {
    310 			tgr.AccessTokenExp = exp
    311 		}
    312 	}
    313 
    314 	ti := models.NewToken()
    315 	ti.SetClientID(tgr.ClientID)
    316 	ti.SetUserID(tgr.UserID)
    317 	ti.SetRedirectURI(tgr.RedirectURI)
    318 	ti.SetScope(tgr.Scope)
    319 
    320 	createAt := time.Now()
    321 	ti.SetAccessCreateAt(createAt)
    322 
    323 	// set access token expires
    324 	gcfg := m.grantConfig(gt)
    325 	aexp := gcfg.AccessTokenExp
    326 	if exp := tgr.AccessTokenExp; exp > 0 {
    327 		aexp = exp
    328 	}
    329 	ti.SetAccessExpiresIn(aexp)
    330 	if gcfg.IsGenerateRefresh {
    331 		ti.SetRefreshCreateAt(createAt)
    332 		ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp)
    333 	}
    334 
    335 	td := &oauth2.GenerateBasic{
    336 		Client:    cli,
    337 		UserID:    tgr.UserID,
    338 		CreateAt:  createAt,
    339 		TokenInfo: ti,
    340 		Request:   tgr.Request,
    341 	}
    342 
    343 	av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh)
    344 	if err != nil {
    345 		return nil, err
    346 	}
    347 	ti.SetAccess(av)
    348 
    349 	if rv != "" {
    350 		ti.SetRefresh(rv)
    351 	}
    352 
    353 	err = m.tokenStore.Create(ctx, ti)
    354 	if err != nil {
    355 		return nil, err
    356 	}
    357 
    358 	return ti, nil
    359 }
    360 
    361 // RefreshAccessToken refreshing an access token
    362 func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
    363 	cli, err := m.GetClient(ctx, tgr.ClientID)
    364 	if err != nil {
    365 		return nil, err
    366 	} else if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok {
    367 		if !cliPass.VerifyPassword(tgr.ClientSecret) {
    368 			return nil, errors.ErrInvalidClient
    369 		}
    370 	} else if tgr.ClientSecret != cli.GetSecret() {
    371 		return nil, errors.ErrInvalidClient
    372 	}
    373 
    374 	ti, err := m.LoadRefreshToken(ctx, tgr.Refresh)
    375 	if err != nil {
    376 		return nil, err
    377 	} else if ti.GetClientID() != tgr.ClientID {
    378 		return nil, errors.ErrInvalidRefreshToken
    379 	}
    380 
    381 	oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh()
    382 
    383 	td := &oauth2.GenerateBasic{
    384 		Client:    cli,
    385 		UserID:    ti.GetUserID(),
    386 		CreateAt:  time.Now(),
    387 		TokenInfo: ti,
    388 		Request:   tgr.Request,
    389 	}
    390 
    391 	rcfg := DefaultRefreshTokenCfg
    392 	if v := m.rcfg; v != nil {
    393 		rcfg = v
    394 	}
    395 
    396 	ti.SetAccessCreateAt(td.CreateAt)
    397 	if v := rcfg.AccessTokenExp; v > 0 {
    398 		ti.SetAccessExpiresIn(v)
    399 	}
    400 
    401 	if v := rcfg.RefreshTokenExp; v > 0 {
    402 		ti.SetRefreshExpiresIn(v)
    403 	}
    404 
    405 	if rcfg.IsResetRefreshTime {
    406 		ti.SetRefreshCreateAt(td.CreateAt)
    407 	}
    408 
    409 	if scope := tgr.Scope; scope != "" {
    410 		ti.SetScope(scope)
    411 	}
    412 
    413 	tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh)
    414 	if err != nil {
    415 		return nil, err
    416 	}
    417 
    418 	ti.SetAccess(tv)
    419 	if rv != "" {
    420 		ti.SetRefresh(rv)
    421 	}
    422 
    423 	if err := m.tokenStore.Create(ctx, ti); err != nil {
    424 		return nil, err
    425 	}
    426 
    427 	if rcfg.IsRemoveAccess {
    428 		// remove the old access token
    429 		if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil {
    430 			return nil, err
    431 		}
    432 	}
    433 
    434 	if rcfg.IsRemoveRefreshing && rv != "" {
    435 		// remove the old refresh token
    436 		if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil {
    437 			return nil, err
    438 		}
    439 	}
    440 
    441 	if rv == "" {
    442 		ti.SetRefresh("")
    443 		ti.SetRefreshCreateAt(time.Now())
    444 		ti.SetRefreshExpiresIn(0)
    445 	}
    446 
    447 	return ti, nil
    448 }
    449 
    450 // RemoveAccessToken use the access token to delete the token information
    451 func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error {
    452 	if access == "" {
    453 		return errors.ErrInvalidAccessToken
    454 	}
    455 	return m.tokenStore.RemoveByAccess(ctx, access)
    456 }
    457 
    458 // RemoveRefreshToken use the refresh token to delete the token information
    459 func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error {
    460 	if refresh == "" {
    461 		return errors.ErrInvalidAccessToken
    462 	}
    463 	return m.tokenStore.RemoveByRefresh(ctx, refresh)
    464 }
    465 
    466 // LoadAccessToken according to the access token for corresponding token information
    467 func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) {
    468 	if access == "" {
    469 		return nil, errors.ErrInvalidAccessToken
    470 	}
    471 
    472 	ct := time.Now()
    473 	ti, err := m.tokenStore.GetByAccess(ctx, access)
    474 	if err != nil {
    475 		return nil, err
    476 	} else if ti == nil || ti.GetAccess() != access {
    477 		return nil, errors.ErrInvalidAccessToken
    478 	} else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 &&
    479 		ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
    480 		return nil, errors.ErrExpiredRefreshToken
    481 	} else if ti.GetAccessExpiresIn() != 0 &&
    482 		ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {
    483 		return nil, errors.ErrExpiredAccessToken
    484 	}
    485 	return ti, nil
    486 }
    487 
    488 // LoadRefreshToken according to the refresh token for corresponding token information
    489 func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
    490 	if refresh == "" {
    491 		return nil, errors.ErrInvalidRefreshToken
    492 	}
    493 
    494 	ti, err := m.tokenStore.GetByRefresh(ctx, refresh)
    495 	if err != nil {
    496 		return nil, err
    497 	} else if ti == nil || ti.GetRefresh() != refresh {
    498 		return nil, errors.ErrInvalidRefreshToken
    499 	} else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire
    500 		ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) {
    501 		return nil, errors.ErrExpiredRefreshToken
    502 	}
    503 	return ti, nil
    504 }