commit e323a930bff2c3a7d6b591e0bdcd092a5ed60f18
parent cb2f84e551727bd1852ed5fd93777289d3439bbf
Author: darrinsmart <darrin@djs.to>
Date: Sat, 11 Mar 2023 02:10:58 -0800
[feature] Support multiple subscriptions on single websocket connection (#1489)
- Allow Oauth authentication on websocket endpoint
- Make streamType query parameter optional
- Read websocket commands from client and update subscriptions
Diffstat:
4 files changed, 74 insertions(+), 26 deletions(-)
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go
@@ -20,14 +20,16 @@ package streaming
import (
"context"
- "errors"
- "fmt"
"time"
"codeberg.org/gruf/go-kv"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/oauth"
+ streampkg "github.com/superseriousbusiness/gotosocial/internal/stream"
+ "golang.org/x/exp/slices"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
@@ -134,32 +136,37 @@ import (
// '400':
// description: bad request
func (m *Module) StreamGETHandler(c *gin.Context) {
- streamType := c.Query(StreamQueryKey)
- if streamType == "" {
- err := fmt.Errorf("no stream type provided under query key %s", StreamQueryKey)
- apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
- return
- }
-
- var token string
// First we check for a query param provided access token
- if token = c.Query(AccessTokenQueryKey); token == "" {
+ token := c.Query(AccessTokenQueryKey)
+ if token == "" {
// Else we check the HTTP header provided token
- if token = c.GetHeader(AccessTokenHeader); token == "" {
- const errStr = "no access token provided"
- err := gtserror.NewErrorUnauthorized(errors.New(errStr), errStr)
+ token = c.GetHeader(AccessTokenHeader)
+ }
+
+ var account *gtsmodel.Account
+ if token != "" {
+ // Check the explicit token
+ var errWithCode gtserror.WithCode
+ account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token)
+ if errWithCode != nil {
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
+ return
+ }
+ } else {
+ // If no explicit token was provided, try regular oauth
+ auth, errStr := oauth.Authed(c, true, true, true, true)
+ if errStr != nil {
+ err := gtserror.NewErrorUnauthorized(errStr, errStr.Error())
apiutil.ErrorHandler(c, err, m.processor.InstanceGetV1)
return
}
+ account = auth.Account
}
- account, errWithCode := m.processor.Stream().Authorize(c.Request.Context(), token)
- if errWithCode != nil {
- apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
- return
- }
-
+ // Get the initial stream type, if there is one.
+ // streamType will be an empty string if one wasn't supplied. Open() will deal with this
+ streamType := c.Query(StreamQueryKey)
stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
@@ -219,8 +226,9 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
// We have to listen for received websocket messages in
// order to trigger the underlying wsConn.PingHandler().
//
- // So we wait on received messages but only act on errors.
- _, _, err := wsConn.ReadMessage()
+ // Read JSON objects from the client and act on them
+ var msg map[string]string
+ err := wsConn.ReadJSON(&msg)
if err != nil {
if ctx.Err() == nil {
// Only log error if the connection was not closed
@@ -229,6 +237,33 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
}
return
}
+ l.Tracef("received message from websocket: %v", msg)
+
+ // If the message contains 'stream' and 'type' fields, we can
+ // update the set of timelines that are subscribed for events.
+ // everything else is ignored.
+ action := msg["type"]
+ streamType := msg["stream"]
+
+ // Ignore if the streamType is unknown (or missing), so a bad
+ // client can't cause extra memory allocations
+ if !slices.Contains(streampkg.AllStatusTimelines, streamType) {
+ l.Warnf("Unknown 'stream' field: %v", msg)
+ continue
+ }
+
+ switch action {
+ case "subscribe":
+ stream.Lock()
+ stream.Timelines[streamType] = true
+ stream.Unlock()
+ case "unsubscribe":
+ stream.Lock()
+ delete(stream.Timelines, streamType)
+ stream.Unlock()
+ default:
+ l.Warnf("Invalid 'type' field: %v", msg)
+ }
}
}()
diff --git a/internal/processing/stream/open.go b/internal/processing/stream/open.go
@@ -45,9 +45,17 @@ func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamT
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err))
}
+ // Each stream can be subscibed to multiple timelines.
+ // Record them in a set, and include the initial one
+ // if it was given to us
+ timelines := map[string]bool{}
+ if streamTimeline != "" {
+ timelines[streamTimeline] = true
+ }
+
thisStream := &stream.Stream{
ID: streamID,
- Timeline: streamTimeline,
+ Timelines: timelines,
Messages: make(chan *stream.Message, 100),
Hangup: make(chan interface{}, 1),
Connected: true,
diff --git a/internal/processing/stream/stream.go b/internal/processing/stream/stream.go
@@ -63,12 +63,15 @@ func (p *Processor) toAccount(payload string, event string, timelines []string,
}
for _, t := range timelines {
- if s.Timeline == string(t) {
+ if _, found := s.Timelines[t]; found {
s.Messages <- &stream.Message{
Stream: []string{string(t)},
Event: string(event),
Payload: payload,
}
+ // break out to the outer loop, to avoid sending duplicates
+ // of the same event to the same stream
+ break
}
}
}
diff --git a/internal/stream/stream.go b/internal/stream/stream.go
@@ -63,8 +63,10 @@ type StreamsForAccount struct {
type Stream struct {
// ID of this stream, generated during creation.
ID string
- // Timeline of this stream: user/public/etc
- Timeline string
+ // A set of timelines of this stream: user/public/etc
+ // a matching key means the timeline is subscribed. The value
+ // is ignored
+ Timelines map[string]bool
// Channel of messages for the client to read from
Messages chan *Message
// Channel to close when the client drops away