commit 1bda6a2002b18d3732b24d2496b88e137a32968b
parent 98edd75f1b0e0e63520b734c2ce20ed43794b5ef
Author: kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com>
Date: Sun, 8 Jan 2023 11:43:08 +0000
[bugfix] return early in websocket upgrade handler (#1315)
* launch websocket streaming in goroutine to allow upgrade handler to return
* don't send any message on ping, improved close check on failed read
* use context to signal wsconn close, ensure canceled in read goroutine
Signed-off-by: kim <grufwub@gmail.com>
Diffstat:
4 files changed, 111 insertions(+), 71 deletions(-)
diff --git a/internal/api/client.go b/internal/api/client.go
@@ -19,6 +19,8 @@
package api
import (
+ "time"
+
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin"
@@ -122,7 +124,7 @@ func NewClient(db db.DB, p processing.Processor) *Client {
notifications: notifications.New(p),
search: search.New(p),
statuses: statuses.New(p),
- streaming: streaming.New(p),
+ streaming: streaming.New(p, time.Second*30, 4096),
timelines: timelines.New(p),
user: user.New(p),
}
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go
@@ -19,8 +19,9 @@
package streaming
import (
+ "context"
+ "errors"
"fmt"
- "net/http"
"time"
"codeberg.org/gruf/go-kv"
@@ -32,16 +33,6 @@ import (
"github.com/gorilla/websocket"
)
-var (
- wsUpgrader = websocket.Upgrader{
- ReadBufferSize: 1024,
- WriteBufferSize: 1024,
- // we expect cors requests (via eg., pinafore.social) so be lenient
- CheckOrigin: func(r *http.Request) bool { return true },
- }
- errNoToken = fmt.Errorf("no access token provided under query key %s or under header %s", AccessTokenQueryKey, AccessTokenHeader)
-)
-
// StreamGETHandler swagger:operation GET /api/v1/streaming streamGet
//
// Initiate a websocket connection for live streaming of statuses and notifications.
@@ -150,21 +141,20 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
return
}
- var accessToken string
- if t := c.Query(AccessTokenQueryKey); t != "" {
- // try query param first
- accessToken = t
- } else if t := c.GetHeader(AccessTokenHeader); t != "" {
- // fall back to Sec-Websocket-Protocol
- accessToken = t
- } else {
- // no token
- err := errNoToken
- apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
- return
+ var token string
+
+ // First we check for a query param provided access token
+ if token = c.Query(AccessTokenQueryKey); token == "" {
+ // Else we check the HTTP header provided token
+ if token = c.GetHeader(AccessTokenHeader); token == "" {
+ const errStr = "no access token provided"
+ err := gtserror.NewErrorUnauthorized(errors.New(errStr), errStr)
+ apiutil.ErrorHandler(c, err, m.processor.InstanceGet)
+ return
+ }
}
- account, errWithCode := m.processor.AuthorizeStreamingRequest(c.Request.Context(), accessToken)
+ account, errWithCode := m.processor.AuthorizeStreamingRequest(c.Request.Context(), token)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return
@@ -178,51 +168,97 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
l := log.WithFields(kv.Fields{
{"account", account.Username},
- {"path", BasePath},
{"streamID", stream.ID},
{"streamType", streamType},
}...)
- wsConn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
+ // Upgrade the incoming HTTP request, which hijacks the underlying
+ // connection and reuses it for the websocket (non-http) protocol.
+ wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
if err != nil {
- // If the upgrade fails, then Upgrade replies to the client with an HTTP error response.
- // Because websocket issues are a pretty common source of headaches, we should also log
- // this at Error to make this plenty visible and help admins out a bit.
- l.Errorf("error upgrading websocket connection: %s", err)
+ l.Errorf("error upgrading websocket connection: %v", err)
close(stream.Hangup)
return
}
- defer func() {
- // cleanup
- wsConn.Close()
- close(stream.Hangup)
- }()
+ go func() {
+ // We perform the main websocket send loop in a separate
+ // goroutine in order to let the upgrade handler return.
+ // This prevents the upgrade handler from holding open any
+ // throttle / rate-limit request tokens which could become
+ // problematic on instances with multiple users.
+ l.Info("opened websocket connection")
+ defer l.Info("closed websocket connection")
+
+ // Create new context for lifetime of the connection
+ ctx, cncl := context.WithCancel(context.Background())
+
+ // Create ticker to send alive pings
+ pinger := time.NewTicker(m.dTicker)
+
+ defer func() {
+ // Signal done
+ cncl()
- streamTicker := time.NewTicker(m.tickDuration)
- defer streamTicker.Stop()
-
- // We want to stay in the loop as long as possible while the client is connected.
- // The only thing that should break the loop is if the client leaves or the connection becomes unhealthy.
- //
- // If the loop does break, we expect the client to reattempt connection, so it's cheap to leave + try again
-wsLoop:
- for {
- select {
- case m := <-stream.Messages:
- l.Trace("received message from stream")
- if err := wsConn.WriteJSON(m); err != nil {
- l.Debugf("error writing json to websocket connection; breaking off: %s", err)
- break wsLoop
+ // Close websocket conn
+ _ = wsConn.Close()
+
+ // Close processor stream
+ close(stream.Hangup)
+
+ // Stop ping ticker
+ pinger.Stop()
+ }()
+
+ go func() {
+ // Signal done
+ defer cncl()
+
+ for {
+ // We have to listen for received websocket messages in
+ // order to trigger the underlying wsConn.PingHandler().
+ //
+ // So we wait on received messages but only act on errors.
+ _, _, err := wsConn.ReadMessage()
+ if err != nil {
+ if ctx.Err() == nil {
+ // Only log error if the connection was not closed
+ // by us. Uncanceled context indicates this is the case.
+ l.Errorf("error reading from websocket: %v", err)
+ }
+ return
+ }
}
- l.Trace("wrote message into websocket connection")
- case <-streamTicker.C:
- l.Trace("received TICK from ticker")
- if err := wsConn.WriteMessage(websocket.PingMessage, []byte(": ping")); err != nil {
- l.Debugf("error writing ping to websocket connection; breaking off: %s", err)
- break wsLoop
+ }()
+
+ for {
+ select {
+ // Connection closed
+ case <-ctx.Done():
+ return
+
+ // Received next stream message
+ case msg := <-stream.Messages:
+ l.Tracef("sending message to websocket: %+v", msg)
+ if err := wsConn.WriteJSON(msg); err != nil {
+ l.Errorf("error writing json to websocket: %v", err)
+ return
+ }
+
+ // Reset on each successful send.
+ pinger.Reset(m.dTicker)
+
+ // Send keep-alive "ping"
+ case <-pinger.C:
+ l.Trace("pinging websocket ...")
+ if err := wsConn.WriteMessage(
+ websocket.PingMessage,
+ []byte{},
+ ); err != nil {
+ l.Errorf("error writing ping to websocket: %v", err)
+ return
+ }
}
- l.Trace("wrote ping message into websocket connection")
}
- }
+ }()
}
diff --git a/internal/api/client/streaming/streaming.go b/internal/api/client/streaming/streaming.go
@@ -23,6 +23,7 @@ import (
"time"
"github.com/gin-gonic/gin"
+ "github.com/gorilla/websocket"
"github.com/superseriousbusiness/gotosocial/internal/processing"
)
@@ -41,21 +42,22 @@ const (
)
type Module struct {
- processor processing.Processor
- tickDuration time.Duration
+ processor processing.Processor
+ dTicker time.Duration
+ wsUpgrade websocket.Upgrader
}
-func New(processor processing.Processor) *Module {
+func New(processor processing.Processor, dTicker time.Duration, wsBuf int) *Module {
return &Module{
- processor: processor,
- tickDuration: 30 * time.Second,
- }
-}
+ processor: processor,
+ dTicker: dTicker,
+ wsUpgrade: websocket.Upgrader{
+ ReadBufferSize: wsBuf, // we don't expect reads
+ WriteBufferSize: wsBuf,
-func NewWithTickDuration(processor processing.Processor, tickDuration time.Duration) *Module {
- return &Module{
- processor: processor,
- tickDuration: tickDuration,
+ // we expect cors requests (via eg., pinafore.social) so be lenient
+ CheckOrigin: func(r *http.Request) bool { return true },
+ },
}
}
diff --git a/internal/api/client/streaming/streaming_test.go b/internal/api/client/streaming/streaming_test.go
@@ -99,7 +99,7 @@ func (suite *StreamingTestSuite) SetupTest() {
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
- suite.streamingModule = streaming.NewWithTickDuration(suite.processor, 1)
+ suite.streamingModule = streaming.New(suite.processor, 1, 4096)
suite.NoError(suite.processor.Start())
}