server.go (15161B)
1 package server 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "net/http" 8 "net/url" 9 "strings" 10 "time" 11 12 "github.com/superseriousbusiness/oauth2/v4" 13 "github.com/superseriousbusiness/oauth2/v4/errors" 14 ) 15 16 // NewDefaultServer create a default authorization server 17 func NewDefaultServer(manager oauth2.Manager) *Server { 18 return NewServer(NewConfig(), manager) 19 } 20 21 // NewServer create authorization server 22 func NewServer(cfg *Config, manager oauth2.Manager) *Server { 23 srv := &Server{ 24 Config: cfg, 25 Manager: manager, 26 } 27 28 // default handler 29 srv.ClientInfoHandler = ClientBasicHandler 30 31 srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { 32 return "", errors.ErrAccessDenied 33 } 34 35 srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { 36 return "", errors.ErrAccessDenied 37 } 38 return srv 39 } 40 41 // Server Provide authorization server 42 type Server struct { 43 Config *Config 44 Manager oauth2.Manager 45 ClientInfoHandler ClientInfoHandler 46 ClientAuthorizedHandler ClientAuthorizedHandler 47 ClientScopeHandler ClientScopeHandler 48 UserAuthorizationHandler UserAuthorizationHandler 49 PasswordAuthorizationHandler PasswordAuthorizationHandler 50 RefreshingValidationHandler RefreshingValidationHandler 51 RefreshingScopeHandler RefreshingScopeHandler 52 ResponseErrorHandler ResponseErrorHandler 53 InternalErrorHandler InternalErrorHandler 54 ExtensionFieldsHandler ExtensionFieldsHandler 55 AccessTokenExpHandler AccessTokenExpHandler 56 AuthorizeScopeHandler AuthorizeScopeHandler 57 } 58 59 func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { 60 if req == nil { 61 return err 62 } 63 data, _, _ := s.GetErrorData(err) 64 return s.redirect(w, req, data) 65 } 66 67 func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { 68 uri, err := s.GetRedirectURI(req, data) 69 if err != nil { 70 return err 71 } 72 73 w.Header().Set("Location", uri) 74 w.WriteHeader(302) 75 return nil 76 } 77 78 func (s *Server) tokenError(w http.ResponseWriter, err error) error { 79 data, statusCode, header := s.GetErrorData(err) 80 return s.token(w, data, header, statusCode) 81 } 82 83 func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { 84 w.Header().Set("Content-Type", "application/json;charset=UTF-8") 85 w.Header().Set("Cache-Control", "no-store") 86 w.Header().Set("Pragma", "no-cache") 87 88 for key := range header { 89 w.Header().Set(key, header.Get(key)) 90 } 91 92 status := http.StatusOK 93 if len(statusCode) > 0 && statusCode[0] > 0 { 94 status = statusCode[0] 95 } 96 97 w.WriteHeader(status) 98 return json.NewEncoder(w).Encode(data) 99 } 100 101 // GetRedirectURI get redirect uri 102 func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { 103 u, err := url.Parse(req.RedirectURI) 104 if err != nil { 105 return "", err 106 } 107 108 q := u.Query() 109 if req.State != "" { 110 q.Set("state", req.State) 111 } 112 113 for k, v := range data { 114 q.Set(k, fmt.Sprint(v)) 115 } 116 117 switch req.ResponseType { 118 case oauth2.Code: 119 u.RawQuery = q.Encode() 120 case oauth2.Token: 121 u.RawQuery = "" 122 fragment, err := url.QueryUnescape(q.Encode()) 123 if err != nil { 124 return "", err 125 } 126 u.Fragment = fragment 127 } 128 129 return u.String(), nil 130 } 131 132 // CheckResponseType check allows response type 133 func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { 134 for _, art := range s.Config.AllowedResponseTypes { 135 if art == rt { 136 return true 137 } 138 } 139 return false 140 } 141 142 // CheckCodeChallengeMethod checks for allowed code challenge method 143 func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool { 144 for _, c := range s.Config.AllowedCodeChallengeMethods { 145 if c == ccm { 146 return true 147 } 148 } 149 return false 150 } 151 152 // ValidationAuthorizeRequest the authorization request validation 153 func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { 154 redirectURI := r.FormValue("redirect_uri") 155 clientID := r.FormValue("client_id") 156 if !(r.Method == "GET" || r.Method == "POST") || 157 clientID == "" { 158 return nil, errors.ErrInvalidRequest 159 } 160 161 resType := oauth2.ResponseType(r.FormValue("response_type")) 162 if resType.String() == "" { 163 return nil, errors.ErrUnsupportedResponseType 164 } else if allowed := s.CheckResponseType(resType); !allowed { 165 return nil, errors.ErrUnauthorizedClient 166 } 167 168 cc := r.FormValue("code_challenge") 169 if cc == "" && s.Config.ForcePKCE { 170 return nil, errors.ErrCodeChallengeRquired 171 } 172 if cc != "" && (len(cc) < 43 || len(cc) > 128) { 173 return nil, errors.ErrInvalidCodeChallengeLen 174 } 175 176 ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method")) 177 // set default 178 if ccm == "" { 179 ccm = oauth2.CodeChallengePlain 180 } 181 if ccm.String() != "" && !s.CheckCodeChallengeMethod(ccm) { 182 return nil, errors.ErrUnsupportedCodeChallengeMethod 183 } 184 185 req := &AuthorizeRequest{ 186 RedirectURI: redirectURI, 187 ResponseType: resType, 188 ClientID: clientID, 189 State: r.FormValue("state"), 190 Scope: r.FormValue("scope"), 191 Request: r, 192 CodeChallenge: cc, 193 CodeChallengeMethod: ccm, 194 } 195 return req, nil 196 } 197 198 // GetAuthorizeToken get authorization token(code) 199 func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { 200 // check the client allows the grant type 201 if fn := s.ClientAuthorizedHandler; fn != nil { 202 gt := oauth2.AuthorizationCode 203 if req.ResponseType == oauth2.Token { 204 gt = oauth2.Implicit 205 } 206 207 allowed, err := fn(req.ClientID, gt) 208 if err != nil { 209 return nil, err 210 } else if !allowed { 211 return nil, errors.ErrUnauthorizedClient 212 } 213 } 214 215 tgr := &oauth2.TokenGenerateRequest{ 216 ClientID: req.ClientID, 217 UserID: req.UserID, 218 RedirectURI: req.RedirectURI, 219 Scope: req.Scope, 220 AccessTokenExp: req.AccessTokenExp, 221 Request: req.Request, 222 } 223 224 // check the client allows the authorized scope 225 if fn := s.ClientScopeHandler; fn != nil { 226 allowed, err := fn(tgr) 227 if err != nil { 228 return nil, err 229 } else if !allowed { 230 return nil, errors.ErrInvalidScope 231 } 232 } 233 234 tgr.CodeChallenge = req.CodeChallenge 235 tgr.CodeChallengeMethod = req.CodeChallengeMethod 236 237 return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) 238 } 239 240 // GetAuthorizeData get authorization response data 241 func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { 242 if rt == oauth2.Code { 243 return map[string]interface{}{ 244 "code": ti.GetCode(), 245 } 246 } 247 return s.GetTokenData(ti) 248 } 249 250 // HandleAuthorizeRequest the authorization request handling 251 func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { 252 ctx := r.Context() 253 254 req, err := s.ValidationAuthorizeRequest(r) 255 if err != nil { 256 return s.redirectError(w, req, err) 257 } 258 259 // user authorization 260 userID, err := s.UserAuthorizationHandler(w, r) 261 if err != nil { 262 return s.redirectError(w, req, err) 263 } else if userID == "" { 264 return nil 265 } 266 req.UserID = userID 267 268 // specify the scope of authorization 269 if fn := s.AuthorizeScopeHandler; fn != nil { 270 scope, err := fn(w, r) 271 if err != nil { 272 return err 273 } else if scope != "" { 274 req.Scope = scope 275 } 276 } 277 278 // specify the expiration time of access token 279 if fn := s.AccessTokenExpHandler; fn != nil { 280 exp, err := fn(w, r) 281 if err != nil { 282 return err 283 } 284 req.AccessTokenExp = exp 285 } 286 287 ti, err := s.GetAuthorizeToken(ctx, req) 288 if err != nil { 289 return s.redirectError(w, req, err) 290 } 291 292 // If the redirect URI is empty, the default domain provided by the client is used. 293 if req.RedirectURI == "" { 294 client, err := s.Manager.GetClient(ctx, req.ClientID) 295 if err != nil { 296 return err 297 } 298 req.RedirectURI = client.GetDomain() 299 } 300 301 return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) 302 } 303 304 // ValidationTokenRequest the token request validation 305 func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { 306 if v := r.Method; !(v == "POST" || 307 (s.Config.AllowGetAccessRequest && v == "GET")) { 308 return "", nil, errors.ErrInvalidRequest 309 } 310 311 gt := oauth2.GrantType(r.FormValue("grant_type")) 312 if gt.String() == "" { 313 return "", nil, errors.ErrUnsupportedGrantType 314 } 315 316 if !s.CheckGrantType(gt) { 317 return "", nil, errors.ErrUnsupportedGrantType 318 } 319 320 clientID, clientSecret, err := s.ClientInfoHandler(r) 321 if err != nil { 322 return "", nil, err 323 } 324 325 tgr := &oauth2.TokenGenerateRequest{ 326 ClientID: clientID, 327 ClientSecret: clientSecret, 328 Request: r, 329 } 330 331 switch gt { 332 case oauth2.AuthorizationCode: 333 tgr.RedirectURI = r.FormValue("redirect_uri") 334 tgr.Code = r.FormValue("code") 335 if tgr.RedirectURI == "" || 336 tgr.Code == "" { 337 return "", nil, errors.ErrInvalidRequest 338 } 339 tgr.CodeVerifier = r.FormValue("code_verifier") 340 if s.Config.ForcePKCE && tgr.CodeVerifier == "" { 341 return "", nil, errors.ErrInvalidRequest 342 } 343 case oauth2.PasswordCredentials: 344 tgr.Scope = r.FormValue("scope") 345 username, password := r.FormValue("username"), r.FormValue("password") 346 if username == "" || password == "" { 347 return "", nil, errors.ErrInvalidRequest 348 } 349 350 userID, err := s.PasswordAuthorizationHandler(username, password) 351 if err != nil { 352 return "", nil, err 353 } else if userID == "" { 354 return "", nil, errors.ErrInvalidGrant 355 } 356 tgr.UserID = userID 357 case oauth2.ClientCredentials: 358 tgr.Scope = r.FormValue("scope") 359 tgr.RedirectURI = r.FormValue("redirect_uri") 360 case oauth2.Refreshing: 361 tgr.Refresh = r.FormValue("refresh_token") 362 tgr.Scope = r.FormValue("scope") 363 if tgr.Refresh == "" { 364 return "", nil, errors.ErrInvalidRequest 365 } 366 } 367 return gt, tgr, nil 368 } 369 370 // CheckGrantType check allows grant type 371 func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { 372 for _, agt := range s.Config.AllowedGrantTypes { 373 if agt == gt { 374 return true 375 } 376 } 377 return false 378 } 379 380 // GetAccessToken access token 381 func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, 382 error) { 383 if allowed := s.CheckGrantType(gt); !allowed { 384 return nil, errors.ErrUnauthorizedClient 385 } 386 387 if fn := s.ClientAuthorizedHandler; fn != nil { 388 allowed, err := fn(tgr.ClientID, gt) 389 if err != nil { 390 return nil, err 391 } else if !allowed { 392 return nil, errors.ErrUnauthorizedClient 393 } 394 } 395 396 switch gt { 397 case oauth2.AuthorizationCode: 398 ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) 399 if err != nil { 400 switch err { 401 case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, errors.ErrMissingCodeChallenge: 402 return nil, errors.ErrInvalidGrant 403 case errors.ErrInvalidClient: 404 return nil, errors.ErrInvalidClient 405 default: 406 return nil, err 407 } 408 } 409 return ti, nil 410 case oauth2.PasswordCredentials, oauth2.ClientCredentials: 411 if fn := s.ClientScopeHandler; fn != nil { 412 allowed, err := fn(tgr) 413 if err != nil { 414 return nil, err 415 } else if !allowed { 416 return nil, errors.ErrInvalidScope 417 } 418 } 419 return s.Manager.GenerateAccessToken(ctx, gt, tgr) 420 case oauth2.Refreshing: 421 // check scope 422 if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { 423 rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) 424 if err != nil { 425 if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { 426 return nil, errors.ErrInvalidGrant 427 } 428 return nil, err 429 } 430 431 allowed, err := scopeFn(tgr, rti.GetScope()) 432 if err != nil { 433 return nil, err 434 } else if !allowed { 435 return nil, errors.ErrInvalidScope 436 } 437 } 438 439 if validationFn := s.RefreshingValidationHandler; validationFn != nil { 440 rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) 441 if err != nil { 442 if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { 443 return nil, errors.ErrInvalidGrant 444 } 445 return nil, err 446 } 447 allowed, err := validationFn(rti) 448 if err != nil { 449 return nil, err 450 } else if !allowed { 451 return nil, errors.ErrInvalidScope 452 } 453 } 454 455 ti, err := s.Manager.RefreshAccessToken(ctx, tgr) 456 if err != nil { 457 if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { 458 return nil, errors.ErrInvalidGrant 459 } 460 return nil, err 461 } 462 return ti, nil 463 } 464 465 return nil, errors.ErrUnsupportedGrantType 466 } 467 468 // GetTokenData token data 469 func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { 470 data := map[string]interface{}{ 471 "access_token": ti.GetAccess(), 472 "token_type": s.Config.TokenType, 473 "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), 474 } 475 476 if scope := ti.GetScope(); scope != "" { 477 data["scope"] = scope 478 } 479 480 if refresh := ti.GetRefresh(); refresh != "" { 481 data["refresh_token"] = refresh 482 } 483 484 if fn := s.ExtensionFieldsHandler; fn != nil { 485 ext := fn(ti) 486 for k, v := range ext { 487 if _, ok := data[k]; ok { 488 continue 489 } 490 data[k] = v 491 } 492 } 493 return data 494 } 495 496 // HandleTokenRequest token request handling 497 func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { 498 ctx := r.Context() 499 500 gt, tgr, err := s.ValidationTokenRequest(r) 501 if err != nil { 502 return s.tokenError(w, err) 503 } 504 505 ti, err := s.GetAccessToken(ctx, gt, tgr) 506 if err != nil { 507 return s.tokenError(w, err) 508 } 509 510 return s.token(w, s.GetTokenData(ti), nil) 511 } 512 513 // GetErrorData get error response data 514 func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { 515 var re errors.Response 516 if v, ok := errors.Descriptions[err]; ok { 517 re.Error = err 518 re.Description = v 519 re.StatusCode = errors.StatusCodes[err] 520 } else { 521 if fn := s.InternalErrorHandler; fn != nil { 522 if v := fn(err); v != nil { 523 re = *v 524 } 525 } 526 527 if re.Error == nil { 528 re.Error = errors.ErrServerError 529 re.Description = errors.Descriptions[errors.ErrServerError] 530 re.StatusCode = errors.StatusCodes[errors.ErrServerError] 531 } 532 } 533 534 if fn := s.ResponseErrorHandler; fn != nil { 535 fn(&re) 536 } 537 538 data := make(map[string]interface{}) 539 if err := re.Error; err != nil { 540 data["error"] = err.Error() 541 } 542 543 if v := re.ErrorCode; v != 0 { 544 data["error_code"] = v 545 } 546 547 if v := re.Description; v != "" { 548 data["error_description"] = v 549 } 550 551 if v := re.URI; v != "" { 552 data["error_uri"] = v 553 } 554 555 statusCode := http.StatusInternalServerError 556 if v := re.StatusCode; v > 0 { 557 statusCode = v 558 } 559 560 return data, statusCode, re.Header 561 } 562 563 // BearerAuth parse bearer token 564 func (s *Server) BearerAuth(r *http.Request) (string, bool) { 565 auth := r.Header.Get("Authorization") 566 prefix := "Bearer " 567 token := "" 568 569 if auth != "" && strings.HasPrefix(auth, prefix) { 570 token = auth[len(prefix):] 571 } else { 572 token = r.FormValue("access_token") 573 } 574 575 return token, token != "" 576 } 577 578 // ValidationBearerToken validation the bearer tokens 579 // https://tools.ietf.org/html/rfc6750 580 func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { 581 ctx := r.Context() 582 583 accessToken, ok := s.BearerAuth(r) 584 if !ok { 585 return nil, errors.ErrInvalidAccessToken 586 } 587 588 return s.Manager.LoadAccessToken(ctx, accessToken) 589 }