gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

server.go (15161B)


      1 package server
      2 
      3 import (
      4 	"context"
      5 	"encoding/json"
      6 	"fmt"
      7 	"net/http"
      8 	"net/url"
      9 	"strings"
     10 	"time"
     11 
     12 	"github.com/superseriousbusiness/oauth2/v4"
     13 	"github.com/superseriousbusiness/oauth2/v4/errors"
     14 )
     15 
     16 // NewDefaultServer create a default authorization server
     17 func NewDefaultServer(manager oauth2.Manager) *Server {
     18 	return NewServer(NewConfig(), manager)
     19 }
     20 
     21 // NewServer create authorization server
     22 func NewServer(cfg *Config, manager oauth2.Manager) *Server {
     23 	srv := &Server{
     24 		Config:  cfg,
     25 		Manager: manager,
     26 	}
     27 
     28 	// default handler
     29 	srv.ClientInfoHandler = ClientBasicHandler
     30 
     31 	srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) {
     32 		return "", errors.ErrAccessDenied
     33 	}
     34 
     35 	srv.PasswordAuthorizationHandler = func(username, password string) (string, error) {
     36 		return "", errors.ErrAccessDenied
     37 	}
     38 	return srv
     39 }
     40 
     41 // Server Provide authorization server
     42 type Server struct {
     43 	Config                       *Config
     44 	Manager                      oauth2.Manager
     45 	ClientInfoHandler            ClientInfoHandler
     46 	ClientAuthorizedHandler      ClientAuthorizedHandler
     47 	ClientScopeHandler           ClientScopeHandler
     48 	UserAuthorizationHandler     UserAuthorizationHandler
     49 	PasswordAuthorizationHandler PasswordAuthorizationHandler
     50 	RefreshingValidationHandler  RefreshingValidationHandler
     51 	RefreshingScopeHandler       RefreshingScopeHandler
     52 	ResponseErrorHandler         ResponseErrorHandler
     53 	InternalErrorHandler         InternalErrorHandler
     54 	ExtensionFieldsHandler       ExtensionFieldsHandler
     55 	AccessTokenExpHandler        AccessTokenExpHandler
     56 	AuthorizeScopeHandler        AuthorizeScopeHandler
     57 }
     58 
     59 func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error {
     60 	if req == nil {
     61 		return err
     62 	}
     63 	data, _, _ := s.GetErrorData(err)
     64 	return s.redirect(w, req, data)
     65 }
     66 
     67 func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error {
     68 	uri, err := s.GetRedirectURI(req, data)
     69 	if err != nil {
     70 		return err
     71 	}
     72 
     73 	w.Header().Set("Location", uri)
     74 	w.WriteHeader(302)
     75 	return nil
     76 }
     77 
     78 func (s *Server) tokenError(w http.ResponseWriter, err error) error {
     79 	data, statusCode, header := s.GetErrorData(err)
     80 	return s.token(w, data, header, statusCode)
     81 }
     82 
     83 func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error {
     84 	w.Header().Set("Content-Type", "application/json;charset=UTF-8")
     85 	w.Header().Set("Cache-Control", "no-store")
     86 	w.Header().Set("Pragma", "no-cache")
     87 
     88 	for key := range header {
     89 		w.Header().Set(key, header.Get(key))
     90 	}
     91 
     92 	status := http.StatusOK
     93 	if len(statusCode) > 0 && statusCode[0] > 0 {
     94 		status = statusCode[0]
     95 	}
     96 
     97 	w.WriteHeader(status)
     98 	return json.NewEncoder(w).Encode(data)
     99 }
    100 
    101 // GetRedirectURI get redirect uri
    102 func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) {
    103 	u, err := url.Parse(req.RedirectURI)
    104 	if err != nil {
    105 		return "", err
    106 	}
    107 
    108 	q := u.Query()
    109 	if req.State != "" {
    110 		q.Set("state", req.State)
    111 	}
    112 
    113 	for k, v := range data {
    114 		q.Set(k, fmt.Sprint(v))
    115 	}
    116 
    117 	switch req.ResponseType {
    118 	case oauth2.Code:
    119 		u.RawQuery = q.Encode()
    120 	case oauth2.Token:
    121 		u.RawQuery = ""
    122 		fragment, err := url.QueryUnescape(q.Encode())
    123 		if err != nil {
    124 			return "", err
    125 		}
    126 		u.Fragment = fragment
    127 	}
    128 
    129 	return u.String(), nil
    130 }
    131 
    132 // CheckResponseType check allows response type
    133 func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool {
    134 	for _, art := range s.Config.AllowedResponseTypes {
    135 		if art == rt {
    136 			return true
    137 		}
    138 	}
    139 	return false
    140 }
    141 
    142 // CheckCodeChallengeMethod checks for allowed code challenge method
    143 func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool {
    144 	for _, c := range s.Config.AllowedCodeChallengeMethods {
    145 		if c == ccm {
    146 			return true
    147 		}
    148 	}
    149 	return false
    150 }
    151 
    152 // ValidationAuthorizeRequest the authorization request validation
    153 func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) {
    154 	redirectURI := r.FormValue("redirect_uri")
    155 	clientID := r.FormValue("client_id")
    156 	if !(r.Method == "GET" || r.Method == "POST") ||
    157 		clientID == "" {
    158 		return nil, errors.ErrInvalidRequest
    159 	}
    160 
    161 	resType := oauth2.ResponseType(r.FormValue("response_type"))
    162 	if resType.String() == "" {
    163 		return nil, errors.ErrUnsupportedResponseType
    164 	} else if allowed := s.CheckResponseType(resType); !allowed {
    165 		return nil, errors.ErrUnauthorizedClient
    166 	}
    167 
    168 	cc := r.FormValue("code_challenge")
    169 	if cc == "" && s.Config.ForcePKCE {
    170 		return nil, errors.ErrCodeChallengeRquired
    171 	}
    172 	if cc != "" && (len(cc) < 43 || len(cc) > 128) {
    173 		return nil, errors.ErrInvalidCodeChallengeLen
    174 	}
    175 
    176 	ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method"))
    177 	// set default
    178 	if ccm == "" {
    179 		ccm = oauth2.CodeChallengePlain
    180 	}
    181 	if ccm.String() != "" && !s.CheckCodeChallengeMethod(ccm) {
    182 		return nil, errors.ErrUnsupportedCodeChallengeMethod
    183 	}
    184 
    185 	req := &AuthorizeRequest{
    186 		RedirectURI:         redirectURI,
    187 		ResponseType:        resType,
    188 		ClientID:            clientID,
    189 		State:               r.FormValue("state"),
    190 		Scope:               r.FormValue("scope"),
    191 		Request:             r,
    192 		CodeChallenge:       cc,
    193 		CodeChallengeMethod: ccm,
    194 	}
    195 	return req, nil
    196 }
    197 
    198 // GetAuthorizeToken get authorization token(code)
    199 func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) {
    200 	// check the client allows the grant type
    201 	if fn := s.ClientAuthorizedHandler; fn != nil {
    202 		gt := oauth2.AuthorizationCode
    203 		if req.ResponseType == oauth2.Token {
    204 			gt = oauth2.Implicit
    205 		}
    206 
    207 		allowed, err := fn(req.ClientID, gt)
    208 		if err != nil {
    209 			return nil, err
    210 		} else if !allowed {
    211 			return nil, errors.ErrUnauthorizedClient
    212 		}
    213 	}
    214 
    215 	tgr := &oauth2.TokenGenerateRequest{
    216 		ClientID:       req.ClientID,
    217 		UserID:         req.UserID,
    218 		RedirectURI:    req.RedirectURI,
    219 		Scope:          req.Scope,
    220 		AccessTokenExp: req.AccessTokenExp,
    221 		Request:        req.Request,
    222 	}
    223 
    224 	// check the client allows the authorized scope
    225 	if fn := s.ClientScopeHandler; fn != nil {
    226 		allowed, err := fn(tgr)
    227 		if err != nil {
    228 			return nil, err
    229 		} else if !allowed {
    230 			return nil, errors.ErrInvalidScope
    231 		}
    232 	}
    233 
    234 	tgr.CodeChallenge = req.CodeChallenge
    235 	tgr.CodeChallengeMethod = req.CodeChallengeMethod
    236 
    237 	return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr)
    238 }
    239 
    240 // GetAuthorizeData get authorization response data
    241 func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} {
    242 	if rt == oauth2.Code {
    243 		return map[string]interface{}{
    244 			"code": ti.GetCode(),
    245 		}
    246 	}
    247 	return s.GetTokenData(ti)
    248 }
    249 
    250 // HandleAuthorizeRequest the authorization request handling
    251 func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error {
    252 	ctx := r.Context()
    253 
    254 	req, err := s.ValidationAuthorizeRequest(r)
    255 	if err != nil {
    256 		return s.redirectError(w, req, err)
    257 	}
    258 
    259 	// user authorization
    260 	userID, err := s.UserAuthorizationHandler(w, r)
    261 	if err != nil {
    262 		return s.redirectError(w, req, err)
    263 	} else if userID == "" {
    264 		return nil
    265 	}
    266 	req.UserID = userID
    267 
    268 	// specify the scope of authorization
    269 	if fn := s.AuthorizeScopeHandler; fn != nil {
    270 		scope, err := fn(w, r)
    271 		if err != nil {
    272 			return err
    273 		} else if scope != "" {
    274 			req.Scope = scope
    275 		}
    276 	}
    277 
    278 	// specify the expiration time of access token
    279 	if fn := s.AccessTokenExpHandler; fn != nil {
    280 		exp, err := fn(w, r)
    281 		if err != nil {
    282 			return err
    283 		}
    284 		req.AccessTokenExp = exp
    285 	}
    286 
    287 	ti, err := s.GetAuthorizeToken(ctx, req)
    288 	if err != nil {
    289 		return s.redirectError(w, req, err)
    290 	}
    291 
    292 	// If the redirect URI is empty, the default domain provided by the client is used.
    293 	if req.RedirectURI == "" {
    294 		client, err := s.Manager.GetClient(ctx, req.ClientID)
    295 		if err != nil {
    296 			return err
    297 		}
    298 		req.RedirectURI = client.GetDomain()
    299 	}
    300 
    301 	return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti))
    302 }
    303 
    304 // ValidationTokenRequest the token request validation
    305 func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) {
    306 	if v := r.Method; !(v == "POST" ||
    307 		(s.Config.AllowGetAccessRequest && v == "GET")) {
    308 		return "", nil, errors.ErrInvalidRequest
    309 	}
    310 
    311 	gt := oauth2.GrantType(r.FormValue("grant_type"))
    312 	if gt.String() == "" {
    313 		return "", nil, errors.ErrUnsupportedGrantType
    314 	}
    315 
    316 	if !s.CheckGrantType(gt) {
    317 		return "", nil, errors.ErrUnsupportedGrantType
    318 	}
    319 
    320 	clientID, clientSecret, err := s.ClientInfoHandler(r)
    321 	if err != nil {
    322 		return "", nil, err
    323 	}
    324 
    325 	tgr := &oauth2.TokenGenerateRequest{
    326 		ClientID:     clientID,
    327 		ClientSecret: clientSecret,
    328 		Request:      r,
    329 	}
    330 
    331 	switch gt {
    332 	case oauth2.AuthorizationCode:
    333 		tgr.RedirectURI = r.FormValue("redirect_uri")
    334 		tgr.Code = r.FormValue("code")
    335 		if tgr.RedirectURI == "" ||
    336 			tgr.Code == "" {
    337 			return "", nil, errors.ErrInvalidRequest
    338 		}
    339 		tgr.CodeVerifier = r.FormValue("code_verifier")
    340 		if s.Config.ForcePKCE && tgr.CodeVerifier == "" {
    341 			return "", nil, errors.ErrInvalidRequest
    342 		}
    343 	case oauth2.PasswordCredentials:
    344 		tgr.Scope = r.FormValue("scope")
    345 		username, password := r.FormValue("username"), r.FormValue("password")
    346 		if username == "" || password == "" {
    347 			return "", nil, errors.ErrInvalidRequest
    348 		}
    349 
    350 		userID, err := s.PasswordAuthorizationHandler(username, password)
    351 		if err != nil {
    352 			return "", nil, err
    353 		} else if userID == "" {
    354 			return "", nil, errors.ErrInvalidGrant
    355 		}
    356 		tgr.UserID = userID
    357 	case oauth2.ClientCredentials:
    358 		tgr.Scope = r.FormValue("scope")
    359 		tgr.RedirectURI = r.FormValue("redirect_uri")
    360 	case oauth2.Refreshing:
    361 		tgr.Refresh = r.FormValue("refresh_token")
    362 		tgr.Scope = r.FormValue("scope")
    363 		if tgr.Refresh == "" {
    364 			return "", nil, errors.ErrInvalidRequest
    365 		}
    366 	}
    367 	return gt, tgr, nil
    368 }
    369 
    370 // CheckGrantType check allows grant type
    371 func (s *Server) CheckGrantType(gt oauth2.GrantType) bool {
    372 	for _, agt := range s.Config.AllowedGrantTypes {
    373 		if agt == gt {
    374 			return true
    375 		}
    376 	}
    377 	return false
    378 }
    379 
    380 // GetAccessToken access token
    381 func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo,
    382 	error) {
    383 	if allowed := s.CheckGrantType(gt); !allowed {
    384 		return nil, errors.ErrUnauthorizedClient
    385 	}
    386 
    387 	if fn := s.ClientAuthorizedHandler; fn != nil {
    388 		allowed, err := fn(tgr.ClientID, gt)
    389 		if err != nil {
    390 			return nil, err
    391 		} else if !allowed {
    392 			return nil, errors.ErrUnauthorizedClient
    393 		}
    394 	}
    395 
    396 	switch gt {
    397 	case oauth2.AuthorizationCode:
    398 		ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr)
    399 		if err != nil {
    400 			switch err {
    401 			case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, errors.ErrMissingCodeChallenge:
    402 				return nil, errors.ErrInvalidGrant
    403 			case errors.ErrInvalidClient:
    404 				return nil, errors.ErrInvalidClient
    405 			default:
    406 				return nil, err
    407 			}
    408 		}
    409 		return ti, nil
    410 	case oauth2.PasswordCredentials, oauth2.ClientCredentials:
    411 		if fn := s.ClientScopeHandler; fn != nil {
    412 			allowed, err := fn(tgr)
    413 			if err != nil {
    414 				return nil, err
    415 			} else if !allowed {
    416 				return nil, errors.ErrInvalidScope
    417 			}
    418 		}
    419 		return s.Manager.GenerateAccessToken(ctx, gt, tgr)
    420 	case oauth2.Refreshing:
    421 		// check scope
    422 		if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil {
    423 			rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh)
    424 			if err != nil {
    425 				if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
    426 					return nil, errors.ErrInvalidGrant
    427 				}
    428 				return nil, err
    429 			}
    430 
    431 			allowed, err := scopeFn(tgr, rti.GetScope())
    432 			if err != nil {
    433 				return nil, err
    434 			} else if !allowed {
    435 				return nil, errors.ErrInvalidScope
    436 			}
    437 		}
    438 
    439 		if validationFn := s.RefreshingValidationHandler; validationFn != nil {
    440 			rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh)
    441 			if err != nil {
    442 				if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
    443 					return nil, errors.ErrInvalidGrant
    444 				}
    445 				return nil, err
    446 			}
    447 			allowed, err := validationFn(rti)
    448 			if err != nil {
    449 				return nil, err
    450 			} else if !allowed {
    451 				return nil, errors.ErrInvalidScope
    452 			}
    453 		}
    454 
    455 		ti, err := s.Manager.RefreshAccessToken(ctx, tgr)
    456 		if err != nil {
    457 			if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
    458 				return nil, errors.ErrInvalidGrant
    459 			}
    460 			return nil, err
    461 		}
    462 		return ti, nil
    463 	}
    464 
    465 	return nil, errors.ErrUnsupportedGrantType
    466 }
    467 
    468 // GetTokenData token data
    469 func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} {
    470 	data := map[string]interface{}{
    471 		"access_token": ti.GetAccess(),
    472 		"token_type":   s.Config.TokenType,
    473 		"expires_in":   int64(ti.GetAccessExpiresIn() / time.Second),
    474 	}
    475 
    476 	if scope := ti.GetScope(); scope != "" {
    477 		data["scope"] = scope
    478 	}
    479 
    480 	if refresh := ti.GetRefresh(); refresh != "" {
    481 		data["refresh_token"] = refresh
    482 	}
    483 
    484 	if fn := s.ExtensionFieldsHandler; fn != nil {
    485 		ext := fn(ti)
    486 		for k, v := range ext {
    487 			if _, ok := data[k]; ok {
    488 				continue
    489 			}
    490 			data[k] = v
    491 		}
    492 	}
    493 	return data
    494 }
    495 
    496 // HandleTokenRequest token request handling
    497 func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
    498 	ctx := r.Context()
    499 
    500 	gt, tgr, err := s.ValidationTokenRequest(r)
    501 	if err != nil {
    502 		return s.tokenError(w, err)
    503 	}
    504 
    505 	ti, err := s.GetAccessToken(ctx, gt, tgr)
    506 	if err != nil {
    507 		return s.tokenError(w, err)
    508 	}
    509 
    510 	return s.token(w, s.GetTokenData(ti), nil)
    511 }
    512 
    513 // GetErrorData get error response data
    514 func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) {
    515 	var re errors.Response
    516 	if v, ok := errors.Descriptions[err]; ok {
    517 		re.Error = err
    518 		re.Description = v
    519 		re.StatusCode = errors.StatusCodes[err]
    520 	} else {
    521 		if fn := s.InternalErrorHandler; fn != nil {
    522 			if v := fn(err); v != nil {
    523 				re = *v
    524 			}
    525 		}
    526 
    527 		if re.Error == nil {
    528 			re.Error = errors.ErrServerError
    529 			re.Description = errors.Descriptions[errors.ErrServerError]
    530 			re.StatusCode = errors.StatusCodes[errors.ErrServerError]
    531 		}
    532 	}
    533 
    534 	if fn := s.ResponseErrorHandler; fn != nil {
    535 		fn(&re)
    536 	}
    537 
    538 	data := make(map[string]interface{})
    539 	if err := re.Error; err != nil {
    540 		data["error"] = err.Error()
    541 	}
    542 
    543 	if v := re.ErrorCode; v != 0 {
    544 		data["error_code"] = v
    545 	}
    546 
    547 	if v := re.Description; v != "" {
    548 		data["error_description"] = v
    549 	}
    550 
    551 	if v := re.URI; v != "" {
    552 		data["error_uri"] = v
    553 	}
    554 
    555 	statusCode := http.StatusInternalServerError
    556 	if v := re.StatusCode; v > 0 {
    557 		statusCode = v
    558 	}
    559 
    560 	return data, statusCode, re.Header
    561 }
    562 
    563 // BearerAuth parse bearer token
    564 func (s *Server) BearerAuth(r *http.Request) (string, bool) {
    565 	auth := r.Header.Get("Authorization")
    566 	prefix := "Bearer "
    567 	token := ""
    568 
    569 	if auth != "" && strings.HasPrefix(auth, prefix) {
    570 		token = auth[len(prefix):]
    571 	} else {
    572 		token = r.FormValue("access_token")
    573 	}
    574 
    575 	return token, token != ""
    576 }
    577 
    578 // ValidationBearerToken validation the bearer tokens
    579 // https://tools.ietf.org/html/rfc6750
    580 func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
    581 	ctx := r.Context()
    582 
    583 	accessToken, ok := s.BearerAuth(r)
    584 	if !ok {
    585 		return nil, errors.ErrInvalidAccessToken
    586 	}
    587 
    588 	return s.Manager.LoadAccessToken(ctx, accessToken)
    589 }