manager.go (13844B)
1 package manage 2 3 import ( 4 "context" 5 "time" 6 7 "github.com/superseriousbusiness/oauth2/v4" 8 "github.com/superseriousbusiness/oauth2/v4/errors" 9 "github.com/superseriousbusiness/oauth2/v4/generates" 10 "github.com/superseriousbusiness/oauth2/v4/models" 11 ) 12 13 // NewDefaultManager create to default authorization management instance 14 func NewDefaultManager() *Manager { 15 m := NewManager() 16 // default implementation 17 m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) 18 m.MapAccessGenerate(generates.NewAccessGenerate()) 19 20 return m 21 } 22 23 // NewManager create to authorization management instance 24 func NewManager() *Manager { 25 return &Manager{ 26 gtcfg: make(map[oauth2.GrantType]*Config), 27 validateURI: DefaultValidateURI, 28 } 29 } 30 31 // Manager provide authorization management 32 type Manager struct { 33 codeExp time.Duration 34 gtcfg map[oauth2.GrantType]*Config 35 rcfg *RefreshingConfig 36 validateURI ValidateURIHandler 37 authorizeGenerate oauth2.AuthorizeGenerate 38 accessGenerate oauth2.AccessGenerate 39 tokenStore oauth2.TokenStore 40 clientStore oauth2.ClientStore 41 } 42 43 // get grant type config 44 func (m *Manager) grantConfig(gt oauth2.GrantType) *Config { 45 if c, ok := m.gtcfg[gt]; ok && c != nil { 46 return c 47 } 48 switch gt { 49 case oauth2.AuthorizationCode: 50 return DefaultAuthorizeCodeTokenCfg 51 case oauth2.Implicit: 52 return DefaultImplicitTokenCfg 53 case oauth2.PasswordCredentials: 54 return DefaultPasswordTokenCfg 55 case oauth2.ClientCredentials: 56 return DefaultClientTokenCfg 57 } 58 return &Config{} 59 } 60 61 // SetAuthorizeCodeExp set the authorization code expiration time 62 func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) { 63 m.codeExp = exp 64 } 65 66 // SetAuthorizeCodeTokenCfg set the authorization code grant token config 67 func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) { 68 m.gtcfg[oauth2.AuthorizationCode] = cfg 69 } 70 71 // SetImplicitTokenCfg set the implicit grant token config 72 func (m *Manager) SetImplicitTokenCfg(cfg *Config) { 73 m.gtcfg[oauth2.Implicit] = cfg 74 } 75 76 // SetPasswordTokenCfg set the password grant token config 77 func (m *Manager) SetPasswordTokenCfg(cfg *Config) { 78 m.gtcfg[oauth2.PasswordCredentials] = cfg 79 } 80 81 // SetClientTokenCfg set the client grant token config 82 func (m *Manager) SetClientTokenCfg(cfg *Config) { 83 m.gtcfg[oauth2.ClientCredentials] = cfg 84 } 85 86 // SetRefreshTokenCfg set the refreshing token config 87 func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) { 88 m.rcfg = cfg 89 } 90 91 // SetValidateURIHandler set the validates that RedirectURI is contained in baseURI 92 func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) { 93 m.validateURI = handler 94 } 95 96 // MapAuthorizeGenerate mapping the authorize code generate interface 97 func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { 98 m.authorizeGenerate = gen 99 } 100 101 // MapAccessGenerate mapping the access token generate interface 102 func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) { 103 m.accessGenerate = gen 104 } 105 106 // MapClientStorage mapping the client store interface 107 func (m *Manager) MapClientStorage(stor oauth2.ClientStore) { 108 m.clientStore = stor 109 } 110 111 // MustClientStorage mandatory mapping the client store interface 112 func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) { 113 if err != nil { 114 panic(err.Error()) 115 } 116 m.clientStore = stor 117 } 118 119 // MapTokenStorage mapping the token store interface 120 func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) { 121 m.tokenStore = stor 122 } 123 124 // MustTokenStorage mandatory mapping the token store interface 125 func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { 126 if err != nil { 127 panic(err) 128 } 129 m.tokenStore = stor 130 } 131 132 // GetClient get the client information 133 func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) { 134 cli, err = m.clientStore.GetByID(ctx, clientID) 135 if err != nil { 136 return 137 } else if cli == nil { 138 err = errors.ErrInvalidClient 139 } 140 return 141 } 142 143 // GenerateAuthToken generate the authorization token(code) 144 func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { 145 cli, err := m.GetClient(ctx, tgr.ClientID) 146 if err != nil { 147 return nil, err 148 } else if tgr.RedirectURI != "" { 149 if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { 150 return nil, err 151 } 152 } 153 154 ti := models.NewToken() 155 ti.SetClientID(tgr.ClientID) 156 ti.SetUserID(tgr.UserID) 157 ti.SetRedirectURI(tgr.RedirectURI) 158 ti.SetScope(tgr.Scope) 159 160 createAt := time.Now() 161 td := &oauth2.GenerateBasic{ 162 Client: cli, 163 UserID: tgr.UserID, 164 CreateAt: createAt, 165 TokenInfo: ti, 166 Request: tgr.Request, 167 } 168 switch rt { 169 case oauth2.Code: 170 codeExp := m.codeExp 171 if codeExp == 0 { 172 codeExp = DefaultCodeExp 173 } 174 ti.SetCodeCreateAt(createAt) 175 ti.SetCodeExpiresIn(codeExp) 176 if exp := tgr.AccessTokenExp; exp > 0 { 177 ti.SetAccessExpiresIn(exp) 178 } 179 if tgr.CodeChallenge != "" { 180 ti.SetCodeChallenge(tgr.CodeChallenge) 181 ti.SetCodeChallengeMethod(tgr.CodeChallengeMethod) 182 } 183 184 tv, err := m.authorizeGenerate.Token(ctx, td) 185 if err != nil { 186 return nil, err 187 } 188 ti.SetCode(tv) 189 case oauth2.Token: 190 // set access token expires 191 icfg := m.grantConfig(oauth2.Implicit) 192 aexp := icfg.AccessTokenExp 193 if exp := tgr.AccessTokenExp; exp > 0 { 194 aexp = exp 195 } 196 ti.SetAccessCreateAt(createAt) 197 ti.SetAccessExpiresIn(aexp) 198 199 if icfg.IsGenerateRefresh { 200 ti.SetRefreshCreateAt(createAt) 201 ti.SetRefreshExpiresIn(icfg.RefreshTokenExp) 202 } 203 204 tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh) 205 if err != nil { 206 return nil, err 207 } 208 ti.SetAccess(tv) 209 210 if rv != "" { 211 ti.SetRefresh(rv) 212 } 213 } 214 215 err = m.tokenStore.Create(ctx, ti) 216 if err != nil { 217 return nil, err 218 } 219 return ti, nil 220 } 221 222 // get authorization code data 223 func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { 224 ti, err := m.tokenStore.GetByCode(ctx, code) 225 if err != nil { 226 return nil, err 227 } else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) { 228 err = errors.ErrInvalidAuthorizeCode 229 return nil, errors.ErrInvalidAuthorizeCode 230 } 231 return ti, nil 232 } 233 234 // delete authorization code data 235 func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error { 236 return m.tokenStore.RemoveByCode(ctx, code) 237 } 238 239 // get and delete authorization code data 240 func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { 241 code := tgr.Code 242 ti, err := m.getAuthorizationCode(ctx, code) 243 if err != nil { 244 return nil, err 245 } else if ti.GetClientID() != tgr.ClientID { 246 return nil, errors.ErrInvalidAuthorizeCode 247 } else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI { 248 return nil, errors.ErrInvalidAuthorizeCode 249 } 250 251 err = m.delAuthorizationCode(ctx, code) 252 if err != nil { 253 return nil, err 254 } 255 return ti, nil 256 } 257 258 func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error { 259 cc := ti.GetCodeChallenge() 260 // early return 261 if cc == "" && ver == "" { 262 return nil 263 } 264 if cc == "" { 265 return errors.ErrMissingCodeVerifier 266 } 267 if ver == "" { 268 return errors.ErrMissingCodeVerifier 269 } 270 ccm := ti.GetCodeChallengeMethod() 271 if ccm.String() == "" { 272 ccm = oauth2.CodeChallengePlain 273 } 274 if !ccm.Validate(cc, ver) { 275 return errors.ErrInvalidCodeChallenge 276 } 277 return nil 278 } 279 280 // GenerateAccessToken generate the access token 281 func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { 282 cli, err := m.GetClient(ctx, tgr.ClientID) 283 if err != nil { 284 return nil, err 285 } 286 if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok { 287 if !cliPass.VerifyPassword(tgr.ClientSecret) { 288 return nil, errors.ErrInvalidClient 289 } 290 } else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() { 291 return nil, errors.ErrInvalidClient 292 } 293 if tgr.RedirectURI != "" { 294 if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { 295 return nil, err 296 } 297 } 298 299 if gt == oauth2.AuthorizationCode { 300 ti, err := m.getAndDelAuthorizationCode(ctx, tgr) 301 if err != nil { 302 return nil, err 303 } 304 if err := m.validateCodeChallenge(ti, tgr.CodeVerifier); err != nil { 305 return nil, err 306 } 307 tgr.UserID = ti.GetUserID() 308 tgr.Scope = ti.GetScope() 309 if exp := ti.GetAccessExpiresIn(); exp > 0 { 310 tgr.AccessTokenExp = exp 311 } 312 } 313 314 ti := models.NewToken() 315 ti.SetClientID(tgr.ClientID) 316 ti.SetUserID(tgr.UserID) 317 ti.SetRedirectURI(tgr.RedirectURI) 318 ti.SetScope(tgr.Scope) 319 320 createAt := time.Now() 321 ti.SetAccessCreateAt(createAt) 322 323 // set access token expires 324 gcfg := m.grantConfig(gt) 325 aexp := gcfg.AccessTokenExp 326 if exp := tgr.AccessTokenExp; exp > 0 { 327 aexp = exp 328 } 329 ti.SetAccessExpiresIn(aexp) 330 if gcfg.IsGenerateRefresh { 331 ti.SetRefreshCreateAt(createAt) 332 ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp) 333 } 334 335 td := &oauth2.GenerateBasic{ 336 Client: cli, 337 UserID: tgr.UserID, 338 CreateAt: createAt, 339 TokenInfo: ti, 340 Request: tgr.Request, 341 } 342 343 av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh) 344 if err != nil { 345 return nil, err 346 } 347 ti.SetAccess(av) 348 349 if rv != "" { 350 ti.SetRefresh(rv) 351 } 352 353 err = m.tokenStore.Create(ctx, ti) 354 if err != nil { 355 return nil, err 356 } 357 358 return ti, nil 359 } 360 361 // RefreshAccessToken refreshing an access token 362 func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { 363 cli, err := m.GetClient(ctx, tgr.ClientID) 364 if err != nil { 365 return nil, err 366 } else if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok { 367 if !cliPass.VerifyPassword(tgr.ClientSecret) { 368 return nil, errors.ErrInvalidClient 369 } 370 } else if tgr.ClientSecret != cli.GetSecret() { 371 return nil, errors.ErrInvalidClient 372 } 373 374 ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) 375 if err != nil { 376 return nil, err 377 } else if ti.GetClientID() != tgr.ClientID { 378 return nil, errors.ErrInvalidRefreshToken 379 } 380 381 oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh() 382 383 td := &oauth2.GenerateBasic{ 384 Client: cli, 385 UserID: ti.GetUserID(), 386 CreateAt: time.Now(), 387 TokenInfo: ti, 388 Request: tgr.Request, 389 } 390 391 rcfg := DefaultRefreshTokenCfg 392 if v := m.rcfg; v != nil { 393 rcfg = v 394 } 395 396 ti.SetAccessCreateAt(td.CreateAt) 397 if v := rcfg.AccessTokenExp; v > 0 { 398 ti.SetAccessExpiresIn(v) 399 } 400 401 if v := rcfg.RefreshTokenExp; v > 0 { 402 ti.SetRefreshExpiresIn(v) 403 } 404 405 if rcfg.IsResetRefreshTime { 406 ti.SetRefreshCreateAt(td.CreateAt) 407 } 408 409 if scope := tgr.Scope; scope != "" { 410 ti.SetScope(scope) 411 } 412 413 tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh) 414 if err != nil { 415 return nil, err 416 } 417 418 ti.SetAccess(tv) 419 if rv != "" { 420 ti.SetRefresh(rv) 421 } 422 423 if err := m.tokenStore.Create(ctx, ti); err != nil { 424 return nil, err 425 } 426 427 if rcfg.IsRemoveAccess { 428 // remove the old access token 429 if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil { 430 return nil, err 431 } 432 } 433 434 if rcfg.IsRemoveRefreshing && rv != "" { 435 // remove the old refresh token 436 if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil { 437 return nil, err 438 } 439 } 440 441 if rv == "" { 442 ti.SetRefresh("") 443 ti.SetRefreshCreateAt(time.Now()) 444 ti.SetRefreshExpiresIn(0) 445 } 446 447 return ti, nil 448 } 449 450 // RemoveAccessToken use the access token to delete the token information 451 func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error { 452 if access == "" { 453 return errors.ErrInvalidAccessToken 454 } 455 return m.tokenStore.RemoveByAccess(ctx, access) 456 } 457 458 // RemoveRefreshToken use the refresh token to delete the token information 459 func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error { 460 if refresh == "" { 461 return errors.ErrInvalidAccessToken 462 } 463 return m.tokenStore.RemoveByRefresh(ctx, refresh) 464 } 465 466 // LoadAccessToken according to the access token for corresponding token information 467 func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) { 468 if access == "" { 469 return nil, errors.ErrInvalidAccessToken 470 } 471 472 ct := time.Now() 473 ti, err := m.tokenStore.GetByAccess(ctx, access) 474 if err != nil { 475 return nil, err 476 } else if ti == nil || ti.GetAccess() != access { 477 return nil, errors.ErrInvalidAccessToken 478 } else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 && 479 ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { 480 return nil, errors.ErrExpiredRefreshToken 481 } else if ti.GetAccessExpiresIn() != 0 && 482 ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { 483 return nil, errors.ErrExpiredAccessToken 484 } 485 return ti, nil 486 } 487 488 // LoadRefreshToken according to the refresh token for corresponding token information 489 func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { 490 if refresh == "" { 491 return nil, errors.ErrInvalidRefreshToken 492 } 493 494 ti, err := m.tokenStore.GetByRefresh(ctx, refresh) 495 if err != nil { 496 return nil, err 497 } else if ti == nil || ti.GetRefresh() != refresh { 498 return nil, errors.ErrInvalidRefreshToken 499 } else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire 500 ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { 501 return nil, errors.ErrExpiredRefreshToken 502 } 503 return ti, nil 504 }