@@ -19,6 +19,7 @@ package responsewriters
19
19
import (
20
20
"bytes"
21
21
"compress/gzip"
22
+ "context"
22
23
"encoding/hex"
23
24
"encoding/json"
24
25
"errors"
@@ -32,6 +33,7 @@ import (
32
33
"os"
33
34
"reflect"
34
35
"strconv"
36
+ "strings"
35
37
"testing"
36
38
"time"
37
39
@@ -371,6 +373,124 @@ func TestSerializeObject(t *testing.T) {
371
373
}
372
374
}
373
375
376
+ func TestDeferredResponseWriter_Write (t * testing.T ) {
377
+ smallChunk := bytes .Repeat ([]byte ("b" ), defaultGzipThresholdBytes - 1 )
378
+ largeChunk := bytes .Repeat ([]byte ("b" ), defaultGzipThresholdBytes + 1 )
379
+
380
+ tests := []struct {
381
+ name string
382
+ chunks [][]byte
383
+ expectGzip bool
384
+ }{
385
+ {
386
+ name : "one small chunk write" ,
387
+ chunks : [][]byte {smallChunk },
388
+ expectGzip : false ,
389
+ },
390
+ {
391
+ name : "two small chunk writes" ,
392
+ chunks : [][]byte {smallChunk , smallChunk },
393
+ expectGzip : false ,
394
+ },
395
+ {
396
+ name : "one large chunk writes" ,
397
+ chunks : [][]byte {largeChunk },
398
+ expectGzip : true ,
399
+ },
400
+ {
401
+ name : "two large chunk writes" ,
402
+ chunks : [][]byte {largeChunk , largeChunk },
403
+ expectGzip : true ,
404
+ },
405
+ }
406
+
407
+ for _ , tt := range tests {
408
+ t .Run (tt .name , func (t * testing.T ) {
409
+ mockResponseWriter := httptest .NewRecorder ()
410
+
411
+ drw := & deferredResponseWriter {
412
+ mediaType : "text/plain" ,
413
+ statusCode : 200 ,
414
+ contentEncoding : "gzip" ,
415
+ hw : mockResponseWriter ,
416
+ ctx : context .Background (),
417
+ }
418
+
419
+ fullPayload := []byte {}
420
+
421
+ for _ , chunk := range tt .chunks {
422
+ n , err := drw .Write (chunk )
423
+
424
+ if err != nil {
425
+ t .Fatalf ("unexpected error while writing chunk: %v" , err )
426
+ }
427
+ if n != len (chunk ) {
428
+ t .Errorf ("write is not complete, expected: %d bytes, written: %d bytes" , len (chunk ), n )
429
+ }
430
+
431
+ fullPayload = append (fullPayload , chunk ... )
432
+ }
433
+
434
+ err := drw .Close ()
435
+ if err != nil {
436
+ t .Fatalf ("unexpected error when closing deferredResponseWriter: %v" , err )
437
+ }
438
+
439
+ res := mockResponseWriter .Result ()
440
+
441
+ if res .StatusCode != http .StatusOK {
442
+ t .Fatalf ("status code is not writtend properly, expected: 200, got: %d" , res .StatusCode )
443
+ }
444
+ contentEncoding := res .Header .Get ("Content-Encoding" )
445
+ varyHeader := res .Header .Get ("Vary" )
446
+
447
+ resBytes , err := io .ReadAll (res .Body )
448
+ if err != nil {
449
+ t .Fatalf ("unexpected error occurred while reading response body: %v" , err )
450
+ }
451
+
452
+ if tt .expectGzip {
453
+ if contentEncoding != "gzip" {
454
+ t .Fatalf ("content-encoding is not set properly, expected: gzip, got: %s" , contentEncoding )
455
+ }
456
+
457
+ if ! strings .Contains (varyHeader , "Accept-Encoding" ) {
458
+ t .Errorf ("vary header doesn't have Accept-Encoding" )
459
+ }
460
+
461
+ gr , err := gzip .NewReader (bytes .NewReader (resBytes ))
462
+ if err != nil {
463
+ t .Fatalf ("failed to create gzip reader: %v" , err )
464
+ }
465
+
466
+ decompressed , err := io .ReadAll (gr )
467
+ if err != nil {
468
+ t .Fatalf ("failed to decompress: %v" , err )
469
+ }
470
+
471
+ if ! bytes .Equal (fullPayload , decompressed ) {
472
+ t .Errorf ("payload mismatch, expected: %s, got: %s" , fullPayload , decompressed )
473
+ }
474
+
475
+ } else {
476
+ if contentEncoding != "" {
477
+ t .Errorf ("content-encoding is set unexpectedly" )
478
+ }
479
+
480
+ if strings .Contains (varyHeader , "Accept-Encoding" ) {
481
+ t .Errorf ("accept encoding is set unexpectedly" )
482
+ }
483
+
484
+ if ! bytes .Equal (fullPayload , resBytes ) {
485
+ t .Errorf ("payload mismatch, expected: %s, got: %s" , fullPayload , resBytes )
486
+ }
487
+
488
+ }
489
+
490
+ })
491
+ }
492
+ }
493
+
374
494
func randTime (t * time.Time , r * rand.Rand ) {
375
495
* t = time .Unix (r .Int63n (1000 * 365 * 24 * 60 * 60 ), r .Int63 ())
376
496
}
0 commit comments