@@ -16,6 +16,7 @@ import (
1616 "github.com/gin-contrib/sessions"
1717 "github.com/gin-contrib/sessions/cookie"
1818 "github.com/gin-gonic/gin"
19+ csrf "github.com/srbry/gin-csrf"
1920 "github.com/stretchr/testify/mock"
2021 "go.uber.org/zap"
2122 "go.uber.org/zap/zapcore"
@@ -43,6 +44,10 @@ var _ = Describe("Login", func() {
4344 session sessions.Session
4445 observedLogs * observer.ObservedLogs
4546 observedZapCore zapcore.Core
47+ csrfManager = & csrf.DefaultCSRFManager {
48+ Secret : "fwibble" ,
49+ SessionName : "session" ,
50+ }
4651 )
4752
4853 BeforeEach (func () {
@@ -53,13 +58,14 @@ var _ = Describe("Login", func() {
5358 JWTCrypto : jwtCrypto ,
5459 BlaiseRestApi : mockRestApi ,
5560 Logger : observedLogger ,
61+ CSRFManager : csrfManager ,
5662 }
5763 httpRouter = gin .Default ()
5864 httpRouter .LoadHTMLGlob ("../templates/*" )
5965 store := cookie .NewStore ([]byte ("secret" ))
60- httpRouter .Use (sessions .Sessions ( "mysession" , store ))
66+ httpRouter .Use (sessions .SessionsMany ([] string { "session" , "user_session" , "session_validation" } , store ))
6167 httpRouter .POST ("/login" , func (context * gin.Context ) {
62- session = sessions .Default (context )
68+ session = sessions .DefaultMany (context , "user_session" )
6369 auth .Login (context , session )
6470 })
6571 })
@@ -374,16 +380,20 @@ var _ = Describe("Logout", func() {
374380 httpRouter * gin.Engine
375381 httpRecorder * httptest.ResponseRecorder
376382 session sessions.Session
377- auth = & authenticate.Auth {}
383+ csrfManager = & csrf.DefaultCSRFManager {
384+ Secret : "fwibble" ,
385+ SessionName : "session" ,
386+ }
387+ auth = & authenticate.Auth {CSRFManager : csrfManager }
378388 )
379389
380390 BeforeEach (func () {
381391 httpRouter = gin .Default ()
382392 httpRouter .LoadHTMLGlob ("../templates/*" )
383393 store := cookie .NewStore ([]byte ("secret" ))
384- httpRouter .Use (sessions .Sessions ( "mysession" , store ))
394+ httpRouter .Use (sessions .SessionsMany ([] string { "session" , "user_session" , "session_validation" } , store ))
385395 httpRouter .GET ("/logout" , func (context * gin.Context ) {
386- session = sessions .Default (context )
396+ session = sessions .DefaultMany (context , "user_session" )
387397 session .Set ("foobar" , "fizzbuzz" )
388398 session .Save ()
389399 Expect (session .Get ("foobar" )).ToNot (BeNil ())
@@ -412,18 +422,24 @@ var _ = Describe("AuthenticatedWithUac", func() {
412422 session sessions.Session
413423
414424 mockJwtCrypto = & mockauth.JWTCryptoInterface {}
415- auth = & authenticate.Auth {
416- JWTCrypto : mockJwtCrypto ,
425+ csrfManager = & csrf.DefaultCSRFManager {
426+ Secret : "fwibble" ,
427+ SessionName : "session" ,
428+ }
429+ auth = & authenticate.Auth {
430+ JWTCrypto : mockJwtCrypto ,
431+ CSRFManager : csrfManager ,
417432 }
418433 httpRecorder * httptest.ResponseRecorder
419434 httpRouter * gin.Engine
435+ sessionValid = false
420436 )
421437
422438 BeforeEach (func () {
423439 httpRouter = gin .Default ()
424440 httpRouter .LoadHTMLGlob ("../templates/*" )
425441 store := cookie .NewStore ([]byte ("secret" ))
426- httpRouter .Use (sessions .Sessions ( "mysession" , store ))
442+ httpRouter .Use (sessions .SessionsMany ([] string { "session" , "user_session" , "session_validation" } , store ))
427443 })
428444
429445 AfterEach (func () {
@@ -441,9 +457,13 @@ var _ = Describe("AuthenticatedWithUac", func() {
441457 Context ("when there is a token" , func () {
442458 BeforeEach (func () {
443459 httpRouter .Use (func (context * gin.Context ) {
444- session = sessions .Default (context )
460+ session = sessions .DefaultMany (context , "user_session" )
445461 session .Set (authenticate .JWT_TOKEN_KEY , "foobar" )
446462 session .Save ()
463+
464+ sessionValidation := sessions .DefaultMany (context , "session_validation" )
465+ sessionValidation .Set (authenticate .SESSION_VALID_KEY , sessionValid )
466+ sessionValidation .Save ()
447467 context .Next ()
448468 })
449469
@@ -458,10 +478,28 @@ var _ = Describe("AuthenticatedWithUac", func() {
458478 mockJwtCrypto .On ("DecryptJWT" , mock .Anything ).Return (nil , nil )
459479 })
460480
461- It ("Allows the context to continue" , func () {
462- Expect (httpRecorder .Code ).To (Equal (http .StatusOK ))
463- body := httpRecorder .Body .Bytes ()
464- Expect (string (body )).To (Equal ("true" ))
481+ Context ("and the session is valid" , func () {
482+ BeforeEach (func () {
483+ sessionValid = true
484+ })
485+
486+ It ("Allows the context to continue" , func () {
487+ Expect (httpRecorder .Code ).To (Equal (http .StatusOK ))
488+ body := httpRecorder .Body .Bytes ()
489+ Expect (string (body )).To (Equal ("true" ))
490+ })
491+ })
492+
493+ Context ("and the session is invalid" , func () {
494+ BeforeEach (func () {
495+ sessionValid = false
496+ })
497+
498+ It ("returns unauthorized" , func () {
499+ Expect (httpRecorder .Code ).To (Equal (http .StatusUnauthorized ))
500+ body := httpRecorder .Body .Bytes ()
501+ Expect (strings .Contains (string (body ), `<span class="btn__inner">Access study` )).To (BeTrue ())
502+ })
465503 })
466504 })
467505
@@ -470,7 +508,7 @@ var _ = Describe("AuthenticatedWithUac", func() {
470508 mockJwtCrypto .On ("DecryptJWT" , mock .Anything ).Return (nil , fmt .Errorf ("Explosions" ))
471509 })
472510
473- It ("return unauthorized" , func () {
511+ It ("returns unauthorized" , func () {
474512 Expect (httpRecorder .Code ).To (Equal (http .StatusUnauthorized ))
475513 body := httpRecorder .Body .Bytes ()
476514 Expect (strings .Contains (string (body ), `<span class="btn__inner">Access study` )).To (BeTrue ())
@@ -486,7 +524,7 @@ var _ = Describe("AuthenticatedWithUac", func() {
486524 })
487525 })
488526
489- It ("return unauthorized" , func () {
527+ It ("returns unauthorized" , func () {
490528 Expect (httpRecorder .Code ).To (Equal (http .StatusUnauthorized ))
491529 body := httpRecorder .Body .Bytes ()
492530 Expect (strings .Contains (string (body ), `<span class="btn__inner">Access study` )).To (BeTrue ())
@@ -512,10 +550,10 @@ var _ = Describe("Has Session", func() {
512550 httpRouter = gin .Default ()
513551 httpRouter .LoadHTMLGlob ("../templates/*" )
514552 store := cookie .NewStore ([]byte ("secret" ))
515- httpRouter .Use (sessions .Sessions ( "mysession" , store ))
553+ httpRouter .Use (sessions .SessionsMany ([] string { "session" , "user_session" , "session_validation" } , store ))
516554
517555 httpRouter .Use (func (context * gin.Context ) {
518- session = sessions .Default (context )
556+ session = sessions .DefaultMany (context , "user_session" )
519557 session .Set (authenticate .JWT_TOKEN_KEY , "foobar" )
520558 session .Save ()
521559 context .Next ()
0 commit comments