gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

server.go (11182B)


      1 // GoToSocial
      2 // Copyright (C) GoToSocial Authors admin@gotosocial.org
      3 // SPDX-License-Identifier: AGPL-3.0-or-later
      4 //
      5 // This program is free software: you can redistribute it and/or modify
      6 // it under the terms of the GNU Affero General Public License as published by
      7 // the Free Software Foundation, either version 3 of the License, or
      8 // (at your option) any later version.
      9 //
     10 // This program is distributed in the hope that it will be useful,
     11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
     12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     13 // GNU Affero General Public License for more details.
     14 //
     15 // You should have received a copy of the GNU Affero General Public License
     16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
     17 
     18 package oauth
     19 
     20 import (
     21 	"context"
     22 	"errors"
     23 	"fmt"
     24 	"net/http"
     25 	"strings"
     26 
     27 	"github.com/superseriousbusiness/gotosocial/internal/db"
     28 	"github.com/superseriousbusiness/gotosocial/internal/gtserror"
     29 	"github.com/superseriousbusiness/gotosocial/internal/log"
     30 	"github.com/superseriousbusiness/oauth2/v4"
     31 	oautherr "github.com/superseriousbusiness/oauth2/v4/errors"
     32 	"github.com/superseriousbusiness/oauth2/v4/manage"
     33 	"github.com/superseriousbusiness/oauth2/v4/server"
     34 )
     35 
     36 const (
     37 	// SessionAuthorizedToken is the key set in the gin context for the Token
     38 	// of a User who has successfully passed Bearer token authorization.
     39 	// The interface returned from grabbing this key should be parsed as oauth2.TokenInfo
     40 	SessionAuthorizedToken = "authorized_token"
     41 	// SessionAuthorizedUser is the key set in the gin context for the id of
     42 	// a User who has successfully passed Bearer token authorization.
     43 	// The interface returned from grabbing this key should be parsed as a *gtsmodel.User
     44 	SessionAuthorizedUser = "authorized_user"
     45 	// SessionAuthorizedAccount is the key set in the gin context for the Account
     46 	// of a User who has successfully passed Bearer token authorization.
     47 	// The interface returned from grabbing this key should be parsed as a *gtsmodel.Account
     48 	SessionAuthorizedAccount = "authorized_account"
     49 	// SessionAuthorizedApplication is the key set in the gin context for the Application
     50 	// of a Client who has successfully passed Bearer token authorization.
     51 	// The interface returned from grabbing this key should be parsed as a *gtsmodel.Application
     52 	SessionAuthorizedApplication = "authorized_app"
     53 	// OOBURI is the out-of-band oauth token uri
     54 	OOBURI = "urn:ietf:wg:oauth:2.0:oob"
     55 	// OOBTokenPath is the path to redirect out-of-band token requests to.
     56 	OOBTokenPath = "/oauth/oob" // #nosec G101 else we get a hardcoded credentials warning
     57 	// HelpfulAdvice is a handy hint to users;
     58 	// particularly important during the login flow
     59 	HelpfulAdvice      = "If you arrived at this error during a login/oauth flow, please try clearing your session cookies and logging in again; if problems persist, make sure you're using the correct credentials"
     60 	HelpfulAdviceGrant = "If you arrived at this error during a login/oauth flow, your client is trying to use an unsupported OAuth grant type. Supported grant types are: authorization_code, client_credentials; please reach out to developer of your client"
     61 )
     62 
     63 // Server wraps some oauth2 server functions in an interface, exposing only what is needed
     64 type Server interface {
     65 	HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode)
     66 	HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtserror.WithCode
     67 	ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error)
     68 	GenerateUserAccessToken(ctx context.Context, ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error)
     69 	LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error)
     70 }
     71 
     72 // s fulfils the Server interface using the underlying oauth2 server
     73 type s struct {
     74 	server *server.Server
     75 }
     76 
     77 // New returns a new oauth server that implements the Server interface
     78 func New(ctx context.Context, database db.Basic) Server {
     79 	ts := newTokenStore(ctx, database)
     80 	cs := NewClientStore(database)
     81 
     82 	manager := manage.NewDefaultManager()
     83 	manager.MapTokenStorage(ts)
     84 	manager.MapClientStorage(cs)
     85 	manager.SetAuthorizeCodeTokenCfg(&manage.Config{
     86 		AccessTokenExp:    0,     // access tokens don't expire -- they must be revoked
     87 		IsGenerateRefresh: false, // don't use refresh tokens
     88 	})
     89 	sc := &server.Config{
     90 		TokenType: "Bearer",
     91 		// Must follow the spec.
     92 		AllowGetAccessRequest: false,
     93 		// Support only the non-implicit flow.
     94 		AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
     95 		// Allow:
     96 		// - Authorization Code (for first & third parties)
     97 		// - Client Credentials (for applications)
     98 		AllowedGrantTypes: []oauth2.GrantType{
     99 			oauth2.AuthorizationCode,
    100 			oauth2.ClientCredentials,
    101 		},
    102 		AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
    103 	}
    104 
    105 	srv := server.NewServer(sc, manager)
    106 	srv.SetInternalErrorHandler(func(err error) *oautherr.Response {
    107 		log.Errorf(nil, "internal oauth error: %s", err)
    108 		return nil
    109 	})
    110 
    111 	srv.SetResponseErrorHandler(func(re *oautherr.Response) {
    112 		log.Errorf(nil, "internal response error: %s", re.Error)
    113 	})
    114 
    115 	srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
    116 		userID := r.FormValue("userid")
    117 		if userID == "" {
    118 			return "", errors.New("userid was empty")
    119 		}
    120 		return userID, nil
    121 	})
    122 	srv.SetClientInfoHandler(server.ClientFormHandler)
    123 	return &s{
    124 		server: srv,
    125 	}
    126 }
    127 
    128 // HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
    129 func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode) {
    130 	ctx := r.Context()
    131 
    132 	gt, tgr, err := s.server.ValidationTokenRequest(r)
    133 	if err != nil {
    134 		help := fmt.Sprintf("could not validate token request: %s", err)
    135 		adv := HelpfulAdvice
    136 		if errors.Is(err, oautherr.ErrUnsupportedGrantType) {
    137 			adv = HelpfulAdviceGrant
    138 		}
    139 		return nil, gtserror.NewErrorBadRequest(err, help, adv)
    140 	}
    141 
    142 	ti, err := s.server.GetAccessToken(ctx, gt, tgr)
    143 	if err != nil {
    144 		help := fmt.Sprintf("could not get access token: %s", err)
    145 		return nil, gtserror.NewErrorBadRequest(err, help, HelpfulAdvice)
    146 	}
    147 
    148 	data := s.server.GetTokenData(ti)
    149 
    150 	if expiresInI, ok := data["expires_in"]; ok {
    151 		switch expiresIn := expiresInI.(type) {
    152 		case int64:
    153 			// remove this key from the returned map
    154 			// if the value is 0 or less, so that clients
    155 			// don't interpret the token as already expired
    156 			if expiresIn <= 0 {
    157 				delete(data, "expires_in")
    158 			}
    159 		default:
    160 			err := errors.New("expires_in was set on token response, but was not an int64")
    161 			return nil, gtserror.NewErrorInternalError(err, HelpfulAdvice)
    162 		}
    163 	}
    164 
    165 	// add this for mastodon api compatibility
    166 	data["created_at"] = ti.GetAccessCreateAt().Unix()
    167 
    168 	return data, nil
    169 }
    170 
    171 func (s *s) errorOrRedirect(err error, w http.ResponseWriter, req *server.AuthorizeRequest) gtserror.WithCode {
    172 	if req == nil {
    173 		return gtserror.NewErrorUnauthorized(err, HelpfulAdvice)
    174 	}
    175 
    176 	data, _, _ := s.server.GetErrorData(err)
    177 	uri, err := s.server.GetRedirectURI(req, data)
    178 	if err != nil {
    179 		return gtserror.NewErrorInternalError(err, HelpfulAdvice)
    180 	}
    181 
    182 	w.Header().Set("Location", uri)
    183 	w.WriteHeader(http.StatusFound)
    184 	return nil
    185 }
    186 
    187 // HandleAuthorizeRequest wraps the oauth2 library's HandleAuthorizeRequest function
    188 func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtserror.WithCode {
    189 	ctx := r.Context()
    190 
    191 	req, err := s.server.ValidationAuthorizeRequest(r)
    192 	if err != nil {
    193 		return s.errorOrRedirect(err, w, req)
    194 	}
    195 
    196 	// user authorization
    197 	userID, err := s.server.UserAuthorizationHandler(w, r)
    198 	if err != nil {
    199 		return s.errorOrRedirect(err, w, req)
    200 	}
    201 	if userID == "" {
    202 		help := "userID was empty"
    203 		return gtserror.NewErrorUnauthorized(err, help, HelpfulAdvice)
    204 	}
    205 	req.UserID = userID
    206 
    207 	// specify the scope of authorization
    208 	if fn := s.server.AuthorizeScopeHandler; fn != nil {
    209 		scope, err := fn(w, r)
    210 		if err != nil {
    211 			return s.errorOrRedirect(err, w, req)
    212 		} else if scope != "" {
    213 			req.Scope = scope
    214 		}
    215 	}
    216 
    217 	// specify the expiration time of access token
    218 	if fn := s.server.AccessTokenExpHandler; fn != nil {
    219 		exp, err := fn(w, r)
    220 		if err != nil {
    221 			return s.errorOrRedirect(err, w, req)
    222 		}
    223 		req.AccessTokenExp = exp
    224 	}
    225 
    226 	ti, err := s.server.GetAuthorizeToken(ctx, req)
    227 	if err != nil {
    228 		return s.errorOrRedirect(err, w, req)
    229 	}
    230 
    231 	// If the redirect URI is empty, the default domain provided by the client is used.
    232 	if req.RedirectURI == "" {
    233 		client, err := s.server.Manager.GetClient(ctx, req.ClientID)
    234 		if err != nil {
    235 			return gtserror.NewErrorUnauthorized(err, HelpfulAdvice)
    236 		}
    237 		req.RedirectURI = client.GetDomain()
    238 	}
    239 
    240 	uri, err := s.server.GetRedirectURI(req, s.server.GetAuthorizeData(req.ResponseType, ti))
    241 	if err != nil {
    242 		return gtserror.NewErrorUnauthorized(err, HelpfulAdvice)
    243 	}
    244 
    245 	if strings.Contains(uri, OOBURI) {
    246 		w.Header().Set("Location", strings.ReplaceAll(uri, OOBURI, OOBTokenPath))
    247 	} else {
    248 		w.Header().Set("Location", uri)
    249 	}
    250 
    251 	w.WriteHeader(http.StatusFound)
    252 	return nil
    253 }
    254 
    255 // ValidationBearerToken wraps the oauth2 library's ValidationBearerToken function
    256 func (s *s) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
    257 	return s.server.ValidationBearerToken(r)
    258 }
    259 
    260 // GenerateUserAccessToken shortcuts the normal oauth flow to create an user-level
    261 // bearer token *without* requiring that user to log in. This is useful when we
    262 // need to create a token for new users who haven't validated their email or logged in yet.
    263 //
    264 // The ti parameter refers to an existing Application token that was used to make the upstream
    265 // request. This token needs to be validated and exist in database in order to create a new token.
    266 func (s *s) GenerateUserAccessToken(ctx context.Context, ti oauth2.TokenInfo, clientSecret string, userID string) (oauth2.TokenInfo, error) {
    267 	authToken, err := s.server.Manager.GenerateAuthToken(ctx, oauth2.Code, &oauth2.TokenGenerateRequest{
    268 		ClientID:     ti.GetClientID(),
    269 		ClientSecret: clientSecret,
    270 		UserID:       userID,
    271 		RedirectURI:  ti.GetRedirectURI(),
    272 		Scope:        ti.GetScope(),
    273 	})
    274 	if err != nil {
    275 		return nil, fmt.Errorf("error generating auth token: %s", err)
    276 	}
    277 	if authToken == nil {
    278 		return nil, errors.New("generated auth token was empty")
    279 	}
    280 	log.Tracef(ctx, "obtained auth token: %+v", authToken)
    281 
    282 	accessToken, err := s.server.Manager.GenerateAccessToken(ctx, oauth2.AuthorizationCode, &oauth2.TokenGenerateRequest{
    283 		ClientID:     authToken.GetClientID(),
    284 		ClientSecret: clientSecret,
    285 		RedirectURI:  authToken.GetRedirectURI(),
    286 		Scope:        authToken.GetScope(),
    287 		Code:         authToken.GetCode(),
    288 	})
    289 	if err != nil {
    290 		return nil, fmt.Errorf("error generating user-level access token: %s", err)
    291 	}
    292 	if accessToken == nil {
    293 		return nil, errors.New("generated user-level access token was empty")
    294 	}
    295 	log.Tracef(ctx, "obtained user-level access token: %+v", accessToken)
    296 	return accessToken, nil
    297 }
    298 
    299 func (s *s) LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error) {
    300 	return s.server.Manager.LoadAccessToken(ctx, access)
    301 }