Skip to content

Commit 1a6fdac

Browse files
📝 Add docstrings to per-message-deflate
Docstrings generation was requested by @dbinnersley. * #2424 (comment) The following files were modified: * `router/core/websocket.go`
1 parent 2173504 commit 1a6fdac

File tree

1 file changed

+127
-9
lines changed

1 file changed

+127
-9
lines changed

router/core/websocket.go

Lines changed: 127 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ package core
22

33
import (
44
"bytes"
5+
"compress/flate"
56
"context"
67
"encoding/json"
78
"errors"
89
"fmt"
10+
"io"
911
"net"
1012
"net/http"
1113
"regexp"
@@ -17,6 +19,7 @@ import (
1719
"github.com/buger/jsonparser"
1820
"github.com/go-chi/chi/v5/middleware"
1921
"github.com/gobwas/ws"
22+
"github.com/gobwas/ws/wsflate"
2023
"github.com/gobwas/ws/wsutil"
2124
"github.com/gorilla/websocket"
2225
"github.com/tidwall/gjson"
@@ -67,6 +70,13 @@ type WebsocketMiddlewareOptions struct {
6770
ApolloCompatibilityFlags config.ApolloCompatibilityFlags
6871
}
6972

73+
// NewWebsocketMiddleware creates an HTTP middleware that upgrades eligible requests to WebSocket
74+
// connections and dispatches them to an internal WebsocketHandler configured by opts.
75+
//
76+
// The returned middleware delegates non-WebSocket requests to the next handler. Options in
77+
// WebsocketMiddlewareOptions control timeouts, compression and protocol features, access control,
78+
// header/query param forwarding allow-lists, net-poll integration, and the components used to
79+
// process GraphQL operations over WebSocket.
7080
func NewWebsocketMiddleware(ctx context.Context, opts WebsocketMiddlewareOptions) func(http.Handler) http.Handler {
7181

7282
handler := &WebsocketHandler{
@@ -87,6 +97,13 @@ func NewWebsocketMiddleware(ctx context.Context, opts WebsocketMiddlewareOptions
8797
disableVariablesRemapping: opts.DisableVariablesRemapping,
8898
apolloCompatibilityFlags: opts.ApolloCompatibilityFlags,
8999
}
100+
if opts.WebSocketConfiguration != nil && opts.WebSocketConfiguration.Compression.Enabled {
101+
handler.compressionEnabled = true
102+
handler.compressionLevel = opts.WebSocketConfiguration.Compression.Level
103+
if handler.compressionLevel < 1 || handler.compressionLevel > 9 {
104+
handler.compressionLevel = flate.DefaultCompression
105+
}
106+
}
90107
if opts.WebSocketConfiguration != nil && opts.WebSocketConfiguration.AbsintheProtocol.Enabled {
91108
handler.absintheHandlerEnabled = true
92109
handler.absintheHandlerPath = opts.WebSocketConfiguration.AbsintheProtocol.HandlerPath
@@ -156,13 +173,20 @@ type wsConnectionWrapper struct {
156173
mu sync.Mutex
157174
readTimeout time.Duration
158175
writeTimeout time.Duration
176+
177+
// Compression fields
178+
compressionEnabled bool
179+
compressionLevel int
159180
}
160181

161-
func newWSConnectionWrapper(conn net.Conn, readTimeout, writeTimeout time.Duration) *wsConnectionWrapper {
182+
// deflate compression level when enabled (typically 1–9 or flate.DefaultCompression).
183+
func newWSConnectionWrapper(conn net.Conn, readTimeout, writeTimeout time.Duration, compressionEnabled bool, compressionLevel int) *wsConnectionWrapper {
162184
return &wsConnectionWrapper{
163-
conn: conn,
164-
readTimeout: readTimeout,
165-
writeTimeout: writeTimeout,
185+
conn: conn,
186+
readTimeout: readTimeout,
187+
writeTimeout: writeTimeout,
188+
compressionEnabled: compressionEnabled,
189+
compressionLevel: compressionLevel,
166190
}
167191
}
168192

@@ -175,9 +199,51 @@ func (c *wsConnectionWrapper) ReadJSON(v any) error {
175199
}
176200
}
177201

178-
text, err := wsutil.ReadClientText(c.conn)
179-
if err != nil {
180-
return err
202+
var text []byte
203+
var err error
204+
205+
if c.compressionEnabled {
206+
// Read frames directly and handle compression
207+
controlHandler := wsutil.ControlFrameHandler(c.conn, ws.StateServerSide)
208+
for {
209+
frame, err := ws.ReadFrame(c.conn)
210+
if err != nil {
211+
return err
212+
}
213+
214+
// Unmask client frames
215+
if frame.Header.Masked {
216+
ws.Cipher(frame.Payload, frame.Header.Mask, 0)
217+
}
218+
219+
if frame.Header.OpCode.IsControl() {
220+
if err := controlHandler(frame.Header, bytes.NewReader(frame.Payload)); err != nil {
221+
return err
222+
}
223+
continue
224+
}
225+
226+
if frame.Header.OpCode == ws.OpText || frame.Header.OpCode == ws.OpBinary {
227+
// Check if frame is compressed (RSV1 bit set)
228+
isCompressed, err := wsflate.IsCompressed(frame.Header)
229+
if err != nil {
230+
return err
231+
}
232+
if isCompressed {
233+
frame, err = wsflate.DecompressFrame(frame)
234+
if err != nil {
235+
return err
236+
}
237+
}
238+
text = frame.Payload
239+
break
240+
}
241+
}
242+
} else {
243+
text, err = wsutil.ReadClientText(c.conn)
244+
if err != nil {
245+
return err
246+
}
181247
}
182248

183249
return json.Unmarshal(text, v)
@@ -195,6 +261,10 @@ func (c *wsConnectionWrapper) WriteText(text string) error {
195261
}
196262
}
197263

264+
if c.compressionEnabled {
265+
return c.writeCompressed([]byte(text))
266+
}
267+
198268
return wsutil.WriteServerText(c.conn, []byte(text))
199269
}
200270

@@ -213,9 +283,32 @@ func (c *wsConnectionWrapper) WriteJSON(v any) error {
213283
}
214284
}
215285

286+
if c.compressionEnabled {
287+
return c.writeCompressed(data)
288+
}
289+
216290
return wsutil.WriteServerText(c.conn, data)
217291
}
218292

293+
// writeCompressed writes data with compression. Must be called with the mutex held.
294+
func (c *wsConnectionWrapper) writeCompressed(data []byte) error {
295+
var buf bytes.Buffer
296+
writer := wsflate.NewWriter(&buf, func(w io.Writer) wsflate.Compressor {
297+
fw, _ := flate.NewWriter(w, c.compressionLevel)
298+
return fw
299+
})
300+
if _, err := writer.Write(data); err != nil {
301+
return err
302+
}
303+
if err := writer.Flush(); err != nil {
304+
return err
305+
}
306+
307+
frame := ws.NewFrame(ws.OpText, true, buf.Bytes())
308+
frame.Header.Rsv = ws.Rsv(true, false, false) // Set RSV1 bit for compression
309+
return ws.WriteFrame(c.conn, frame)
310+
}
311+
219312
func (c *wsConnectionWrapper) WriteCloseFrame(code ws.StatusCode, reason string) error {
220313
c.mu.Lock()
221314
defer c.mu.Unlock()
@@ -267,6 +360,9 @@ type WebsocketHandler struct {
267360
disableVariablesRemapping bool
268361

269362
apolloCompatibilityFlags config.ApolloCompatibilityFlags
363+
364+
compressionEnabled bool
365+
compressionLevel int
270366
}
271367

272368
func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.Request) {
@@ -309,7 +405,29 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R
309405
return false
310406
},
311407
}
408+
409+
// Configure permessage-deflate compression if enabled
410+
var compressionNegotiated bool
411+
var ext wsflate.Extension
412+
if h.compressionEnabled {
413+
ext = wsflate.Extension{
414+
Parameters: wsflate.Parameters{
415+
ServerNoContextTakeover: true,
416+
ClientNoContextTakeover: true,
417+
},
418+
}
419+
upgrader.Negotiate = ext.Negotiate
420+
}
421+
312422
c, _, _, err := upgrader.Upgrade(r, w)
423+
424+
// Check if compression was negotiated
425+
if h.compressionEnabled && err == nil {
426+
if _, accepted := ext.Accepted(); accepted {
427+
compressionNegotiated = true
428+
}
429+
}
430+
313431
if err != nil {
314432
requestLogger.Warn("Websocket upgrade", zap.Error(err))
315433
_ = c.Close()
@@ -325,7 +443,7 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R
325443
// After successful upgrade, we can't write to the response writer anymore
326444
// because it's hijacked by the websocket connection
327445

328-
conn := newWSConnectionWrapper(c, h.readTimeout, h.writeTimeout)
446+
conn := newWSConnectionWrapper(c, h.readTimeout, h.writeTimeout, compressionNegotiated, h.compressionLevel)
329447
protocol, err := wsproto.NewProtocol(subProtocol, conn)
330448
if err != nil {
331449
requestLogger.Error("Create websocket protocol", zap.Error(err))
@@ -1282,4 +1400,4 @@ func (h *WebSocketConnectionHandler) Close(unsubscribe bool) {
12821400
if err != nil {
12831401
h.logger.Debug("Closing websocket connection", zap.Error(err))
12841402
}
1285-
}
1403+
}

0 commit comments

Comments
 (0)