commit aa9ce272dcfa1380b2f05bc3a90ef8ca1b0a7f62
parent 4194f8d88f3bedab6841e80287c007f4c0e6f245
Author: Tobi Smethurst <31960611+tsmethurst@users.noreply.github.com>
Date: Mon, 22 Mar 2021 22:26:54 +0100
Oauth/token (#7)
* add host and protocol options
* some fiddling
* tidying up and comments
* tick off /oauth/token
* tidying a bit
* tidying
* go mod tidy
* allow attaching middleware to server
* add middleware
* more user friendly
* add comments
* comments
* store account + app
* tidying
* lots of restructuring
* lint + tidy
Diffstat:
36 files changed, 1912 insertions(+), 1543 deletions(-)
diff --git a/PROGRESS.md b/PROGRESS.md
@@ -6,7 +6,7 @@
* [ ] /api/v1/apps/verify_credentials GET (Verify an application works)
* [x] /oauth/authorize GET (Show authorize page to user)
* [x] /oauth/authorize POST (Get an oauth access code for an app/user)
- * [ ] /oauth/token POST (Obtain a user-level access token)
+ * [x] /oauth/token POST (Obtain a user-level access token)
* [ ] /oauth/revoke POST (Revoke a user-level access token)
* [x] /auth/sign_in GET (Show form for user signin)
* [x] /auth/sign_in POST (Validate username and password and sign user in)
diff --git a/cmd/gotosocial/main.go b/cmd/gotosocial/main.go
@@ -58,6 +58,18 @@ func main() {
Value: "",
EnvVars: []string{envNames.ConfigPath},
},
+ &cli.StringFlag{
+ Name: flagNames.Host,
+ Usage: "Hostname to use for the server (eg., example.org, gotosocial.whatever.com)",
+ Value: "localhost",
+ EnvVars: []string{envNames.Host},
+ },
+ &cli.StringFlag{
+ Name: flagNames.Protocol,
+ Usage: "Protocol to use for the REST api of the server (only use http for debugging and tests!)",
+ Value: "https",
+ EnvVars: []string{envNames.Protocol},
+ },
// DATABASE FLAGS
&cli.StringFlag{
diff --git a/example/config.yaml b/example/config.yaml
@@ -28,6 +28,17 @@ logLevel: "info"
# Default: "gotosocial"
applicationName: "gotosocial"
+# String. Hostname/domain to use for the server. Defaults to localhost for local testing,
+# but you should *definitely* change this when running for real, or your server won't work at all.
+# Examples: ["example.org","some.server.com"]
+# Default: "localhost"
+host: "localhost"
+
+# String. Protocol to use for the server. Only change to http for local testing!
+# Options: ["http","https"]
+# Default: "https"
+protocol: "https"
+
# Config pertaining to the Gotosocial database connection
db:
# String. Database type.
diff --git a/go.mod b/go.mod
@@ -10,7 +10,7 @@ require (
github.com/go-pg/pg/v10 v10.8.0
github.com/golang/mock v1.4.4 // indirect
github.com/google/uuid v1.2.0
- github.com/gotosocial/oauth2/v4 v4.2.1-0.20210318133800-45d321d259b3
+ github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88
github.com/onsi/ginkgo v1.15.0 // indirect
github.com/onsi/gomega v1.10.5 // indirect
github.com/sirupsen/logrus v1.8.0
diff --git a/go.sum b/go.sum
@@ -103,8 +103,8 @@ github.com/gorilla/sessions v1.1.3 h1:uXoZdcdA5XdXF3QzuSlheVRUvjl+1rKY7zBXL68L9R
github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
-github.com/gotosocial/oauth2/v4 v4.2.1-0.20210318133800-45d321d259b3 h1:CKRz5d7mRum+UMR88Ue33tCYcej14WjUsB59C02DDqY=
-github.com/gotosocial/oauth2/v4 v4.2.1-0.20210318133800-45d321d259b3/go.mod h1:zl5kwHf/atRUrY5yOyDnk49Us1Ygs0BzdW4jKAgoiP8=
+github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88 h1:YJ//HmHOYJ4srm/LA6VPNjNisneMbY6TTM1xttV/ZQU=
+github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88/go.mod h1:zl5kwHf/atRUrY5yOyDnk49Us1Ygs0BzdW4jKAgoiP8=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk=
github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA=
diff --git a/internal/api/route_statuses.go b/internal/api/route_statuses.go
@@ -1,19 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see <http://www.gnu.org/licenses/>.
-*/
-
-package api
diff --git a/internal/api/server.go b/internal/api/server.go
@@ -1,87 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see <http://www.gnu.org/licenses/>.
-*/
-
-package api
-
-import (
- "fmt"
- "os"
- "path/filepath"
-
- "github.com/gin-contrib/sessions"
- "github.com/gin-contrib/sessions/memstore"
- "github.com/gin-gonic/gin"
- "github.com/gotosocial/gotosocial/internal/config"
- "github.com/sirupsen/logrus"
-)
-
-type Server interface {
- AttachHandler(method string, path string, handler gin.HandlerFunc)
- // AttachMiddleware(handler gin.HandlerFunc)
- GetAPIGroup() *gin.RouterGroup
- Start()
- Stop()
-}
-
-type AddsRoutes interface {
- AddRoutes(s Server) error
-}
-
-type server struct {
- APIGroup *gin.RouterGroup
- logger *logrus.Logger
- engine *gin.Engine
-}
-
-func (s *server) GetAPIGroup() *gin.RouterGroup {
- return s.APIGroup
-}
-
-func (s *server) Start() {
- // todo: start gracefully
- if err := s.engine.Run(); err != nil {
- s.logger.Panicf("server error: %s", err)
- }
-}
-
-func (s *server) Stop() {
- // todo: shut down gracefully
-}
-
-func (s *server) AttachHandler(method string, path string, handler gin.HandlerFunc) {
- if method == "ANY" {
- s.engine.Any(path, handler)
- } else {
- s.engine.Handle(method, path, handler)
- }
-}
-
-func New(config *config.Config, logger *logrus.Logger) Server {
- engine := gin.New()
- store := memstore.NewStore([]byte("authentication-key"), []byte("encryption-keyencryption-key----"))
- engine.Use(sessions.Sessions("gotosocial-session", store))
- cwd, _ := os.Getwd()
- tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", config.TemplateConfig.BaseDir))
- logger.Debugf("loading templates from %s", tmPath)
- engine.LoadHTMLGlob(tmPath)
- return &server{
- APIGroup: engine.Group("/api").Group("/v1"),
- logger: logger,
- engine: engine,
- }
-}
diff --git a/internal/config/config.go b/internal/config/config.go
@@ -29,6 +29,8 @@ import (
type Config struct {
LogLevel string `yaml:"logLevel"`
ApplicationName string `yaml:"applicationName"`
+ Host string `yaml:"host"`
+ Protocol string `yaml:"protocol"`
DBConfig *DBConfig `yaml:"db"`
TemplateConfig *TemplateConfig `yaml:"template"`
}
@@ -97,6 +99,14 @@ func (c *Config) ParseCLIFlags(f KeyedFlags) {
c.ApplicationName = f.String(fn.ApplicationName)
}
+ if c.Host == "" || f.IsSet(fn.Host) {
+ c.Host = f.String(fn.Host)
+ }
+
+ if c.Protocol == "" || f.IsSet(fn.Protocol) {
+ c.Protocol = f.String(fn.Protocol)
+ }
+
// db flags
if c.DBConfig.Type == "" || f.IsSet(fn.DbType) {
c.DBConfig.Type = f.String(fn.DbType)
@@ -142,6 +152,8 @@ type Flags struct {
LogLevel string
ApplicationName string
ConfigPath string
+ Host string
+ Protocol string
DbType string
DbAddress string
DbPort string
@@ -158,6 +170,8 @@ func GetFlagNames() Flags {
LogLevel: "log-level",
ApplicationName: "application-name",
ConfigPath: "config-path",
+ Host: "host",
+ Protocol: "protocol",
DbType: "db-type",
DbAddress: "db-address",
DbPort: "db-port",
@@ -175,6 +189,8 @@ func GetEnvNames() Flags {
LogLevel: "GTS_LOG_LEVEL",
ApplicationName: "GTS_APPLICATION_NAME",
ConfigPath: "GTS_CONFIG_PATH",
+ Host: "GTS_HOST",
+ Protocol: "GTS_PROTOCOL",
DbType: "GTS_DB_TYPE",
DbAddress: "GTS_DB_ADDRESS",
DbPort: "GTS_DB_PORT",
diff --git a/internal/db/actions.go b/internal/db/actions.go
@@ -28,9 +28,10 @@ import (
// Initialize will initialize the database given in the config for use with GoToSocial
var Initialize action.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
- db, err := New(ctx, c, log)
- if err != nil {
- return err
- }
- return db.CreateSchema(ctx)
+ // db, err := New(ctx, c, log)
+ // if err != nil {
+ // return err
+ // }
+ return nil
+ // return db.CreateSchema(ctx)
}
diff --git a/internal/db/db.go b/internal/db/db.go
@@ -30,30 +30,47 @@ import (
const dbTypePostgres string = "POSTGRES"
-// DB provides methods for interacting with an underlying database (for now, just postgres).
-// The function mapping lines up with the DB interface described in go-fed.
-// See here: https://github.com/go-fed/activity/blob/master/pub/database.go
+// DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres).
type DB interface {
- /*
- GO-FED DATABASE FUNCTIONS
- */
- pub.Database
+ // Federation returns an interface that's compatible with go-fed, for performing federation storage/retrieval functions.
+ // See: https://pkg.go.dev/github.com/go-fed/activity@v1.0.0/pub?utm_source=gopls#Database
+ Federation() pub.Database
- /*
- ANY ADDITIONAL DESIRED FUNCTIONS
- */
+ // CreateTable creates a table for the given interface
+ CreateTable(i interface{}) error
- // CreateSchema should populate the database with the required tables
- CreateSchema(context.Context) error
+ // DropTable drops the table for the given interface
+ DropTable(i interface{}) error
// Stop should stop and close the database connection cleanly, returning an error if this is not possible
- Stop(context.Context) error
+ Stop(ctx context.Context) error
// IsHealthy should return nil if the database connection is healthy, or an error if not
- IsHealthy(context.Context) error
+ IsHealthy(ctx context.Context) error
+
+ // GetByID gets one entry by its id.
+ GetByID(id string, i interface{}) error
+
+ // GetWhere gets one entry where key = value
+ GetWhere(key string, value interface{}, i interface{}) error
+
+ // GetAll gets all entries of interface type i
+ GetAll(i interface{}) error
+
+ // Put stores i
+ Put(i interface{}) error
+
+ // Update by id updates i with id id
+ UpdateByID(id string, i interface{}) error
+
+ // Delete by id removes i with id id
+ DeleteByID(id string, i interface{}) error
+
+ // Delete where deletes i where key = value
+ DeleteWhere(key string, value interface{}, i interface{}) error
}
-// New returns a new database service that satisfies the Service interface and, by extension,
+// New returns a new database service that satisfies the DB interface and, by extension,
// the go-fed database interface described here: https://github.com/go-fed/activity/blob/master/pub/database.go
func New(ctx context.Context, c *config.Config, log *logrus.Logger) (DB, error) {
switch strings.ToUpper(c.DBConfig.Type) {
diff --git a/internal/db/pg-fed.go b/internal/db/pg-fed.go
@@ -0,0 +1,137 @@
+package db
+
+import (
+ "context"
+ "errors"
+ "net/url"
+ "sync"
+
+ "github.com/go-fed/activity/pub"
+ "github.com/go-fed/activity/streams"
+ "github.com/go-fed/activity/streams/vocab"
+ "github.com/go-pg/pg/v10"
+)
+
+type postgresFederation struct {
+ locks *sync.Map
+ conn *pg.DB
+}
+
+func newPostgresFederation(conn *pg.DB) pub.Database {
+ return &postgresFederation{
+ locks: new(sync.Map),
+ conn: conn,
+ }
+}
+
+/*
+ GO-FED DB INTERFACE-IMPLEMENTING FUNCTIONS
+*/
+func (pf *postgresFederation) Lock(ctx context.Context, id *url.URL) error {
+ // Before any other Database methods are called, the relevant `id`
+ // entries are locked to allow for fine-grained concurrency.
+
+ // Strategy: create a new lock, if stored, continue. Otherwise, lock the
+ // existing mutex.
+ mu := &sync.Mutex{}
+ mu.Lock() // Optimistically lock if we do store it.
+ i, loaded := pf.locks.LoadOrStore(id.String(), mu)
+ if loaded {
+ mu = i.(*sync.Mutex)
+ mu.Lock()
+ }
+ return nil
+}
+
+func (pf *postgresFederation) Unlock(ctx context.Context, id *url.URL) error {
+ // Once Go-Fed is done calling Database methods, the relevant `id`
+ // entries are unlocked.
+
+ i, ok := pf.locks.Load(id.String())
+ if !ok {
+ return errors.New("missing an id in unlock")
+ }
+ mu := i.(*sync.Mutex)
+ mu.Unlock()
+ return nil
+}
+
+func (pf *postgresFederation) InboxContains(ctx context.Context, inbox *url.URL, id *url.URL) (bool, error) {
+ return false, nil
+}
+
+func (pf *postgresFederation) GetInbox(ctx context.Context, inboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
+ return nil, nil
+}
+
+func (pf *postgresFederation) SetInbox(ctx context.Context, inbox vocab.ActivityStreamsOrderedCollectionPage) error {
+ return nil
+}
+
+func (pf *postgresFederation) Owns(ctx context.Context, id *url.URL) (owns bool, err error) {
+ return false, nil
+}
+
+func (pf *postgresFederation) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) {
+ return nil, nil
+}
+
+func (pf *postgresFederation) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) {
+ return nil, nil
+}
+
+func (pf *postgresFederation) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) {
+ return nil, nil
+}
+
+func (pf *postgresFederation) Exists(ctx context.Context, id *url.URL) (exists bool, err error) {
+ return false, nil
+}
+
+func (pf *postgresFederation) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) {
+ return nil, nil
+}
+
+func (pf *postgresFederation) Create(ctx context.Context, asType vocab.Type) error {
+ t, err := streams.NewTypeResolver()
+ if err != nil {
+ return err
+ }
+ if err := t.Resolve(ctx, asType); err != nil {
+ return err
+ }
+ asType.GetTypeName()
+ return nil
+}
+
+func (pf *postgresFederation) Update(ctx context.Context, asType vocab.Type) error {
+ return nil
+}
+
+func (pf *postgresFederation) Delete(ctx context.Context, id *url.URL) error {
+ return nil
+}
+
+func (pf *postgresFederation) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
+ return nil, nil
+}
+
+func (pf *postgresFederation) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error {
+ return nil
+}
+
+func (pf *postgresFederation) NewID(ctx context.Context, t vocab.Type) (id *url.URL, err error) {
+ return nil, nil
+}
+
+func (pf *postgresFederation) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
+ return nil, nil
+}
+
+func (pf *postgresFederation) Following(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
+ return nil, nil
+}
+
+func (pf *postgresFederation) Liked(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
+ return nil, nil
+}
diff --git a/internal/db/pg.go b/internal/db/pg.go
@@ -0,0 +1,251 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package db
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "regexp"
+ "strings"
+ "time"
+
+ "github.com/go-fed/activity/pub"
+ "github.com/go-pg/pg/extra/pgdebug"
+ "github.com/go-pg/pg/v10"
+ "github.com/go-pg/pg/v10/orm"
+ "github.com/gotosocial/gotosocial/internal/config"
+ "github.com/gotosocial/gotosocial/internal/gtsmodel"
+ "github.com/sirupsen/logrus"
+)
+
+// postgresService satisfies the DB interface
+type postgresService struct {
+ config *config.DBConfig
+ conn *pg.DB
+ log *logrus.Entry
+ cancel context.CancelFunc
+ federationDB pub.Database
+}
+
+// newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
+// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
+func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (*postgresService, error) {
+ opts, err := derivePGOptions(c)
+ if err != nil {
+ return nil, fmt.Errorf("could not create postgres service: %s", err)
+ }
+ log.Debugf("using pg options: %+v", opts)
+
+ readyChan := make(chan interface{})
+ opts.OnConnect = func(ctx context.Context, c *pg.Conn) error {
+ close(readyChan)
+ return nil
+ }
+
+ // create a connection
+ pgCtx, cancel := context.WithCancel(ctx)
+ conn := pg.Connect(opts).WithContext(pgCtx)
+
+ // this will break the logfmt format we normally log in,
+ // since we can't choose where pg outputs to and it defaults to
+ // stdout. So use this option with care!
+ if log.Logger.GetLevel() >= logrus.TraceLevel {
+ conn.AddQueryHook(pgdebug.DebugHook{
+ // Print all queries.
+ Verbose: true,
+ })
+ }
+
+ // actually *begin* the connection so that we can tell if the db is there
+ // and listening, and also trigger the opts.OnConnect function passed in above
+ if err := conn.Ping(ctx); err != nil {
+ cancel()
+ return nil, fmt.Errorf("db connection error: %s", err)
+ }
+
+ // print out discovered postgres version
+ var version string
+ if _, err = conn.QueryOneContext(ctx, pg.Scan(&version), "SELECT version()"); err != nil {
+ cancel()
+ return nil, fmt.Errorf("db connection error: %s", err)
+ }
+ log.Infof("connected to postgres version: %s", version)
+
+ // make sure the opts.OnConnect function has been triggered
+ // and closed the ready channel
+ select {
+ case <-readyChan:
+ log.Infof("postgres connection ready")
+ case <-time.After(5 * time.Second):
+ cancel()
+ return nil, errors.New("db connection timeout")
+ }
+
+ // we can confidently return this useable postgres service now
+ return &postgresService{
+ config: c.DBConfig,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ federationDB: newPostgresFederation(conn),
+ }, nil
+}
+
+func (ps *postgresService) Federation() pub.Database {
+ return ps.federationDB
+}
+
+/*
+ HANDY STUFF
+*/
+
+// derivePGOptions takes an application config and returns either a ready-to-use *pg.Options
+// with sensible defaults, or an error if it's not satisfied by the provided config.
+func derivePGOptions(c *config.Config) (*pg.Options, error) {
+ if strings.ToUpper(c.DBConfig.Type) != dbTypePostgres {
+ return nil, fmt.Errorf("expected db type of %s but got %s", dbTypePostgres, c.DBConfig.Type)
+ }
+
+ // validate port
+ if c.DBConfig.Port == 0 {
+ return nil, errors.New("no port set")
+ }
+
+ // validate address
+ if c.DBConfig.Address == "" {
+ return nil, errors.New("no address set")
+ }
+
+ ipv4Regex := regexp.MustCompile(`^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$`)
+ hostnameRegex := regexp.MustCompile(`^(?:[a-z0-9]+(?:-[a-z0-9]+)*\.)+[a-z]{2,}$`)
+ if !hostnameRegex.MatchString(c.DBConfig.Address) && !ipv4Regex.MatchString(c.DBConfig.Address) && c.DBConfig.Address != "localhost" {
+ return nil, fmt.Errorf("address %s was neither an ipv4 address nor a valid hostname", c.DBConfig.Address)
+ }
+
+ // validate username
+ if c.DBConfig.User == "" {
+ return nil, errors.New("no user set")
+ }
+
+ // validate that there's a password
+ if c.DBConfig.Password == "" {
+ return nil, errors.New("no password set")
+ }
+
+ // validate database
+ if c.DBConfig.Database == "" {
+ return nil, errors.New("no database set")
+ }
+
+ // We can rely on the pg library we're using to set
+ // sensible defaults for everything we don't set here.
+ options := &pg.Options{
+ Addr: fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port),
+ User: c.DBConfig.User,
+ Password: c.DBConfig.Password,
+ Database: c.DBConfig.Database,
+ ApplicationName: c.ApplicationName,
+ }
+
+ return options, nil
+}
+
+/*
+ EXTRA FUNCTIONS
+*/
+
+func (ps *postgresService) Stop(ctx context.Context) error {
+ ps.log.Info("closing db connection")
+ if err := ps.conn.Close(); err != nil {
+ // only cancel if there's a problem closing the db
+ ps.cancel()
+ return err
+ }
+ return nil
+}
+
+func (ps *postgresService) CreateSchema(ctx context.Context) error {
+ models := []interface{}{
+ (*gtsmodel.Account)(nil),
+ (*gtsmodel.Status)(nil),
+ (*gtsmodel.User)(nil),
+ }
+ ps.log.Info("creating db schema")
+
+ for _, model := range models {
+ err := ps.conn.Model(model).CreateTable(&orm.CreateTableOptions{
+ IfNotExists: true,
+ })
+ if err != nil {
+ return err
+ }
+ }
+
+ ps.log.Info("db schema created")
+ return nil
+}
+
+func (ps *postgresService) IsHealthy(ctx context.Context) error {
+ return ps.conn.Ping(ctx)
+}
+
+func (ps *postgresService) CreateTable(i interface{}) error {
+ return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
+ IfNotExists: true,
+ })
+}
+
+func (ps *postgresService) DropTable(i interface{}) error {
+ return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
+ IfExists: true,
+ })
+}
+
+func (ps *postgresService) GetByID(id string, i interface{}) error {
+ return ps.conn.Model(i).Where("id = ?", id).Select()
+}
+
+func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error {
+ return ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select()
+}
+
+func (ps *postgresService) GetAll(i interface{}) error {
+ return ps.conn.Model(i).Select()
+}
+
+func (ps *postgresService) Put(i interface{}) error {
+ _, err := ps.conn.Model(i).Insert(i)
+ return err
+}
+
+func (ps *postgresService) UpdateByID(id string, i interface{}) error {
+ _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert()
+ return err
+}
+
+func (ps *postgresService) DeleteByID(id string, i interface{}) error {
+ _, err := ps.conn.Model(i).Where("id = ?", id).Delete()
+ return err
+}
+
+func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error {
+ _, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete()
+ return err
+}
diff --git a/internal/db/postgres.go b/internal/db/postgres.go
@@ -1,343 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see <http://www.gnu.org/licenses/>.
-*/
-
-package db
-
-import (
- "context"
- "errors"
- "fmt"
- "net/url"
- "regexp"
- "strings"
- "sync"
- "time"
-
- "github.com/go-fed/activity/streams"
- "github.com/go-fed/activity/streams/vocab"
- "github.com/go-pg/pg/extra/pgdebug"
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
- "github.com/gotosocial/gotosocial/internal/config"
- "github.com/gotosocial/gotosocial/internal/gtsmodel"
- "github.com/gotosocial/oauth2/v4"
- "github.com/sirupsen/logrus"
-)
-
-type postgresService struct {
- config *config.DBConfig
- conn *pg.DB
- log *logrus.Entry
- cancel context.CancelFunc
- locks *sync.Map
- tokenStore oauth2.TokenStore
-}
-
-// newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
-// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
-func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (*postgresService, error) {
- opts, err := derivePGOptions(c)
- if err != nil {
- return nil, fmt.Errorf("could not create postgres service: %s", err)
- }
- log.Debugf("using pg options: %+v", opts)
-
- readyChan := make(chan interface{})
- opts.OnConnect = func(ctx context.Context, c *pg.Conn) error {
- close(readyChan)
- return nil
- }
-
- // create a connection
- pgCtx, cancel := context.WithCancel(ctx)
- conn := pg.Connect(opts).WithContext(pgCtx)
-
- // this will break the logfmt format we normally log in,
- // since we can't choose where pg outputs to and it defaults to
- // stdout. So use this option with care!
- if log.Logger.GetLevel() >= logrus.TraceLevel {
- conn.AddQueryHook(pgdebug.DebugHook{
- // Print all queries.
- Verbose: true,
- })
- }
-
- // actually *begin* the connection so that we can tell if the db is there
- // and listening, and also trigger the opts.OnConnect function passed in above
- if err := conn.Ping(ctx); err != nil {
- cancel()
- return nil, fmt.Errorf("db connection error: %s", err)
- }
-
- // print out discovered postgres version
- var version string
- if _, err = conn.QueryOneContext(ctx, pg.Scan(&version), "SELECT version()"); err != nil {
- cancel()
- return nil, fmt.Errorf("db connection error: %s", err)
- }
- log.Infof("connected to postgres version: %s", version)
-
- // make sure the opts.OnConnect function has been triggered
- // and closed the ready channel
- select {
- case <-readyChan:
- log.Infof("postgres connection ready")
- case <-time.After(5 * time.Second):
- cancel()
- return nil, errors.New("db connection timeout")
- }
-
- // acc := model.StubAccount()
- // if _, err := conn.Model(acc).Returning("id").Insert(); err != nil {
- // cancel()
- // return nil, fmt.Errorf("db insert error: %s", err)
- // }
- // log.Infof("created account with id %s", acc.ID)
-
- // note := &model.Note{
- // Visibility: &model.Visibility{
- // Local: true,
- // },
- // CreatedAt: time.Now(),
- // UpdatedAt: time.Now(),
- // }
- // if _, err := conn.WithContext(ctx).Model(note).Returning("id").Insert(); err != nil {
- // cancel()
- // return nil, fmt.Errorf("db insert error: %s", err)
- // }
- // log.Infof("created note with id %s", note.ID)
-
- // we can confidently return this useable postgres service now
- return &postgresService{
- config: c.DBConfig,
- conn: conn,
- log: log,
- cancel: cancel,
- locks: &sync.Map{},
- }, nil
-}
-
-/*
- HANDY STUFF
-*/
-
-// derivePGOptions takes an application config and returns either a ready-to-use *pg.Options
-// with sensible defaults, or an error if it's not satisfied by the provided config.
-func derivePGOptions(c *config.Config) (*pg.Options, error) {
- if strings.ToUpper(c.DBConfig.Type) != dbTypePostgres {
- return nil, fmt.Errorf("expected db type of %s but got %s", dbTypePostgres, c.DBConfig.Type)
- }
-
- // validate port
- if c.DBConfig.Port == 0 {
- return nil, errors.New("no port set")
- }
-
- // validate address
- if c.DBConfig.Address == "" {
- return nil, errors.New("no address set")
- }
-
- ipv4Regex := regexp.MustCompile(`^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$`)
- hostnameRegex := regexp.MustCompile(`^(?:[a-z0-9]+(?:-[a-z0-9]+)*\.)+[a-z]{2,}$`)
- if !hostnameRegex.MatchString(c.DBConfig.Address) && !ipv4Regex.MatchString(c.DBConfig.Address) && c.DBConfig.Address != "localhost" {
- return nil, fmt.Errorf("address %s was neither an ipv4 address nor a valid hostname", c.DBConfig.Address)
- }
-
- // validate username
- if c.DBConfig.User == "" {
- return nil, errors.New("no user set")
- }
-
- // validate that there's a password
- if c.DBConfig.Password == "" {
- return nil, errors.New("no password set")
- }
-
- // validate database
- if c.DBConfig.Database == "" {
- return nil, errors.New("no database set")
- }
-
- // We can rely on the pg library we're using to set
- // sensible defaults for everything we don't set here.
- options := &pg.Options{
- Addr: fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port),
- User: c.DBConfig.User,
- Password: c.DBConfig.Password,
- Database: c.DBConfig.Database,
- ApplicationName: c.ApplicationName,
- }
-
- return options, nil
-}
-
-/*
- GO-FED DB INTERFACE-IMPLEMENTING FUNCTIONS
-*/
-func (ps *postgresService) Lock(ctx context.Context, id *url.URL) error {
- // Before any other Database methods are called, the relevant `id`
- // entries are locked to allow for fine-grained concurrency.
-
- // Strategy: create a new lock, if stored, continue. Otherwise, lock the
- // existing mutex.
- mu := &sync.Mutex{}
- mu.Lock() // Optimistically lock if we do store it.
- i, loaded := ps.locks.LoadOrStore(id.String(), mu)
- if loaded {
- mu = i.(*sync.Mutex)
- mu.Lock()
- }
- return nil
-}
-
-func (ps *postgresService) Unlock(ctx context.Context, id *url.URL) error {
- // Once Go-Fed is done calling Database methods, the relevant `id`
- // entries are unlocked.
-
- i, ok := ps.locks.Load(id.String())
- if !ok {
- return errors.New("missing an id in unlock")
- }
- mu := i.(*sync.Mutex)
- mu.Unlock()
- return nil
-}
-
-func (ps *postgresService) InboxContains(ctx context.Context, inbox *url.URL, id *url.URL) (bool, error) {
- return false, nil
-}
-
-func (ps *postgresService) GetInbox(ctx context.Context, inboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
- return nil, nil
-}
-
-func (ps *postgresService) SetInbox(ctx context.Context, inbox vocab.ActivityStreamsOrderedCollectionPage) error {
- return nil
-}
-
-func (ps *postgresService) Owns(ctx context.Context, id *url.URL) (owns bool, err error) {
- return false, nil
-}
-
-func (ps *postgresService) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) {
- return nil, nil
-}
-
-func (ps *postgresService) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) {
- return nil, nil
-}
-
-func (ps *postgresService) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) {
- return nil, nil
-}
-
-func (ps *postgresService) Exists(ctx context.Context, id *url.URL) (exists bool, err error) {
- return false, nil
-}
-
-func (ps *postgresService) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) {
- return nil, nil
-}
-
-func (ps *postgresService) Create(ctx context.Context, asType vocab.Type) error {
- t, err := streams.NewTypeResolver()
- if err != nil {
- return err
- }
- if err := t.Resolve(ctx, asType); err != nil {
- return err
- }
- asType.GetTypeName()
- return nil
-}
-
-func (ps *postgresService) Update(ctx context.Context, asType vocab.Type) error {
- return nil
-}
-
-func (ps *postgresService) Delete(ctx context.Context, id *url.URL) error {
- return nil
-}
-
-func (ps *postgresService) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
- return nil, nil
-}
-
-func (ps *postgresService) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error {
- return nil
-}
-
-func (ps *postgresService) NewID(ctx context.Context, t vocab.Type) (id *url.URL, err error) {
- return nil, nil
-}
-
-func (ps *postgresService) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
- return nil, nil
-}
-
-func (ps *postgresService) Following(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
- return nil, nil
-}
-
-func (ps *postgresService) Liked(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
- return nil, nil
-}
-
-/*
- EXTRA FUNCTIONS
-*/
-
-func (ps *postgresService) Stop(ctx context.Context) error {
- ps.log.Info("closing db connection")
- if err := ps.conn.Close(); err != nil {
- // only cancel if there's a problem closing the db
- ps.cancel()
- return err
- }
- return nil
-}
-
-func (ps *postgresService) CreateSchema(ctx context.Context) error {
- models := []interface{}{
- (*gtsmodel.Account)(nil),
- (*gtsmodel.Status)(nil),
- (*gtsmodel.User)(nil),
- }
- ps.log.Info("creating db schema")
-
- for _, model := range models {
- err := ps.conn.Model(model).CreateTable(&orm.CreateTableOptions{
- IfNotExists: true,
- })
- if err != nil {
- return err
- }
- }
-
- ps.log.Info("db schema created")
- return nil
-}
-
-func (ps *postgresService) IsHealthy(ctx context.Context) error {
- return ps.conn.Ping(ctx)
-}
-
-func (ps *postgresService) TokenStore() oauth2.TokenStore {
- return ps.tokenStore
-}
diff --git a/internal/email/email.go b/internal/email/email.go
@@ -16,5 +16,5 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-// package email provides a service for interacting with an SMTP server
+// Package email provides a service for interacting with an SMTP server
package email
diff --git a/internal/federation/federation.go b/internal/federation/federation.go
@@ -30,11 +30,13 @@ import (
"github.com/gotosocial/gotosocial/internal/db"
)
+// New returns a go-fed compatible federating actor
func New(db db.DB) pub.FederatingActor {
fa := &API{}
- return pub.NewFederatingActor(fa, fa, db, fa)
+ return pub.NewFederatingActor(fa, fa, db.Federation(), fa)
}
+// API implements several go-fed interfaces in one convenient location
type API struct {
}
diff --git a/internal/gotosocial/actions.go b/internal/gotosocial/actions.go
@@ -38,9 +38,9 @@ var Run action.GTSAction = func(ctx context.Context, c *config.Config, log *logr
return fmt.Errorf("error creating dbservice: %s", err)
}
- if err := dbService.CreateSchema(ctx); err != nil {
- return fmt.Errorf("error creating dbschema: %s", err)
- }
+ // if err := dbService.CreateSchema(ctx); err != nil {
+ // return fmt.Errorf("error creating dbschema: %s", err)
+ // }
// catch shutdown signals from the operating system
sigs := make(chan os.Signal, 1)
diff --git a/internal/gotosocial/gotosocial.go b/internal/gotosocial/gotosocial.go
@@ -22,10 +22,10 @@ import (
"context"
"github.com/go-fed/activity/pub"
- "github.com/gotosocial/gotosocial/internal/api"
"github.com/gotosocial/gotosocial/internal/cache"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/db"
+ "github.com/gotosocial/gotosocial/internal/router"
)
type Gotosocial interface {
@@ -33,11 +33,11 @@ type Gotosocial interface {
Stop(context.Context) error
}
-func New(db db.DB, cache cache.Cache, clientAPI api.Server, federationAPI pub.FederatingActor, config *config.Config) (Gotosocial, error) {
+func New(db db.DB, cache cache.Cache, apiRouter router.Router, federationAPI pub.FederatingActor, config *config.Config) (Gotosocial, error) {
return &gotosocial{
db: db,
cache: cache,
- clientAPI: clientAPI,
+ apiRouter: apiRouter,
federationAPI: federationAPI,
config: config,
}, nil
@@ -46,7 +46,7 @@ func New(db db.DB, cache cache.Cache, clientAPI api.Server, federationAPI pub.Fe
type gotosocial struct {
db db.DB
cache cache.Cache
- clientAPI api.Server
+ apiRouter router.Router
federationAPI pub.FederatingActor
config *config.Config
}
diff --git a/internal/gtsmodel/account.go b/internal/gtsmodel/account.go
@@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-// package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database.
+// Package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database.
// These types should never be serialized and/or sent out via public APIs, as they contain sensitive information.
// The annotation used on these structs is for handling them via the go-pg ORM. See here: https://pg.uptrace.dev/models/
package gtsmodel
diff --git a/internal/gtsmodel/application.go b/internal/gtsmodel/application.go
@@ -18,13 +18,38 @@
package gtsmodel
+import "github.com/gotosocial/gotosocial/pkg/mastotypes"
+
+// Application represents an application that can perform actions on behalf of a user.
+// It is used to authorize tokens etc, and is associated with an oauth client id in the database.
type Application struct {
- ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
- Name string
- Website string
- RedirectURI string `json:"redirect_uri"`
- ClientID string `json:"client_id"`
- ClientSecret string `json:"client_secret"`
- Scopes string `json:"scopes"`
- VapidKey string `json:"vapid_key"`
+ // id of this application in the db
+ ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
+ // name of the application given when it was created (eg., 'tusky')
+ Name string
+ // website for the application given when it was created (eg., 'https://tusky.app')
+ Website string
+ // redirect uri requested by the application for oauth2 flow
+ RedirectURI string
+ // id of the associated oauth client entity in the db
+ ClientID string
+ // secret of the associated oauth client entity in the db
+ ClientSecret string
+ // scopes requested when this app was created
+ Scopes string
+ // a vapid key generated for this app when it was created
+ VapidKey string
+}
+
+// ToMastotype returns this application as a mastodon api type, ready for serialization
+func (a *Application) ToMastotype() *mastotypes.Application {
+ return &mastotypes.Application{
+ ID: a.ID,
+ Name: a.Name,
+ Website: a.Website,
+ RedirectURI: a.RedirectURI,
+ ClientID: a.ClientID,
+ ClientSecret: a.ClientSecret,
+ VapidKey: a.VapidKey,
+ }
}
diff --git a/internal/gtsmodel/status.go b/internal/gtsmodel/status.go
@@ -20,25 +20,44 @@ package gtsmodel
import "time"
+// Status represents a user-created 'post' or 'status' in the database, either remote or local
type Status struct {
- ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
- URI string `pg:",unique"`
- URL string `pg:",unique"`
- Content string
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
- Local bool
- AccountID string
- InReplyToID string
- BoostOfID string
+ // id of the status in the database
+ ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
+ // uri at which this status is reachable
+ URI string `pg:",unique"`
+ // web url for viewing this status
+ URL string `pg:",unique"`
+ // the html-formatted content of this status
+ Content string
+ // when was this status created?
+ CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ // when was this status updated?
+ UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ // is this status from a local account?
+ Local bool
+ // which account posted this status?
+ AccountID string
+ // id of the status this status is a reply to
+ InReplyToID string
+ // id of the status this status is a boost of
+ BoostOfID string
+ // cw string for this status
ContentWarning string
- Visibility *Visibility
+ // visibility entry for this status
+ Visibility *Visibility
}
+// Visibility represents the visibility granularity of a status. It is a combination of flags.
type Visibility struct {
- Direct bool
+ // Is this status viewable as a direct message?
+ Direct bool
+ // Is this status viewable to followers?
Followers bool
- Local bool
- Unlisted bool
- Public bool
+ // Is this status viewable on the local timeline?
+ Local bool
+ // Is this status boostable but not shown on public timelines?
+ Unlisted bool
+ // Is this status shown on public and federated timelines?
+ Public bool
}
diff --git a/internal/module/account/account.go b/internal/module/account/account.go
@@ -0,0 +1,37 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package account
+
+import (
+ "github.com/gotosocial/gotosocial/internal/module"
+ "github.com/gotosocial/gotosocial/internal/router"
+)
+
+type accountModule struct {
+}
+
+// New returns a new account module
+func New() module.ClientAPIModule {
+ return &accountModule{}
+}
+
+// Route attaches all routes from this module to the given router
+func (m *accountModule) Route(r router.Router) error {
+ return nil
+}
diff --git a/internal/module/module.go b/internal/module/module.go
@@ -0,0 +1,29 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+// Package module is basically a wrapper for a lot of modules (in subdirectories) that satisfy the ClientAPIModule interface.
+package module
+
+import "github.com/gotosocial/gotosocial/internal/router"
+
+// ClientAPIModule represents a chunk of code (usually contained in a single package) that adds a set
+// of functionalities and side effects to a router, by mapping routes and handlers onto it--in other words, a REST API ;)
+// A ClientAPIMpdule corresponds roughly to one main path of the gotosocial REST api, for example /api/v1/accounts/ or /oauth/
+type ClientAPIModule interface {
+ Route(s router.Router) error
+}
diff --git a/internal/module/oauth/README.md b/internal/module/oauth/README.md
@@ -0,0 +1,5 @@
+# oauth
+
+This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) functionality to the GoToSocial client API.
+
+It also provides a handler/middleware for attaching to the Gin engine for validating authenticated users.
diff --git a/internal/module/oauth/clientstore.go b/internal/module/oauth/clientstore.go
@@ -0,0 +1,73 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package oauth
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/gotosocial/gotosocial/internal/db"
+ "github.com/gotosocial/oauth2/v4"
+ "github.com/gotosocial/oauth2/v4/models"
+)
+
+type clientStore struct {
+ db db.DB
+}
+
+func newClientStore(db db.DB) oauth2.ClientStore {
+ pts := &clientStore{
+ db: db,
+ }
+ return pts
+}
+
+func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
+ poc := &oauthClient{
+ ID: clientID,
+ }
+ if err := cs.db.GetByID(clientID, poc); err != nil {
+ return nil, fmt.Errorf("database error: %s", err)
+ }
+ return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil
+}
+
+func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error {
+ poc := &oauthClient{
+ ID: cli.GetID(),
+ Secret: cli.GetSecret(),
+ Domain: cli.GetDomain(),
+ UserID: cli.GetUserID(),
+ }
+ return cs.db.UpdateByID(id, poc)
+}
+
+func (cs *clientStore) Delete(ctx context.Context, id string) error {
+ poc := &oauthClient{
+ ID: id,
+ }
+ return cs.db.DeleteByID(id, poc)
+}
+
+type oauthClient struct {
+ ID string
+ Secret string
+ Domain string
+ UserID string
+}
diff --git a/internal/module/oauth/clientstore_test.go b/internal/module/oauth/clientstore_test.go
@@ -0,0 +1,144 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+package oauth
+
+import (
+ "context"
+ "testing"
+
+ "github.com/gotosocial/gotosocial/internal/config"
+ "github.com/gotosocial/gotosocial/internal/db"
+ "github.com/gotosocial/oauth2/v4/models"
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/suite"
+)
+
+type PgClientStoreTestSuite struct {
+ suite.Suite
+ db db.DB
+ testClientID string
+ testClientSecret string
+ testClientDomain string
+ testClientUserID string
+}
+
+const ()
+
+// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
+func (suite *PgClientStoreTestSuite) SetupSuite() {
+ suite.testClientID = "test-client-id"
+ suite.testClientSecret = "test-client-secret"
+ suite.testClientDomain = "https://example.org"
+ suite.testClientUserID = "test-client-user-id"
+}
+
+// SetupTest creates a postgres connection and creates the oauth_clients table before each test
+func (suite *PgClientStoreTestSuite) SetupTest() {
+ log := logrus.New()
+ log.SetLevel(logrus.TraceLevel)
+ c := config.Empty()
+ c.DBConfig = &config.DBConfig{
+ Type: "postgres",
+ Address: "localhost",
+ Port: 5432,
+ User: "postgres",
+ Password: "postgres",
+ Database: "postgres",
+ ApplicationName: "gotosocial",
+ }
+ db, err := db.New(context.Background(), c, log)
+ if err != nil {
+ logrus.Panicf("error creating database connection: %s", err)
+ }
+
+ suite.db = db
+
+ models := []interface{}{
+ &oauthClient{},
+ }
+
+ for _, m := range models {
+ if err := suite.db.CreateTable(m); err != nil {
+ logrus.Panicf("db connection error: %s", err)
+ }
+ }
+}
+
+// TearDownTest drops the oauth_clients table and closes the pg connection after each test
+func (suite *PgClientStoreTestSuite) TearDownTest() {
+ models := []interface{}{
+ &oauthClient{},
+ }
+ for _, m := range models {
+ if err := suite.db.DropTable(m); err != nil {
+ logrus.Panicf("error dropping table: %s", err)
+ }
+ }
+ if err := suite.db.Stop(context.Background()); err != nil {
+ logrus.Panicf("error closing db connection: %s", err)
+ }
+ suite.db = nil
+}
+
+func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() {
+ // set a new client in the store
+ cs := newClientStore(suite.db)
+ if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // fetch that client from the store
+ client, err := cs.GetByID(context.Background(), suite.testClientID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // check that the values are the same
+ suite.NotNil(client)
+ suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client)
+}
+
+func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
+ // set a new client in the store
+ cs := newClientStore(suite.db)
+ if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // fetch the client from the store
+ client, err := cs.GetByID(context.Background(), suite.testClientID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // check that the values are the same
+ suite.NotNil(client)
+ suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client)
+ if err := cs.Delete(context.Background(), suite.testClientID); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // try to get the deleted client; we should get an error
+ deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
+ suite.Assert().Nil(deletedClient)
+ suite.Assert().NotNil(err)
+}
+
+func TestPgClientStoreTestSuite(t *testing.T) {
+ suite.Run(t, new(PgClientStoreTestSuite))
+}
diff --git a/internal/module/oauth/oauth.go b/internal/module/oauth/oauth.go
@@ -0,0 +1,510 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+// Package oauth is a module that provides oauth functionality to a router.
+// It adds the following paths:
+// /api/v1/apps
+// /auth/sign_in
+// /oauth/token
+// /oauth/authorize
+// It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token.
+package oauth
+
+import (
+ "fmt"
+ "net/http"
+ "net/url"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+ "github.com/google/uuid"
+ "github.com/gotosocial/gotosocial/internal/db"
+ "github.com/gotosocial/gotosocial/internal/gtsmodel"
+ "github.com/gotosocial/gotosocial/internal/module"
+ "github.com/gotosocial/gotosocial/internal/router"
+ "github.com/gotosocial/gotosocial/pkg/mastotypes"
+ "github.com/gotosocial/oauth2/v4"
+ "github.com/gotosocial/oauth2/v4/errors"
+ "github.com/gotosocial/oauth2/v4/manage"
+ "github.com/gotosocial/oauth2/v4/server"
+ "github.com/sirupsen/logrus"
+ "golang.org/x/crypto/bcrypt"
+)
+
+const (
+ appsPath = "/api/v1/apps"
+ authSignInPath = "/auth/sign_in"
+ oauthTokenPath = "/oauth/token"
+ oauthAuthorizePath = "/oauth/authorize"
+)
+
+// oauthModule is an oauth2 oauthModule that satisfies the ClientAPIModule interface
+type oauthModule struct {
+ oauthManager *manage.Manager
+ oauthServer *server.Server
+ db db.DB
+ log *logrus.Logger
+}
+
+type login struct {
+ Email string `form:"username"`
+ Password string `form:"password"`
+}
+
+// New returns a new oauth module
+func New(ts oauth2.TokenStore, cs oauth2.ClientStore, db db.DB, log *logrus.Logger) module.ClientAPIModule {
+ manager := manage.NewDefaultManager()
+ manager.MapTokenStorage(ts)
+ manager.MapClientStorage(cs)
+ manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
+ sc := &server.Config{
+ TokenType: "Bearer",
+ // Must follow the spec.
+ AllowGetAccessRequest: false,
+ // Support only the non-implicit flow.
+ AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
+ // Allow:
+ // - Authorization Code (for first & third parties)
+ AllowedGrantTypes: []oauth2.GrantType{
+ oauth2.AuthorizationCode,
+ },
+ AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
+ }
+
+ srv := server.NewServer(sc, manager)
+ srv.SetInternalErrorHandler(func(err error) *errors.Response {
+ log.Errorf("internal oauth error: %s", err)
+ return nil
+ })
+
+ srv.SetResponseErrorHandler(func(re *errors.Response) {
+ log.Errorf("internal response error: %s", re.Error)
+ })
+
+ m := &oauthModule{
+ oauthManager: manager,
+ oauthServer: srv,
+ db: db,
+ log: log,
+ }
+
+ m.oauthServer.SetUserAuthorizationHandler(m.userAuthorizationHandler)
+ m.oauthServer.SetClientInfoHandler(server.ClientFormHandler)
+ return m
+}
+
+// Route satisfies the RESTAPIModule interface
+func (m *oauthModule) Route(s router.Router) error {
+ s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
+
+ s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler)
+ s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler)
+
+ s.AttachHandler(http.MethodPost, oauthTokenPath, m.tokenPOSTHandler)
+
+ s.AttachHandler(http.MethodGet, oauthAuthorizePath, m.authorizeGETHandler)
+ s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler)
+
+ s.AttachMiddleware(m.oauthTokenMiddleware)
+
+ return nil
+}
+
+/*
+ MAIN HANDLERS -- serve these through a server/router
+*/
+
+// appsPOSTHandler should be served at https://example.org/api/v1/apps
+// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
+func (m *oauthModule) appsPOSTHandler(c *gin.Context) {
+ l := m.log.WithField("func", "AppsPOSTHandler")
+ l.Trace("entering AppsPOSTHandler")
+
+ form := &mastotypes.ApplicationPOSTRequest{}
+ if err := c.ShouldBind(form); err != nil {
+ c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
+ return
+ }
+
+ // permitted length for most fields
+ permittedLength := 64
+ // redirect can be a bit bigger because we probably need to encode data in the redirect uri
+ permittedRedirect := 256
+
+ // check lengths of fields before proceeding so the user can't spam huge entries into the database
+ if len(form.ClientName) > permittedLength {
+ c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
+ return
+ }
+ if len(form.Website) > permittedLength {
+ c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
+ return
+ }
+ if len(form.RedirectURIs) > permittedRedirect {
+ c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
+ return
+ }
+ if len(form.Scopes) > permittedLength {
+ c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
+ return
+ }
+
+ // set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
+ var scopes string
+ if form.Scopes == "" {
+ scopes = "read"
+ } else {
+ scopes = form.Scopes
+ }
+
+ // generate new IDs for this application and its associated client
+ clientID := uuid.NewString()
+ clientSecret := uuid.NewString()
+ vapidKey := uuid.NewString()
+
+ // generate the application to put in the database
+ app := >smodel.Application{
+ Name: form.ClientName,
+ Website: form.Website,
+ RedirectURI: form.RedirectURIs,
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ Scopes: scopes,
+ VapidKey: vapidKey,
+ }
+
+ // chuck it in the db
+ if err := m.db.Put(app); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ // now we need to model an oauth client from the application that the oauth library can use
+ oc := &oauthClient{
+ ID: clientID,
+ Secret: clientSecret,
+ Domain: form.RedirectURIs,
+ UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
+ }
+
+ // chuck it in the db
+ if err := m.db.Put(oc); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ // done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
+ c.JSON(http.StatusOK, app.ToMastotype())
+}
+
+// signInGETHandler should be served at https://example.org/auth/sign_in.
+// The idea is to present a sign in page to the user, where they can enter their username and password.
+// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler
+func (m *oauthModule) signInGETHandler(c *gin.Context) {
+ m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html")
+ c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{})
+}
+
+// signInPOSTHandler should be served at https://example.org/auth/sign_in.
+// The idea is to present a sign in page to the user, where they can enter their username and password.
+// The handler will then redirect to the auth handler served at /auth
+func (m *oauthModule) signInPOSTHandler(c *gin.Context) {
+ l := m.log.WithField("func", "SignInPOSTHandler")
+ s := sessions.Default(c)
+ form := &login{}
+ if err := c.ShouldBind(form); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+ return
+ }
+ l.Tracef("parsed form: %+v", form)
+
+ userid, err := m.validatePassword(form.Email, form.Password)
+ if err != nil {
+ c.String(http.StatusForbidden, err.Error())
+ return
+ }
+
+ s.Set("userid", userid)
+ if err := s.Save(); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ l.Trace("redirecting to auth page")
+ c.Redirect(http.StatusFound, oauthAuthorizePath)
+}
+
+// tokenPOSTHandler should be served as a POST at https://example.org/oauth/token
+// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs.
+// See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token
+func (m *oauthModule) tokenPOSTHandler(c *gin.Context) {
+ l := m.log.WithField("func", "TokenPOSTHandler")
+ l.Trace("entered TokenPOSTHandler")
+ if err := m.oauthServer.HandleTokenRequest(c.Writer, c.Request); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ }
+}
+
+// authorizeGETHandler should be served as GET at https://example.org/oauth/authorize
+// The idea here is to present an oauth authorize page to the user, with a button
+// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
+func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
+ l := m.log.WithField("func", "AuthorizeGETHandler")
+ s := sessions.Default(c)
+
+ // UserID will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow
+ // If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page.
+ userID, ok := s.Get("userid").(string)
+ if !ok || userID == "" {
+ l.Trace("userid was empty, parsing form then redirecting to sign in page")
+ if err := parseAuthForm(c, l); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+ } else {
+ c.Redirect(http.StatusFound, authSignInPath)
+ }
+ return
+ }
+
+ // We can use the client_id on the session to retrieve info about the app associated with the client_id
+ clientID, ok := s.Get("client_id").(string)
+ if !ok || clientID == "" {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"})
+ return
+ }
+ app := >smodel.Application{
+ ClientID: clientID,
+ }
+ if err := m.db.GetWhere("client_id", app.ClientID, app); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
+ return
+ }
+
+ // we can also use the userid of the user to fetch their username from the db to greet them nicely <3
+ user := >smodel.User{
+ ID: userID,
+ }
+ if err := m.db.GetByID(user.ID, user); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ acct := >smodel.Account{
+ ID: user.AccountID,
+ }
+
+ if err := m.db.GetByID(acct.ID, acct); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ // Finally we should also get the redirect and scope of this particular request, as stored in the session.
+ redirect, ok := s.Get("redirect_uri").(string)
+ if !ok || redirect == "" {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "no redirect_uri found in session"})
+ return
+ }
+ scope, ok := s.Get("scope").(string)
+ if !ok || scope == "" {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "no scope found in session"})
+ return
+ }
+
+ // the authorize template will display a form to the user where they can get some information
+ // about the app that's trying to authorize, and the scope of the request.
+ // They can then approve it if it looks OK to them, which will POST to the AuthorizePOSTHandler
+ l.Trace("serving authorize html")
+ c.HTML(http.StatusOK, "authorize.tmpl", gin.H{
+ "appname": app.Name,
+ "appwebsite": app.Website,
+ "redirect": redirect,
+ "scope": scope,
+ "user": acct.Username,
+ })
+}
+
+// authorizePOSTHandler should be served as POST at https://example.org/oauth/authorize
+// At this point we assume that the user has A) logged in and B) accepted that the app should act for them,
+// so we should proceed with the authentication flow and generate an oauth token for them if we can.
+// See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
+func (m *oauthModule) authorizePOSTHandler(c *gin.Context) {
+ l := m.log.WithField("func", "AuthorizePOSTHandler")
+ s := sessions.Default(c)
+
+ // At this point we know the user has said 'yes' to allowing the application and oauth client
+ // work for them, so we can set the
+
+ // We need to retrieve the original form submitted to the authorizeGEThandler, and
+ // recreate it on the request so that it can be used further by the oauth2 library.
+ // So first fetch all the values from the session.
+ forceLogin, ok := s.Get("force_login").(string)
+ if !ok {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "session missing force_login"})
+ return
+ }
+ responseType, ok := s.Get("response_type").(string)
+ if !ok || responseType == "" {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "session missing response_type"})
+ return
+ }
+ clientID, ok := s.Get("client_id").(string)
+ if !ok || clientID == "" {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "session missing client_id"})
+ return
+ }
+ redirectURI, ok := s.Get("redirect_uri").(string)
+ if !ok || redirectURI == "" {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "session missing redirect_uri"})
+ return
+ }
+ scope, ok := s.Get("scope").(string)
+ if !ok {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "session missing scope"})
+ return
+ }
+ userID, ok := s.Get("userid").(string)
+ if !ok {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "session missing userid"})
+ return
+ }
+ // we're done with the session so we can clear it now
+ s.Clear()
+
+ // now set the values on the request
+ values := url.Values{}
+ values.Set("force_login", forceLogin)
+ values.Set("response_type", responseType)
+ values.Set("client_id", clientID)
+ values.Set("redirect_uri", redirectURI)
+ values.Set("scope", scope)
+ values.Set("userid", userID)
+ c.Request.Form = values
+ l.Tracef("values on request set to %+v", c.Request.Form)
+
+ // and proceed with authorization using the oauth2 library
+ if err := m.oauthServer.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+ }
+}
+
+/*
+ MIDDLEWARE
+*/
+
+// oauthTokenMiddleware
+func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) {
+ l := m.log.WithField("func", "ValidatePassword")
+ l.Trace("entering OauthTokenMiddleware")
+ if ti, err := m.oauthServer.ValidationBearerToken(c.Request); err == nil {
+ l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope())
+ c.Set("authenticated_user", ti.GetUserID())
+
+ } else {
+ l.Trace("continuing with unauthenticated request")
+ }
+}
+
+/*
+ SUB-HANDLERS -- don't serve these directly, they should be attached to the oauth2 server or used inside handler funcs
+*/
+
+// validatePassword takes an email address and a password.
+// The goal is to authenticate the password against the one for that email
+// address stored in the database. If OK, we return the userid (a uuid) for that user,
+// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
+func (m *oauthModule) validatePassword(email string, password string) (userid string, err error) {
+ l := m.log.WithField("func", "ValidatePassword")
+
+ // make sure an email/password was provided and bail if not
+ if email == "" || password == "" {
+ l.Debug("email or password was not provided")
+ return incorrectPassword()
+ }
+
+ // first we select the user from the database based on email address, bail if no user found for that email
+ gtsUser := >smodel.User{}
+
+ if err := m.db.GetWhere("email", email, gtsUser); err != nil {
+ l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
+ return incorrectPassword()
+ }
+
+ // make sure a password is actually set and bail if not
+ if gtsUser.EncryptedPassword == "" {
+ l.Warnf("encrypted password for user %s was empty for some reason", gtsUser.Email)
+ return incorrectPassword()
+ }
+
+ // compare the provided password with the encrypted one from the db, bail if they don't match
+ if err := bcrypt.CompareHashAndPassword([]byte(gtsUser.EncryptedPassword), []byte(password)); err != nil {
+ l.Debugf("password hash didn't match for user %s during login attempt: %s", gtsUser.Email, err)
+ return incorrectPassword()
+ }
+
+ // If we've made it this far the email/password is correct, so we can just return the id of the user.
+ userid = gtsUser.ID
+ l.Tracef("returning (%s, %s)", userid, err)
+ return
+}
+
+// incorrectPassword is just a little helper function to use in the ValidatePassword function
+func incorrectPassword() (string, error) {
+ return "", errors.New("password/email combination was incorrect")
+}
+
+// userAuthorizationHandler gets the user's ID from the 'userid' field of the request form,
+// or redirects to the /auth/sign_in page, if this key is not present.
+func (m *oauthModule) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) {
+ l := m.log.WithField("func", "UserAuthorizationHandler")
+ userID = r.FormValue("userid")
+ if userID == "" {
+ return "", errors.New("userid was empty, redirecting to sign in page")
+ }
+ l.Tracef("returning userID %s", userID)
+ return userID, err
+}
+
+// parseAuthForm parses the OAuthAuthorize form in the gin context, and stores
+// the values in the form into the session.
+func parseAuthForm(c *gin.Context, l *logrus.Entry) error {
+ s := sessions.Default(c)
+
+ // first make sure they've filled out the authorize form with the required values
+ form := &mastotypes.OAuthAuthorize{}
+ if err := c.ShouldBind(form); err != nil {
+ return err
+ }
+ l.Tracef("parsed form: %+v", form)
+
+ // these fields are *required* so check 'em
+ if form.ResponseType == "" || form.ClientID == "" || form.RedirectURI == "" {
+ return errors.New("missing one of: response_type, client_id or redirect_uri")
+ }
+
+ // set default scope to read
+ if form.Scope == "" {
+ form.Scope = "read"
+ }
+
+ // save these values from the form so we can use them elsewhere in the session
+ s.Set("force_login", form.ForceLogin)
+ s.Set("response_type", form.ResponseType)
+ s.Set("client_id", form.ClientID)
+ s.Set("redirect_uri", form.RedirectURI)
+ s.Set("scope", form.Scope)
+ return s.Save()
+}
diff --git a/internal/module/oauth/oauth_test.go b/internal/module/oauth/oauth_test.go
@@ -0,0 +1,191 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package oauth
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/gotosocial/gotosocial/internal/config"
+ "github.com/gotosocial/gotosocial/internal/db"
+ "github.com/gotosocial/gotosocial/internal/gtsmodel"
+ "github.com/gotosocial/gotosocial/internal/router"
+ "github.com/gotosocial/oauth2/v4"
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/suite"
+ "golang.org/x/crypto/bcrypt"
+)
+
+type OauthTestSuite struct {
+ suite.Suite
+ tokenStore oauth2.TokenStore
+ clientStore oauth2.ClientStore
+ db db.DB
+ testAccount *gtsmodel.Account
+ testApplication *gtsmodel.Application
+ testUser *gtsmodel.User
+ testClient *oauthClient
+ config *config.Config
+}
+
+// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
+func (suite *OauthTestSuite) SetupSuite() {
+ c := config.Empty()
+ // we're running on localhost without https so set the protocol to http
+ c.Protocol = "http"
+ // just for testing
+ c.Host = "localhost:8080"
+ // because go tests are run within the test package directory, we need to fiddle with the templateconfig
+ // basedir in a way that we wouldn't normally have to do when running the binary, in order to make
+ // the templates actually load
+ c.TemplateConfig.BaseDir = "../../../web/template/"
+ c.DBConfig = &config.DBConfig{
+ Type: "postgres",
+ Address: "localhost",
+ Port: 5432,
+ User: "postgres",
+ Password: "postgres",
+ Database: "postgres",
+ ApplicationName: "gotosocial",
+ }
+ suite.config = c
+
+ encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
+ if err != nil {
+ logrus.Panicf("error encrypting user pass: %s", err)
+ }
+
+ acctID := uuid.NewString()
+
+ suite.testAccount = >smodel.Account{
+ ID: acctID,
+ Username: "test_user",
+ }
+ suite.testUser = >smodel.User{
+ EncryptedPassword: string(encryptedPassword),
+ Email: "user@example.org",
+ AccountID: acctID,
+ }
+ suite.testClient = &oauthClient{
+ ID: "a-known-client-id",
+ Secret: "some-secret",
+ Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host),
+ }
+ suite.testApplication = >smodel.Application{
+ Name: "a test application",
+ Website: "https://some-application-website.com",
+ RedirectURI: "http://localhost:8080",
+ ClientID: "a-known-client-id",
+ ClientSecret: "some-secret",
+ Scopes: "read",
+ VapidKey: uuid.NewString(),
+ }
+}
+
+// SetupTest creates a postgres connection and creates the oauth_clients table before each test
+func (suite *OauthTestSuite) SetupTest() {
+
+ log := logrus.New()
+ log.SetLevel(logrus.TraceLevel)
+ db, err := db.New(context.Background(), suite.config, log)
+ if err != nil {
+ logrus.Panicf("error creating database connection: %s", err)
+ }
+
+ suite.db = db
+
+ models := []interface{}{
+ &oauthClient{},
+ &oauthToken{},
+ >smodel.User{},
+ >smodel.Account{},
+ >smodel.Application{},
+ }
+
+ for _, m := range models {
+ if err := suite.db.CreateTable(m); err != nil {
+ logrus.Panicf("db connection error: %s", err)
+ }
+ }
+
+ suite.tokenStore = newTokenStore(context.Background(), suite.db, logrus.New())
+ suite.clientStore = newClientStore(suite.db)
+
+ if err := suite.db.Put(suite.testAccount); err != nil {
+ logrus.Panicf("could not insert test account into db: %s", err)
+ }
+ if err := suite.db.Put(suite.testUser); err != nil {
+ logrus.Panicf("could not insert test user into db: %s", err)
+ }
+ if err := suite.db.Put(suite.testClient); err != nil {
+ logrus.Panicf("could not insert test client into db: %s", err)
+ }
+ if err := suite.db.Put(suite.testApplication); err != nil {
+ logrus.Panicf("could not insert test application into db: %s", err)
+ }
+
+}
+
+// TearDownTest drops the oauth_clients table and closes the pg connection after each test
+func (suite *OauthTestSuite) TearDownTest() {
+ models := []interface{}{
+ &oauthClient{},
+ &oauthToken{},
+ >smodel.User{},
+ >smodel.Account{},
+ >smodel.Application{},
+ }
+ for _, m := range models {
+ if err := suite.db.DropTable(m); err != nil {
+ logrus.Panicf("error dropping table: %s", err)
+ }
+ }
+ if err := suite.db.Stop(context.Background()); err != nil {
+ logrus.Panicf("error closing db connection: %s", err)
+ }
+ suite.db = nil
+}
+
+func (suite *OauthTestSuite) TestAPIInitialize() {
+ log := logrus.New()
+ log.SetLevel(logrus.TraceLevel)
+
+ r, err := router.New(suite.config, log)
+ if err != nil {
+ suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
+ }
+
+ api := New(suite.tokenStore, suite.clientStore, suite.db, log)
+ if err := api.Route(r); err != nil {
+ suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
+ }
+
+ go r.Start()
+ time.Sleep(60 * time.Second)
+ // http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=http://localhost:8080&scope=read
+ // curl -v -F client_id=a-known-client-id -F client_secret=some-secret -F redirect_uri=http://localhost:8080 -F code=[ INSERT CODE HERE ] -F grant_type=authorization_code localhost:8080/oauth/token
+ // curl -v -H "Authorization: Bearer [INSERT TOKEN HERE]" http://localhost:8080
+}
+
+func TestOauthTestSuite(t *testing.T) {
+ suite.Run(t, new(OauthTestSuite))
+}
diff --git a/internal/module/oauth/tokenstore.go b/internal/module/oauth/tokenstore.go
@@ -0,0 +1,251 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package oauth
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "github.com/gotosocial/gotosocial/internal/db"
+ "github.com/gotosocial/oauth2/v4"
+ "github.com/gotosocial/oauth2/v4/models"
+ "github.com/sirupsen/logrus"
+)
+
+// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
+type tokenStore struct {
+ oauth2.TokenStore
+ db db.DB
+ log *logrus.Logger
+}
+
+// newTokenStore returns a token store that satisfies the oauth2.TokenStore interface.
+//
+// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
+// the tokens in the DB once per minute and deletes any that have expired.
+func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.TokenStore {
+ pts := &tokenStore{
+ db: db,
+ log: log,
+ }
+
+ // set the token store to clean out expired tokens once per minute, or return if we're done
+ go func(ctx context.Context, pts *tokenStore, log *logrus.Logger) {
+ cleanloop:
+ for {
+ select {
+ case <-ctx.Done():
+ log.Info("breaking cleanloop")
+ break cleanloop
+ case <-time.After(1 * time.Minute):
+ log.Debug("sweeping out old oauth entries broom broom")
+ if err := pts.sweep(); err != nil {
+ log.Errorf("error while sweeping oauth entries: %s", err)
+ }
+ }
+ }
+ }(ctx, pts, log)
+ return pts
+}
+
+// sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so.
+func (pts *tokenStore) sweep() error {
+ // select *all* tokens from the db
+ // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
+ tokens := new([]*oauthToken)
+ if err := pts.db.GetAll(tokens); err != nil {
+ return err
+ }
+
+ // iterate through and remove expired tokens
+ now := time.Now()
+ for _, pgt := range *tokens {
+ // The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
+ // we only want to check if a token expired before now if the expiry time is *not zero*;
+ // ie., if it's been explicity set.
+ if !pgt.CodeExpiresAt.IsZero() && pgt.CodeExpiresAt.Before(now) || !pgt.RefreshExpiresAt.IsZero() && pgt.RefreshExpiresAt.Before(now) || !pgt.AccessExpiresAt.IsZero() && pgt.AccessExpiresAt.Before(now) {
+ if err := pts.db.DeleteByID(pgt.ID, &pgt); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// Create creates and store the new token information.
+// For the original implementation, see https://github.com/gotosocial/oauth2/blob/master/store/token.go#L34
+func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
+ t, ok := info.(*models.Token)
+ if !ok {
+ return errors.New("info param was not a models.Token")
+ }
+ if err := pts.db.Put(oauthTokenToPGToken(t)); err != nil {
+ return fmt.Errorf("error in tokenstore create: %s", err)
+ }
+ return nil
+}
+
+// RemoveByCode deletes a token from the DB based on the Code field
+func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
+ return pts.db.DeleteWhere("code", code, &oauthToken{})
+}
+
+// RemoveByAccess deletes a token from the DB based on the Access field
+func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
+ return pts.db.DeleteWhere("access", access, &oauthToken{})
+}
+
+// RemoveByRefresh deletes a token from the DB based on the Refresh field
+func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
+ return pts.db.DeleteWhere("refresh", refresh, &oauthToken{})
+}
+
+// GetByCode selects a token from the DB based on the Code field
+func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
+ pgt := &oauthToken{
+ Code: code,
+ }
+ if err := pts.db.GetWhere("code", code, pgt); err != nil {
+ return nil, err
+ }
+ return pgTokenToOauthToken(pgt), nil
+}
+
+// GetByAccess selects a token from the DB based on the Access field
+func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
+ pgt := &oauthToken{
+ Access: access,
+ }
+ if err := pts.db.GetWhere("access", access, pgt); err != nil {
+ return nil, err
+ }
+ return pgTokenToOauthToken(pgt), nil
+}
+
+// GetByRefresh selects a token from the DB based on the Refresh field
+func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
+ pgt := &oauthToken{
+ Refresh: refresh,
+ }
+ if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil {
+ return nil, err
+ }
+ return pgTokenToOauthToken(pgt), nil
+}
+
+/*
+ The following models are basically helpers for the postgres token store implementation, they should only be used internally.
+*/
+
+// oauthToken is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
+//
+// Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined,
+// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and
+// then periodically sweep out tokens when that time has passed.
+//
+// Note that this struct does *not* satisfy the token interface shown here: https://github.com/gotosocial/oauth2/blob/master/model.go#L22
+// and implemented here: https://github.com/gotosocial/oauth2/blob/master/models/token.go.
+// As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
+// and pgTokenToOauthToken can be used for that.
+type oauthToken struct {
+ ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
+ ClientID string
+ UserID string
+ RedirectURI string
+ Scope string
+ Code string `pg:"default:'',pk"`
+ CodeChallenge string
+ CodeChallengeMethod string
+ CodeCreateAt time.Time `pg:"type:timestamp"`
+ CodeExpiresAt time.Time `pg:"type:timestamp"`
+ Access string `pg:"default:'',pk"`
+ AccessCreateAt time.Time `pg:"type:timestamp"`
+ AccessExpiresAt time.Time `pg:"type:timestamp"`
+ Refresh string `pg:"default:'',pk"`
+ RefreshCreateAt time.Time `pg:"type:timestamp"`
+ RefreshExpiresAt time.Time `pg:"type:timestamp"`
+}
+
+// oauthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres
+func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
+ now := time.Now()
+
+ // For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's
+ // going to cause all sorts of interesting problems. So check first to make sure that the ExpiresIn is not equal
+ // to the zero value of a time.Duration, which is 0s. If it *is* empty/nil, just leave the ExpiresAt at nil as well.
+
+ var cea time.Time
+ if tkn.CodeExpiresIn != 0*time.Second {
+ cea = now.Add(tkn.CodeExpiresIn)
+ }
+
+ var aea time.Time
+ if tkn.AccessExpiresIn != 0*time.Second {
+ aea = now.Add(tkn.AccessExpiresIn)
+ }
+
+ var rea time.Time
+ if tkn.RefreshExpiresIn != 0*time.Second {
+ rea = now.Add(tkn.RefreshExpiresIn)
+ }
+
+ return &oauthToken{
+ ClientID: tkn.ClientID,
+ UserID: tkn.UserID,
+ RedirectURI: tkn.RedirectURI,
+ Scope: tkn.Scope,
+ Code: tkn.Code,
+ CodeChallenge: tkn.CodeChallenge,
+ CodeChallengeMethod: tkn.CodeChallengeMethod,
+ CodeCreateAt: tkn.CodeCreateAt,
+ CodeExpiresAt: cea,
+ Access: tkn.Access,
+ AccessCreateAt: tkn.AccessCreateAt,
+ AccessExpiresAt: aea,
+ Refresh: tkn.Refresh,
+ RefreshCreateAt: tkn.RefreshCreateAt,
+ RefreshExpiresAt: rea,
+ }
+}
+
+// pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token
+func pgTokenToOauthToken(pgt *oauthToken) *models.Token {
+ now := time.Now()
+
+ return &models.Token{
+ ClientID: pgt.ClientID,
+ UserID: pgt.UserID,
+ RedirectURI: pgt.RedirectURI,
+ Scope: pgt.Scope,
+ Code: pgt.Code,
+ CodeChallenge: pgt.CodeChallenge,
+ CodeChallengeMethod: pgt.CodeChallengeMethod,
+ CodeCreateAt: pgt.CodeCreateAt,
+ CodeExpiresIn: pgt.CodeExpiresAt.Sub(now),
+ Access: pgt.Access,
+ AccessCreateAt: pgt.AccessCreateAt,
+ AccessExpiresIn: pgt.AccessExpiresAt.Sub(now),
+ Refresh: pgt.Refresh,
+ RefreshCreateAt: pgt.RefreshCreateAt,
+ RefreshExpiresIn: pgt.RefreshExpiresAt.Sub(now),
+ }
+}
diff --git a/internal/oauth/README.md b/internal/oauth/README.md
@@ -1,3 +0,0 @@
-# oauth
-
-This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) server functionality to the GoToSocial APIs.
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
@@ -1,446 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see <http://www.gnu.org/licenses/>.
-*/
-
-package oauth
-
-import (
- "fmt"
- "net/http"
- "net/url"
-
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
- "github.com/go-pg/pg/v10"
- "github.com/google/uuid"
- "github.com/gotosocial/gotosocial/internal/api"
- "github.com/gotosocial/gotosocial/internal/gtsmodel"
- "github.com/gotosocial/gotosocial/pkg/mastotypes"
- "github.com/gotosocial/oauth2/v4"
- "github.com/gotosocial/oauth2/v4/errors"
- "github.com/gotosocial/oauth2/v4/manage"
- "github.com/gotosocial/oauth2/v4/server"
- "github.com/sirupsen/logrus"
- "golang.org/x/crypto/bcrypt"
-)
-
-type API struct {
- manager *manage.Manager
- server *server.Server
- conn *pg.DB
- log *logrus.Logger
-}
-
-type login struct {
- Email string `form:"username"`
- Password string `form:"password"`
-}
-
-type code struct {
- Code string `form:"code"`
-}
-
-func New(ts oauth2.TokenStore, cs oauth2.ClientStore, conn *pg.DB, log *logrus.Logger) *API {
- manager := manage.NewDefaultManager()
- manager.MapTokenStorage(ts)
- manager.MapClientStorage(cs)
- manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
- sc := &server.Config{
- TokenType: "Bearer",
- // Must follow the spec.
- AllowGetAccessRequest: false,
- // Support only the non-implicit flow.
- AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
- // Allow:
- // - Authorization Code (for first & third parties)
- // - Refreshing Tokens
- //
- // Deny:
- // - Resource owner secrets (password grant)
- // - Client secrets
- AllowedGrantTypes: []oauth2.GrantType{
- oauth2.AuthorizationCode,
- oauth2.Refreshing,
- },
- AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
- oauth2.CodeChallengePlain,
- },
- }
-
- srv := server.NewServer(sc, manager)
- srv.SetInternalErrorHandler(func(err error) *errors.Response {
- log.Errorf("internal oauth error: %s", err)
- return nil
- })
-
- srv.SetResponseErrorHandler(func(re *errors.Response) {
- log.Errorf("internal response error: %s", re.Error)
- })
-
- api := &API{
- manager: manager,
- server: srv,
- conn: conn,
- log: log,
- }
-
- api.server.SetUserAuthorizationHandler(api.UserAuthorizationHandler)
- api.server.SetClientInfoHandler(server.ClientFormHandler)
- return api
-}
-
-func (a *API) AddRoutes(s api.Server) error {
- s.AttachHandler(http.MethodPost, "/api/v1/apps", a.AppsPOSTHandler)
-
- s.AttachHandler(http.MethodGet, "/auth/sign_in", a.SignInGETHandler)
- s.AttachHandler(http.MethodPost, "/auth/sign_in", a.SignInPOSTHandler)
-
- s.AttachHandler(http.MethodPost, "/oauth/token", a.TokenPOSTHandler)
-
- s.AttachHandler(http.MethodGet, "/oauth/authorize", a.AuthorizeGETHandler)
- s.AttachHandler(http.MethodPost, "/oauth/authorize", a.AuthorizePOSTHandler)
-
- return nil
-}
-
-func incorrectPassword() (string, error) {
- return "", errors.New("password/email combination was incorrect")
-}
-
-/*
- MAIN HANDLERS -- serve these through a server/router
-*/
-
-// AppsPOSTHandler should be served at https://example.org/api/v1/apps
-// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
-func (a *API) AppsPOSTHandler(c *gin.Context) {
- l := a.log.WithField("func", "AppsPOSTHandler")
- l.Trace("entering AppsPOSTHandler")
-
- form := &mastotypes.ApplicationPOSTRequest{}
- if err := c.ShouldBind(form); err != nil {
- c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
- return
- }
-
- // permitted length for most fields
- permittedLength := 64
- // redirect can be a bit bigger because we probably need to encode data in the redirect uri
- permittedRedirect := 256
-
- // check lengths of fields before proceeding so the user can't spam huge entries into the database
- if len(form.ClientName) > permittedLength {
- c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
- return
- }
- if len(form.Website) > permittedLength {
- c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
- return
- }
- if len(form.RedirectURIs) > permittedRedirect {
- c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
- return
- }
- if len(form.Scopes) > permittedLength {
- c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
- return
- }
-
- // set default 'read' for scopes if it's not set
- var scopes string
- if form.Scopes == "" {
- scopes = "read"
- } else {
- scopes = form.Scopes
- }
-
- // generate new IDs for this application and its associated client
- clientID := uuid.NewString()
- clientSecret := uuid.NewString()
- vapidKey := uuid.NewString()
-
- // generate the application to put in the database
- app := >smodel.Application{
- Name: form.ClientName,
- Website: form.Website,
- RedirectURI: form.RedirectURIs,
- ClientID: clientID,
- ClientSecret: clientSecret,
- Scopes: scopes,
- VapidKey: vapidKey,
- }
-
- // chuck it in the db
- if _, err := a.conn.Model(app).Insert(); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
-
- // now we need to model an oauth client from the application that the oauth library can use
- oc := &oauthClient{
- ID: clientID,
- Secret: clientSecret,
- Domain: form.RedirectURIs,
- UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
- }
-
- // chuck it in the db
- if _, err := a.conn.Model(oc).Insert(); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
-
- // done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
- c.JSON(http.StatusOK, app)
-}
-
-// SignInGETHandler should be served at https://example.org/auth/sign_in.
-// The idea is to present a sign in page to the user, where they can enter their username and password.
-// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler
-func (a *API) SignInGETHandler(c *gin.Context) {
- a.log.WithField("func", "SignInGETHandler").Trace("serving sign in html")
- c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{})
-}
-
-// SignInPOSTHandler should be served at https://example.org/auth/sign_in.
-// The idea is to present a sign in page to the user, where they can enter their username and password.
-// The handler will then redirect to the auth handler served at /auth
-func (a *API) SignInPOSTHandler(c *gin.Context) {
- l := a.log.WithField("func", "SignInPOSTHandler")
- s := sessions.Default(c)
- form := &login{}
- if err := c.ShouldBind(form); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
- return
- }
- l.Tracef("parsed form: %+v", form)
-
- userid, err := a.ValidatePassword(form.Email, form.Password)
- if err != nil {
- c.String(http.StatusForbidden, err.Error())
- return
- }
-
- s.Set("username", userid)
- if err := s.Save(); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
-
- l.Trace("redirecting to auth page")
- c.Redirect(http.StatusFound, "/oauth/authorize")
-}
-
-// TokenPOSTHandler should be served as a POST at https://example.org/oauth/token
-// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs.
-// See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token
-func (a *API) TokenPOSTHandler(c *gin.Context) {
- l := a.log.WithField("func", "TokenHandler")
- l.Trace("entered token handler, will now go to server.HandleTokenRequest")
- if err := a.server.HandleTokenRequest(c.Writer, c.Request); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- }
-}
-
-// AuthorizeGETHandler should be served as GET at https://example.org/oauth/authorize
-// The idea here is to present an oauth authorize page to the user, with a button
-// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
-func (a *API) AuthorizeGETHandler(c *gin.Context) {
- l := a.log.WithField("func", "AuthorizeGETHandler")
- s := sessions.Default(c)
-
- // Username will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow
- // If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page.
- v := s.Get("username")
- if username, ok := v.(string); !ok || username == "" {
- l.Trace("username was empty, parsing form then redirecting to sign in page")
-
- // first make sure they've filled out the authorize form with the required values
- form := &mastotypes.OAuthAuthorize{}
- if err := c.ShouldBind(form); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
- return
- }
- l.Tracef("parsed form: %+v", form)
-
- // these fields are *required* so check 'em
- if form.ResponseType == "" || form.ClientID == "" || form.RedirectURI == "" {
- c.JSON(http.StatusBadRequest, gin.H{"error": "missing one of: response_type, client_id or redirect_uri"})
- return
- }
-
- // save these values from the form so we can use them elsewhere in the session
- s.Set("force_login", form.ForceLogin)
- s.Set("response_type", form.ResponseType)
- s.Set("client_id", form.ClientID)
- s.Set("redirect_uri", form.RedirectURI)
- s.Set("scope", form.Scope)
- if err := s.Save(); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
-
- // send them to the sign in page so we can tell who they are
- c.Redirect(http.StatusFound, "/auth/sign_in")
- return
- }
-
- // Check if we have a code already. If we do, it means the user used urn:ietf:wg:oauth:2.0:oob as their redirect URI
- // and were sent here, which means they just want the code displayed so they can use it out of band.
- code := &code{}
- if err := c.Bind(code); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
-
- // the authorize template will either:
- // 1. Display the code to the user if they're already authorized and were redirected here because they selected urn:ietf:wg:oauth:2.0:oob.
- // 2. Display a form where they can get some information about the app that's trying to authorize, and approve it, which will then go to AuthorizePOSTHandler
- l.Trace("serving authorize html")
- c.HTML(http.StatusOK, "authorize.tmpl", gin.H{
- "code": code.Code,
- })
-}
-
-// AuthorizePOSTHandler should be served as POST at https://example.org/oauth/authorize
-// The idea here is to present an oauth authorize page to the user, with a button
-// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
-func (a *API) AuthorizePOSTHandler(c *gin.Context) {
- l := a.log.WithField("func", "AuthorizePOSTHandler")
- s := sessions.Default(c)
-
- v := s.Get("username")
- if username, ok := v.(string); !ok || username == "" {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "you are not signed in"})
- }
-
- values := url.Values{}
-
- if v, ok := s.Get("force_login").(string); !ok {
- c.JSON(http.StatusBadRequest, gin.H{"error": "session missing force_login"})
- return
- } else {
- values.Add("force_login", v)
- }
-
- if v, ok := s.Get("response_type").(string); !ok {
- c.JSON(http.StatusBadRequest, gin.H{"error": "session missing response_type"})
- return
- } else {
- values.Add("response_type", v)
- }
-
- if v, ok := s.Get("client_id").(string); !ok {
- c.JSON(http.StatusBadRequest, gin.H{"error": "session missing client_id"})
- return
- } else {
- values.Add("client_id", v)
- }
-
- if v, ok := s.Get("redirect_uri").(string); !ok {
- c.JSON(http.StatusBadRequest, gin.H{"error": "session missing redirect_uri"})
- return
- } else {
- // todo: explain this little hack
- if v == "urn:ietf:wg:oauth:2.0:oob" {
- v = "http://localhost:8080/oauth/authorize"
- }
- values.Add("redirect_uri", v)
- }
-
- if v, ok := s.Get("scope").(string); !ok {
- c.JSON(http.StatusBadRequest, gin.H{"error": "session missing scope"})
- return
- } else {
- values.Add("scope", v)
- }
-
- if v, ok := s.Get("username").(string); !ok {
- c.JSON(http.StatusBadRequest, gin.H{"error": "session missing username"})
- return
- } else {
- values.Add("username", v)
- }
-
- c.Request.Form = values
- l.Tracef("values on request set to %+v", c.Request.Form)
-
- if err := s.Save(); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
-
- if err := a.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
- }
-}
-
-/*
- SUB-HANDLERS -- don't serve these directly, they should be attached to the oauth2 server
-*/
-
-// PasswordAuthorizationHandler takes a username (in this case, we use an email address)
-// and a password. The goal is to authenticate the password against the one for that email
-// address stored in the database. If OK, we return the userid (a uuid) for that user,
-// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
-func (a *API) ValidatePassword(email string, password string) (userid string, err error) {
- l := a.log.WithField("func", "PasswordAuthorizationHandler")
-
- // make sure an email/password was provided and bail if not
- if email == "" || password == "" {
- l.Debug("email or password was not provided")
- return incorrectPassword()
- }
-
- // first we select the user from the database based on email address, bail if no user found for that email
- gtsUser := >smodel.User{}
- if err := a.conn.Model(gtsUser).Where("email = ?", email).Select(); err != nil {
- l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
- return incorrectPassword()
- }
-
- // make sure a password is actually set and bail if not
- if gtsUser.EncryptedPassword == "" {
- l.Warnf("encrypted password for user %s was empty for some reason", gtsUser.Email)
- return incorrectPassword()
- }
-
- // compare the provided password with the encrypted one from the db, bail if they don't match
- if err := bcrypt.CompareHashAndPassword([]byte(gtsUser.EncryptedPassword), []byte(password)); err != nil {
- l.Debugf("password hash didn't match for user %s during login attempt: %s", gtsUser.Email, err)
- return incorrectPassword()
- }
-
- // If we've made it this far the email/password is correct, so we can just return the id of the user.
- userid = gtsUser.ID
- l.Tracef("returning (%s, %s)", userid, err)
- return
-}
-
-// UserAuthorizationHandler gets the user's ID from the 'username' field of the request form,
-// or redirects to the /auth/sign_in page, if this key is not present.
-func (a *API) UserAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) {
- l := a.log.WithField("func", "UserAuthorizationHandler")
- userID = r.FormValue("username")
- if userID == "" {
- l.Trace("username was empty, redirecting to sign in page")
- http.Redirect(w, r, "/auth/sign_in", http.StatusFound)
- return "", nil
- }
- l.Tracef("returning (%s, %s)", userID, err)
- return userID, err
-}
diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go
@@ -1,133 +0,0 @@
-package oauth
-
-import (
- "context"
- "fmt"
- "testing"
- "time"
-
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
- "github.com/gotosocial/gotosocial/internal/api"
- "github.com/gotosocial/gotosocial/internal/config"
- "github.com/gotosocial/gotosocial/internal/gtsmodel"
- "github.com/gotosocial/oauth2/v4"
- "github.com/sirupsen/logrus"
- "github.com/stretchr/testify/suite"
- "golang.org/x/crypto/bcrypt"
-)
-
-type OauthTestSuite struct {
- suite.Suite
- tokenStore oauth2.TokenStore
- clientStore oauth2.ClientStore
- conn *pg.DB
- testAccount *gtsmodel.Account
- testUser *gtsmodel.User
- testClient *oauthClient
- config *config.Config
-}
-
-const ()
-
-// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
-func (suite *OauthTestSuite) SetupSuite() {
- encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("test-password"), bcrypt.DefaultCost)
- if err != nil {
- logrus.Panicf("error encrypting user pass: %s", err)
- }
-
- suite.testAccount = >smodel.Account{}
- suite.testUser = >smodel.User{
- EncryptedPassword: string(encryptedPassword),
- Email: "user@localhost",
- AccountID: "some-account-id-it-doesn't-matter-really-since-this-user-doesn't-actually-have-an-account!",
- }
- suite.testClient = &oauthClient{
- ID: "a-known-client-id",
- Secret: "some-secret",
- Domain: "http://localhost:8080",
- }
-
- // because go tests are run within the test package directory, we need to fiddle with the templateconfig
- // basedir in a way that we wouldn't normally have to do when running the binary, in order to make
- // the templates actually load
- c := config.Empty()
- c.TemplateConfig.BaseDir = "../../web/template/"
- suite.config = c
-}
-
-// SetupTest creates a postgres connection and creates the oauth_clients table before each test
-func (suite *OauthTestSuite) SetupTest() {
- suite.conn = pg.Connect(&pg.Options{})
- if err := suite.conn.Ping(context.Background()); err != nil {
- logrus.Panicf("db connection error: %s", err)
- }
-
- models := []interface{}{
- &oauthClient{},
- &oauthToken{},
- >smodel.User{},
- >smodel.Account{},
- >smodel.Application{},
- }
-
- for _, m := range models {
- if err := suite.conn.Model(m).CreateTable(&orm.CreateTableOptions{
- IfNotExists: true,
- }); err != nil {
- logrus.Panicf("db connection error: %s", err)
- }
- }
-
- suite.tokenStore = NewPGTokenStore(context.Background(), suite.conn, logrus.New())
- suite.clientStore = NewPGClientStore(suite.conn)
-
- if _, err := suite.conn.Model(suite.testUser).Insert(); err != nil {
- logrus.Panicf("could not insert test user into db: %s", err)
- }
-
- if _, err := suite.conn.Model(suite.testClient).Insert(); err != nil {
- logrus.Panicf("could not insert test client into db: %s", err)
- }
-
-}
-
-// TearDownTest drops the oauth_clients table and closes the pg connection after each test
-func (suite *OauthTestSuite) TearDownTest() {
- models := []interface{}{
- &oauthClient{},
- &oauthToken{},
- >smodel.User{},
- >smodel.Account{},
- >smodel.Application{},
- }
- for _, m := range models {
- if err := suite.conn.Model(m).DropTable(&orm.DropTableOptions{}); err != nil {
- logrus.Panicf("drop table error: %s", err)
- }
- }
- if err := suite.conn.Close(); err != nil {
- logrus.Panicf("error closing db connection: %s", err)
- }
- suite.conn = nil
-}
-
-func (suite *OauthTestSuite) TestAPIInitialize() {
- log := logrus.New()
- log.SetLevel(logrus.TraceLevel)
-
- r := api.New(suite.config, log)
- api := New(suite.tokenStore, suite.clientStore, suite.conn, log)
- if err := api.AddRoutes(r); err != nil {
- suite.FailNow(fmt.Sprintf("error initializing api: %s", err))
- }
- go r.Start()
- time.Sleep(30 * time.Second)
- // http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=https://example.org
- // http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=urn:ietf:wg:oauth:2.0:oob
-}
-
-func TestOauthTestSuite(t *testing.T) {
- suite.Run(t, new(OauthTestSuite))
-}
diff --git a/internal/oauth/pgclientstore.go b/internal/oauth/pgclientstore.go
@@ -1,81 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see <http://www.gnu.org/licenses/>.
-*/
-
-package oauth
-
-import (
- "context"
- "fmt"
-
- "github.com/go-pg/pg/v10"
- "github.com/gotosocial/oauth2/v4"
- "github.com/gotosocial/oauth2/v4/models"
-)
-
-type pgClientStore struct {
- conn *pg.DB
-}
-
-func NewPGClientStore(conn *pg.DB) oauth2.ClientStore {
- pts := &pgClientStore{
- conn: conn,
- }
- return pts
-}
-
-func (pcs *pgClientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
- poc := &oauthClient{
- ID: clientID,
- }
- if err := pcs.conn.WithContext(ctx).Model(poc).Where("id = ?", poc.ID).Select(); err != nil {
- return nil, fmt.Errorf("error in clientstore getbyid searching for client %s: %s", clientID, err)
- }
- return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil
-}
-
-func (pcs *pgClientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error {
- poc := &oauthClient{
- ID: cli.GetID(),
- Secret: cli.GetSecret(),
- Domain: cli.GetDomain(),
- UserID: cli.GetUserID(),
- }
- _, err := pcs.conn.WithContext(ctx).Model(poc).OnConflict("(id) DO UPDATE").Insert()
- if err != nil {
- return fmt.Errorf("error in clientstore set: %s", err)
- }
- return nil
-}
-
-func (pcs *pgClientStore) Delete(ctx context.Context, id string) error {
- poc := &oauthClient{
- ID: id,
- }
- _, err := pcs.conn.WithContext(ctx).Model(poc).Where("id = ?", poc.ID).Delete()
- if err != nil {
- return fmt.Errorf("error in clientstore delete: %s", err)
- }
- return nil
-}
-
-type oauthClient struct {
- ID string
- Secret string
- Domain string
- UserID string
-}
diff --git a/internal/oauth/pgclientstore_test.go b/internal/oauth/pgclientstore_test.go
@@ -1,103 +0,0 @@
-package oauth
-
-import (
- "context"
- "testing"
-
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
- "github.com/gotosocial/oauth2/v4/models"
- "github.com/sirupsen/logrus"
- "github.com/stretchr/testify/suite"
-)
-
-type PgClientStoreTestSuite struct {
- suite.Suite
- conn *pg.DB
- testClientID string
- testClientSecret string
- testClientDomain string
- testClientUserID string
-}
-
-const ()
-
-// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
-func (suite *PgClientStoreTestSuite) SetupSuite() {
- suite.testClientID = "test-client-id"
- suite.testClientSecret = "test-client-secret"
- suite.testClientDomain = "https://example.org"
- suite.testClientUserID = "test-client-user-id"
-}
-
-// SetupTest creates a postgres connection and creates the oauth_clients table before each test
-func (suite *PgClientStoreTestSuite) SetupTest() {
- suite.conn = pg.Connect(&pg.Options{})
- if err := suite.conn.Ping(context.Background()); err != nil {
- logrus.Panicf("db connection error: %s", err)
- }
- if err := suite.conn.Model(&oauthClient{}).CreateTable(&orm.CreateTableOptions{
- IfNotExists: true,
- }); err != nil {
- logrus.Panicf("db connection error: %s", err)
- }
-}
-
-// TearDownTest drops the oauth_clients table and closes the pg connection after each test
-func (suite *PgClientStoreTestSuite) TearDownTest() {
- if err := suite.conn.Model(&oauthClient{}).DropTable(&orm.DropTableOptions{}); err != nil {
- logrus.Panicf("drop table error: %s", err)
- }
- if err := suite.conn.Close(); err != nil {
- logrus.Panicf("error closing db connection: %s", err)
- }
- suite.conn = nil
-}
-
-func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() {
- // set a new client in the store
- cs := NewPGClientStore(suite.conn)
- if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
- suite.FailNow(err.Error())
- }
-
- // fetch that client from the store
- client, err := cs.GetByID(context.Background(), suite.testClientID)
- if err != nil {
- suite.FailNow(err.Error())
- }
-
- // check that the values are the same
- suite.NotNil(client)
- suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client)
-}
-
-func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
- // set a new client in the store
- cs := NewPGClientStore(suite.conn)
- if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
- suite.FailNow(err.Error())
- }
-
- // fetch the client from the store
- client, err := cs.GetByID(context.Background(), suite.testClientID)
- if err != nil {
- suite.FailNow(err.Error())
- }
-
- // check that the values are the same
- suite.NotNil(client)
- suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client)
- if err := cs.Delete(context.Background(), suite.testClientID); err != nil {
- suite.FailNow(err.Error())
- }
-
- // try to get the deleted client; we should get an error
- deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
- suite.Assert().Nil(deletedClient)
- suite.Assert().NotNil(err)
-}
-
-func TestPgClientStoreTestSuite(t *testing.T) {
- suite.Run(t, new(PgClientStoreTestSuite))
-}
diff --git a/internal/oauth/pgtokenstore.go b/internal/oauth/pgtokenstore.go
@@ -1,257 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see <http://www.gnu.org/licenses/>.
-*/
-
-package oauth
-
-import (
- "context"
- "errors"
- "fmt"
- "time"
-
- "github.com/go-pg/pg/v10"
- "github.com/gotosocial/oauth2/v4"
- "github.com/gotosocial/oauth2/v4/models"
- "github.com/sirupsen/logrus"
-)
-
-// pgTokenStore is an implementation of oauth2.TokenStore, which uses Postgres as a storage backend.
-type pgTokenStore struct {
- oauth2.TokenStore
- conn *pg.DB
- log *logrus.Logger
-}
-
-// NewPGTokenStore returns a token store, using postgres, that satisfies the oauth2.TokenStore interface.
-//
-// In order to allow tokens to 'expire' (not really a thing in Postgres world), it will also set off a
-// goroutine that iterates through the tokens in the DB once per minute and deletes any that have expired.
-func NewPGTokenStore(ctx context.Context, conn *pg.DB, log *logrus.Logger) oauth2.TokenStore {
- pts := &pgTokenStore{
- conn: conn,
- log: log,
- }
-
- // set the token store to clean out expired tokens once per minute, or return if we're done
- go func(ctx context.Context, pts *pgTokenStore, log *logrus.Logger) {
- cleanloop:
- for {
- select {
- case <-ctx.Done():
- log.Info("breaking cleanloop")
- break cleanloop
- case <-time.After(1 * time.Minute):
- log.Debug("sweeping out old oauth entries broom broom")
- if err := pts.sweep(); err != nil {
- log.Errorf("error while sweeping oauth entries: %s", err)
- }
- }
- }
- }(ctx, pts, log)
- return pts
-}
-
-// sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so.
-func (pts *pgTokenStore) sweep() error {
- // select *all* tokens from the db
- // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
- var tokens []oauthToken
- if err := pts.conn.Model(&tokens).Select(); err != nil {
- return err
- }
-
- // iterate through and remove expired tokens
- now := time.Now()
- for _, pgt := range tokens {
- // The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
- // we only want to check if a token expired before now if the expiry time is *not zero*;
- // ie., if it's been explicity set.
- if !pgt.CodeExpiresAt.IsZero() && pgt.CodeExpiresAt.Before(now) || !pgt.RefreshExpiresAt.IsZero() && pgt.RefreshExpiresAt.Before(now) || !pgt.AccessExpiresAt.IsZero() && pgt.AccessExpiresAt.Before(now) {
- if _, err := pts.conn.Model(&pgt).Delete(); err != nil {
- return err
- }
- }
- }
-
- return nil
-}
-
-// Create creates and store the new token information.
-// For the original implementation, see https://github.com/gotosocial/oauth2/blob/master/store/token.go#L34
-func (pts *pgTokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
- t, ok := info.(*models.Token)
- if !ok {
- return errors.New("info param was not a models.Token")
- }
- _, err := pts.conn.WithContext(ctx).Model(oauthTokenToPGToken(t)).Insert()
- if err != nil {
- return fmt.Errorf("error in tokenstore create: %s", err)
- }
- return nil
-}
-
-// RemoveByCode deletes a token from the DB based on the Code field
-func (pts *pgTokenStore) RemoveByCode(ctx context.Context, code string) error {
- _, err := pts.conn.Model(&oauthToken{}).Where("code = ?", code).Delete()
- if err != nil {
- return fmt.Errorf("error in tokenstore removebycode: %s", err)
- }
- return nil
-}
-
-// RemoveByAccess deletes a token from the DB based on the Access field
-func (pts *pgTokenStore) RemoveByAccess(ctx context.Context, access string) error {
- _, err := pts.conn.Model(&oauthToken{}).Where("access = ?", access).Delete()
- if err != nil {
- return fmt.Errorf("error in tokenstore removebyaccess: %s", err)
- }
- return nil
-}
-
-// RemoveByRefresh deletes a token from the DB based on the Refresh field
-func (pts *pgTokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
- _, err := pts.conn.Model(&oauthToken{}).Where("refresh = ?", refresh).Delete()
- if err != nil {
- return fmt.Errorf("error in tokenstore removebyrefresh: %s", err)
- }
- return nil
-}
-
-// GetByCode selects a token from the DB based on the Code field
-func (pts *pgTokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
- pgt := &oauthToken{}
- if err := pts.conn.Model(pgt).Where("code = ?", code).Select(); err != nil {
- return nil, fmt.Errorf("error in tokenstore getbycode: %s", err)
- }
- return pgTokenToOauthToken(pgt), nil
-}
-
-// GetByAccess selects a token from the DB based on the Access field
-func (pts *pgTokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
- pgt := &oauthToken{}
- if err := pts.conn.Model(pgt).Where("access = ?", access).Select(); err != nil {
- return nil, fmt.Errorf("error in tokenstore getbyaccess: %s", err)
- }
- return pgTokenToOauthToken(pgt), nil
-}
-
-// GetByRefresh selects a token from the DB based on the Refresh field
-func (pts *pgTokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
- pgt := &oauthToken{}
- if err := pts.conn.Model(pgt).Where("refresh = ?", refresh).Select(); err != nil {
- return nil, fmt.Errorf("error in tokenstore getbyrefresh: %s", err)
- }
- return pgTokenToOauthToken(pgt), nil
-}
-
-/*
- The following models are basically helpers for the postgres token store implementation, they should only be used internally.
-*/
-
-// oauthToken is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
-//
-// Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined,
-// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and
-// then periodically sweep out tokens when that time has passed.
-//
-// Note that this struct does *not* satisfy the token interface shown here: https://github.com/gotosocial/oauth2/blob/master/model.go#L22
-// and implemented here: https://github.com/gotosocial/oauth2/blob/master/models/token.go.
-// As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
-// and pgTokenToOauthToken can be used for that.
-type oauthToken struct {
- ClientID string
- UserID string
- RedirectURI string
- Scope string
- Code string `pg:"default:'',pk"`
- CodeChallenge string
- CodeChallengeMethod string
- CodeCreateAt time.Time `pg:"type:timestamp"`
- CodeExpiresAt time.Time `pg:"type:timestamp"`
- Access string `pg:"default:'',pk"`
- AccessCreateAt time.Time `pg:"type:timestamp"`
- AccessExpiresAt time.Time `pg:"type:timestamp"`
- Refresh string `pg:"default:'',pk"`
- RefreshCreateAt time.Time `pg:"type:timestamp"`
- RefreshExpiresAt time.Time `pg:"type:timestamp"`
-}
-
-// oauthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres
-func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
- now := time.Now()
-
- // For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's
- // going to cause all sorts of interesting problems. So check first to make sure that the ExpiresIn is not equal
- // to the zero value of a time.Duration, which is 0s. If it *is* empty/nil, just leave the ExpiresAt at nil as well.
-
- var cea time.Time
- if tkn.CodeExpiresIn != 0*time.Second {
- cea = now.Add(tkn.CodeExpiresIn)
- }
-
- var aea time.Time
- if tkn.AccessExpiresIn != 0*time.Second {
- aea = now.Add(tkn.AccessExpiresIn)
- }
-
- var rea time.Time
- if tkn.RefreshExpiresIn != 0*time.Second {
- rea = now.Add(tkn.RefreshExpiresIn)
- }
-
- return &oauthToken{
- ClientID: tkn.ClientID,
- UserID: tkn.UserID,
- RedirectURI: tkn.RedirectURI,
- Scope: tkn.Scope,
- Code: tkn.Code,
- CodeChallenge: tkn.CodeChallenge,
- CodeChallengeMethod: tkn.CodeChallengeMethod,
- CodeCreateAt: tkn.CodeCreateAt,
- CodeExpiresAt: cea,
- Access: tkn.Access,
- AccessCreateAt: tkn.AccessCreateAt,
- AccessExpiresAt: aea,
- Refresh: tkn.Refresh,
- RefreshCreateAt: tkn.RefreshCreateAt,
- RefreshExpiresAt: rea,
- }
-}
-
-// pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token
-func pgTokenToOauthToken(pgt *oauthToken) *models.Token {
- now := time.Now()
-
- return &models.Token{
- ClientID: pgt.ClientID,
- UserID: pgt.UserID,
- RedirectURI: pgt.RedirectURI,
- Scope: pgt.Scope,
- Code: pgt.Code,
- CodeChallenge: pgt.CodeChallenge,
- CodeChallengeMethod: pgt.CodeChallengeMethod,
- CodeCreateAt: pgt.CodeCreateAt,
- CodeExpiresIn: pgt.CodeExpiresAt.Sub(now),
- Access: pgt.Access,
- AccessCreateAt: pgt.AccessCreateAt,
- AccessExpiresIn: pgt.AccessExpiresAt.Sub(now),
- Refresh: pgt.Refresh,
- RefreshCreateAt: pgt.RefreshCreateAt,
- RefreshExpiresIn: pgt.RefreshExpiresAt.Sub(now),
- }
-}
diff --git a/internal/router/router.go b/internal/router/router.go
@@ -0,0 +1,120 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package router
+
+import (
+ "crypto/rand"
+ "fmt"
+ "os"
+ "path/filepath"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-contrib/sessions/memstore"
+ "github.com/gin-gonic/gin"
+ "github.com/gotosocial/gotosocial/internal/config"
+ "github.com/sirupsen/logrus"
+)
+
+// Router provides the REST interface for gotosocial, using gin.
+type Router interface {
+ // Attach a gin handler to the router with the given method and path
+ AttachHandler(method string, path string, handler gin.HandlerFunc)
+ // Attach a gin middleware to the router that will be used globally
+ AttachMiddleware(handler gin.HandlerFunc)
+ // Start the router
+ Start()
+ // Stop the router
+ Stop()
+}
+
+// router fulfils the Router interface using gin and logrus
+type router struct {
+ logger *logrus.Logger
+ engine *gin.Engine
+}
+
+// Start starts the router nicely
+func (s *router) Start() {
+ // todo: start gracefully
+ if err := s.engine.Run(); err != nil {
+ s.logger.Panicf("server error: %s", err)
+ }
+}
+
+// Stop shuts down the router nicely
+func (s *router) Stop() {
+ // todo: shut down gracefully
+}
+
+// AttachHandler attaches the given gin.HandlerFunc to the router with the specified method and path.
+// If the path is set to ANY, then the handlerfunc will be used for ALL methods at its given path.
+func (s *router) AttachHandler(method string, path string, handler gin.HandlerFunc) {
+ if method == "ANY" {
+ s.engine.Any(path, handler)
+ } else {
+ s.engine.Handle(method, path, handler)
+ }
+}
+
+// AttachMiddleware attaches a gin middleware to the router that will be used globally
+func (s *router) AttachMiddleware(middleware gin.HandlerFunc) {
+ s.engine.Use(middleware)
+}
+
+// New returns a new Router with the specified configuration, using the given logrus logger.
+func New(config *config.Config, logger *logrus.Logger) (Router, error) {
+ engine := gin.New()
+
+ // create a new session store middleware
+ store, err := sessionStore()
+ if err != nil {
+ return nil, fmt.Errorf("error creating session store: %s", err)
+ }
+ engine.Use(sessions.Sessions("gotosocial-session", store))
+
+ // load html templates for use by the router
+ cwd, err := os.Getwd()
+ if err != nil {
+ return nil, fmt.Errorf("error getting current working directory: %s", err)
+ }
+ tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", config.TemplateConfig.BaseDir))
+ logger.Debugf("loading templates from %s", tmPath)
+ engine.LoadHTMLGlob(tmPath)
+
+ return &router{
+ logger: logger,
+ engine: engine,
+ }, nil
+}
+
+// sessionStore returns a new session store with a random auth and encryption key.
+// This means that cookies using the store will be reset if gotosocial is restarted!
+func sessionStore() (memstore.Store, error) {
+ auth := make([]byte, 32)
+ crypt := make([]byte, 32)
+
+ if _, err := rand.Read(auth); err != nil {
+ return nil, err
+ }
+ if _, err := rand.Read(crypt); err != nil {
+ return nil, err
+ }
+
+ return memstore.NewStore(auth, crypt), nil
+}
diff --git a/web/template/authorize.tmpl b/web/template/authorize.tmpl
@@ -2,7 +2,7 @@
<html lang="en">
<head>
<meta charset="UTF-8" />
- <title>Auth</title>
+ <title>GoToSocial Authorization</title>
<link
rel="stylesheet"
href="//maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css"
@@ -11,13 +11,13 @@
<script src="//maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
</head>
-{{if len .code | eq 0 }}
<body>
<div class="container">
<div class="jumbotron">
<form action="/oauth/authorize" method="POST">
- <h1>Authorize</h1>
- <p>The client would like to perform actions on your behalf.</p>
+ <h1>Hi {{.user}}!</h1>
+ <p>Application <b>{{.appname}}</b> {{if len .appwebsite | eq 0 | not}}({{.appwebsite}}) {{end}}would like to perform actions on your behalf, with scope <em>{{.scope}}</em>.</p>
+ <p>The application will redirect to {{.redirect}} to continue.</p>
<p>
<button
type="submit"
@@ -31,14 +31,4 @@
</div>
</div>
</body>
-{{else}}
- <body>
- <div class="container">
- <div class="jumbotron">
- {{.code}}
- </div>
- </div>
- </body>
-{{end}}
-
</html>