@@ -2,10 +2,12 @@ package core
22
33import (
44 "bytes"
5+ "compress/flate"
56 "context"
67 "encoding/json"
78 "errors"
89 "fmt"
10+ "io"
911 "net"
1012 "net/http"
1113 "regexp"
@@ -17,6 +19,7 @@ import (
1719 "github.com/buger/jsonparser"
1820 "github.com/go-chi/chi/v5/middleware"
1921 "github.com/gobwas/ws"
22+ "github.com/gobwas/ws/wsflate"
2023 "github.com/gobwas/ws/wsutil"
2124 "github.com/gorilla/websocket"
2225 "github.com/tidwall/gjson"
@@ -67,6 +70,13 @@ type WebsocketMiddlewareOptions struct {
6770 ApolloCompatibilityFlags config.ApolloCompatibilityFlags
6871}
6972
73+ // NewWebsocketMiddleware creates an HTTP middleware that upgrades eligible requests to WebSocket
74+ // connections and dispatches them to an internal WebsocketHandler configured by opts.
75+ //
76+ // The returned middleware delegates non-WebSocket requests to the next handler. Options in
77+ // WebsocketMiddlewareOptions control timeouts, compression and protocol features, access control,
78+ // header/query param forwarding allow-lists, net-poll integration, and the components used to
79+ // process GraphQL operations over WebSocket.
7080func NewWebsocketMiddleware (ctx context.Context , opts WebsocketMiddlewareOptions ) func (http.Handler ) http.Handler {
7181
7282 handler := & WebsocketHandler {
@@ -87,6 +97,13 @@ func NewWebsocketMiddleware(ctx context.Context, opts WebsocketMiddlewareOptions
8797 disableVariablesRemapping : opts .DisableVariablesRemapping ,
8898 apolloCompatibilityFlags : opts .ApolloCompatibilityFlags ,
8999 }
100+ if opts .WebSocketConfiguration != nil && opts .WebSocketConfiguration .Compression .Enabled {
101+ handler .compressionEnabled = true
102+ handler .compressionLevel = opts .WebSocketConfiguration .Compression .Level
103+ if handler .compressionLevel < 1 || handler .compressionLevel > 9 {
104+ handler .compressionLevel = flate .DefaultCompression
105+ }
106+ }
90107 if opts .WebSocketConfiguration != nil && opts .WebSocketConfiguration .AbsintheProtocol .Enabled {
91108 handler .absintheHandlerEnabled = true
92109 handler .absintheHandlerPath = opts .WebSocketConfiguration .AbsintheProtocol .HandlerPath
@@ -156,13 +173,20 @@ type wsConnectionWrapper struct {
156173 mu sync.Mutex
157174 readTimeout time.Duration
158175 writeTimeout time.Duration
176+
177+ // Compression fields
178+ compressionEnabled bool
179+ compressionLevel int
159180}
160181
161- func newWSConnectionWrapper (conn net.Conn , readTimeout , writeTimeout time.Duration ) * wsConnectionWrapper {
182+ // deflate compression level when enabled (typically 1–9 or flate.DefaultCompression).
183+ func newWSConnectionWrapper (conn net.Conn , readTimeout , writeTimeout time.Duration , compressionEnabled bool , compressionLevel int ) * wsConnectionWrapper {
162184 return & wsConnectionWrapper {
163- conn : conn ,
164- readTimeout : readTimeout ,
165- writeTimeout : writeTimeout ,
185+ conn : conn ,
186+ readTimeout : readTimeout ,
187+ writeTimeout : writeTimeout ,
188+ compressionEnabled : compressionEnabled ,
189+ compressionLevel : compressionLevel ,
166190 }
167191}
168192
@@ -175,9 +199,51 @@ func (c *wsConnectionWrapper) ReadJSON(v any) error {
175199 }
176200 }
177201
178- text , err := wsutil .ReadClientText (c .conn )
179- if err != nil {
180- return err
202+ var text []byte
203+ var err error
204+
205+ if c .compressionEnabled {
206+ // Read frames directly and handle compression
207+ controlHandler := wsutil .ControlFrameHandler (c .conn , ws .StateServerSide )
208+ for {
209+ frame , err := ws .ReadFrame (c .conn )
210+ if err != nil {
211+ return err
212+ }
213+
214+ // Unmask client frames
215+ if frame .Header .Masked {
216+ ws .Cipher (frame .Payload , frame .Header .Mask , 0 )
217+ }
218+
219+ if frame .Header .OpCode .IsControl () {
220+ if err := controlHandler (frame .Header , bytes .NewReader (frame .Payload )); err != nil {
221+ return err
222+ }
223+ continue
224+ }
225+
226+ if frame .Header .OpCode == ws .OpText || frame .Header .OpCode == ws .OpBinary {
227+ // Check if frame is compressed (RSV1 bit set)
228+ isCompressed , err := wsflate .IsCompressed (frame .Header )
229+ if err != nil {
230+ return err
231+ }
232+ if isCompressed {
233+ frame , err = wsflate .DecompressFrame (frame )
234+ if err != nil {
235+ return err
236+ }
237+ }
238+ text = frame .Payload
239+ break
240+ }
241+ }
242+ } else {
243+ text , err = wsutil .ReadClientText (c .conn )
244+ if err != nil {
245+ return err
246+ }
181247 }
182248
183249 return json .Unmarshal (text , v )
@@ -195,6 +261,10 @@ func (c *wsConnectionWrapper) WriteText(text string) error {
195261 }
196262 }
197263
264+ if c .compressionEnabled {
265+ return c .writeCompressed ([]byte (text ))
266+ }
267+
198268 return wsutil .WriteServerText (c .conn , []byte (text ))
199269}
200270
@@ -213,9 +283,32 @@ func (c *wsConnectionWrapper) WriteJSON(v any) error {
213283 }
214284 }
215285
286+ if c .compressionEnabled {
287+ return c .writeCompressed (data )
288+ }
289+
216290 return wsutil .WriteServerText (c .conn , data )
217291}
218292
293+ // writeCompressed writes data with compression. Must be called with the mutex held.
294+ func (c * wsConnectionWrapper ) writeCompressed (data []byte ) error {
295+ var buf bytes.Buffer
296+ writer := wsflate .NewWriter (& buf , func (w io.Writer ) wsflate.Compressor {
297+ fw , _ := flate .NewWriter (w , c .compressionLevel )
298+ return fw
299+ })
300+ if _ , err := writer .Write (data ); err != nil {
301+ return err
302+ }
303+ if err := writer .Flush (); err != nil {
304+ return err
305+ }
306+
307+ frame := ws .NewFrame (ws .OpText , true , buf .Bytes ())
308+ frame .Header .Rsv = ws .Rsv (true , false , false ) // Set RSV1 bit for compression
309+ return ws .WriteFrame (c .conn , frame )
310+ }
311+
219312func (c * wsConnectionWrapper ) WriteCloseFrame (code ws.StatusCode , reason string ) error {
220313 c .mu .Lock ()
221314 defer c .mu .Unlock ()
@@ -267,6 +360,9 @@ type WebsocketHandler struct {
267360 disableVariablesRemapping bool
268361
269362 apolloCompatibilityFlags config.ApolloCompatibilityFlags
363+
364+ compressionEnabled bool
365+ compressionLevel int
270366}
271367
272368func (h * WebsocketHandler ) handleUpgradeRequest (w http.ResponseWriter , r * http.Request ) {
@@ -309,7 +405,29 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R
309405 return false
310406 },
311407 }
408+
409+ // Configure permessage-deflate compression if enabled
410+ var compressionNegotiated bool
411+ var ext wsflate.Extension
412+ if h .compressionEnabled {
413+ ext = wsflate.Extension {
414+ Parameters : wsflate.Parameters {
415+ ServerNoContextTakeover : true ,
416+ ClientNoContextTakeover : true ,
417+ },
418+ }
419+ upgrader .Negotiate = ext .Negotiate
420+ }
421+
312422 c , _ , _ , err := upgrader .Upgrade (r , w )
423+
424+ // Check if compression was negotiated
425+ if h .compressionEnabled && err == nil {
426+ if _ , accepted := ext .Accepted (); accepted {
427+ compressionNegotiated = true
428+ }
429+ }
430+
313431 if err != nil {
314432 requestLogger .Warn ("Websocket upgrade" , zap .Error (err ))
315433 _ = c .Close ()
@@ -325,7 +443,7 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R
325443 // After successful upgrade, we can't write to the response writer anymore
326444 // because it's hijacked by the websocket connection
327445
328- conn := newWSConnectionWrapper (c , h .readTimeout , h .writeTimeout )
446+ conn := newWSConnectionWrapper (c , h .readTimeout , h .writeTimeout , compressionNegotiated , h . compressionLevel )
329447 protocol , err := wsproto .NewProtocol (subProtocol , conn )
330448 if err != nil {
331449 requestLogger .Error ("Create websocket protocol" , zap .Error (err ))
@@ -1282,4 +1400,4 @@ func (h *WebSocketConnectionHandler) Close(unsubscribe bool) {
12821400 if err != nil {
12831401 h .logger .Debug ("Closing websocket connection" , zap .Error (err ))
12841402 }
1285- }
1403+ }
0 commit comments