diff --git a/README.md b/README.md index 86f4d0f9..16b0897b 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,56 @@ func main() { } ``` +### multiple sessions with different stores + +```go +package main + +import ( + "github.com/gin-contrib/sessions" + "github.com/gin-contrib/sessions/cookie" + "github.com/gin-gonic/gin" +) + +func main() { + r := gin.Default() + cookieStore := cookie.NewStore([]byte("secret")) + redisStore, _ := redis.NewStore(10, "tcp", "localhost:6379", "", []byte("secret")) + sessionStores := []sessions.SessionStore{ + { + Name: "a", + Store: cookieStore, + }, + { + Name: "b", + Store: redisStore, + }, + } + r.Use(sessions.SessionsManyStores(sessionStores)) + + r.GET("/hello", func(c *gin.Context) { + sessionA := sessions.DefaultMany(c, "a") + sessionB := sessions.DefaultMany(c, "b") + + if sessionA.Get("hello") != "world!" { + sessionA.Set("hello", "world!") + sessionA.Save() + } + + if sessionB.Get("hello") != "world?" { + sessionB.Set("hello", "world?") + sessionB.Save() + } + + c.JSON(200, gin.H{ + "a": sessionA.Get("hello"), + "b": sessionB.Get("hello"), + }) + }) + r.Run(":8000") +} +``` + ## Backend Examples ### cookie-based diff --git a/cookie/cookie_test.go b/cookie/cookie_test.go index 9f2da68a..78be9e95 100644 --- a/cookie/cookie_test.go +++ b/cookie/cookie_test.go @@ -35,3 +35,7 @@ func TestCookie_SessionOptions(t *testing.T) { func TestCookie_SessionMany(t *testing.T) { tester.Many(t, newStore) } + +func TestCookie_SessionManyStores(t *testing.T) { + tester.ManyStores(t, newStore) +} diff --git a/memcached/memcached_test.go b/memcached/memcached_test.go index cd8ae1ce..56e4b354 100644 --- a/memcached/memcached_test.go +++ b/memcached/memcached_test.go @@ -41,6 +41,10 @@ func TestMemcached_SessionMany(t *testing.T) { tester.Many(t, newStore) } +func TestMemcached_SessionManyStores(t *testing.T) { + tester.ManyStores(t, newStore) +} + var newBinaryStore = func(_ *testing.T) sessions.Store { store := NewMemcacheStore( mc.NewMC(memcachedTestServer, "", ""), "", []byte("secret")) @@ -70,3 +74,7 @@ func TestBinaryMemcached_SessionOptions(t *testing.T) { func TestBinaryMemcached_SessionMany(t *testing.T) { tester.Many(t, newBinaryStore) } + +func TestBinaryMemcached_SessionManyStores(t *testing.T) { + tester.ManyStores(t, newBinaryStore) +} diff --git a/memstore/memstore_test.go b/memstore/memstore_test.go index 045eab93..e824db77 100644 --- a/memstore/memstore_test.go +++ b/memstore/memstore_test.go @@ -35,3 +35,7 @@ func TestCookie_SessionOptions(t *testing.T) { func TestCookie_SessionMany(t *testing.T) { tester.Many(t, newStore) } + +func TestCookie_SessionManyStores(t *testing.T) { + tester.ManyStores(t, newStore) +} diff --git a/mongo/mongomgo/mongomgo_test.go b/mongo/mongomgo/mongomgo_test.go index 8a05a504..183ab44c 100644 --- a/mongo/mongomgo/mongomgo_test.go +++ b/mongo/mongomgo/mongomgo_test.go @@ -43,3 +43,7 @@ func TestMongoMGO_SessionOptions(t *testing.T) { func TestMongoMGO_SessionMany(t *testing.T) { tester.Many(t, newStore) } + +func TestMongo_SessionManyStores(t *testing.T) { + tester.ManyStores(t, newStore) +} diff --git a/postgres/postgres_test.go b/postgres/postgres_test.go index e8f39be0..093c6550 100644 --- a/postgres/postgres_test.go +++ b/postgres/postgres_test.go @@ -47,3 +47,7 @@ func TestPostgres_SessionOptions(t *testing.T) { func TestPostgres_SessionMany(t *testing.T) { tester.Many(t, newStore) } + +func TestPostgres_SessionManyStores(t *testing.T) { + tester.ManyStores(t, newStore) +} diff --git a/redis/redis_test.go b/redis/redis_test.go index 8944cbaa..453026e2 100644 --- a/redis/redis_test.go +++ b/redis/redis_test.go @@ -41,6 +41,10 @@ func TestRedis_SessionMany(t *testing.T) { tester.Many(t, newRedisStore) } +func TestRedis_SessionManyStores(t *testing.T) { + tester.ManyStores(t, newRedisStore) +} + func TestGetRedisStore(t *testing.T) { t.Run("unmatched type", func(t *testing.T) { type store struct{ Store } diff --git a/sessions.go b/sessions.go index 2d3fe775..d7914741 100644 --- a/sessions.go +++ b/sessions.go @@ -46,6 +46,12 @@ type Session interface { Save() error } +// SessionStore named session stores allow multiple sessions with different store types +type SessionStore struct { + Name string + Store Store +} + func Sessions(name string, store Store) gin.HandlerFunc { return func(c *gin.Context) { s := &session{name, c.Request, store, nil, false, c.Writer} @@ -67,6 +73,18 @@ func SessionsMany(names []string, store Store) gin.HandlerFunc { } } +func SessionsManyStores(sessionStores []SessionStore) gin.HandlerFunc { + return func(c *gin.Context) { + sessions := make(map[string]Session, len(sessionStores)) + for _, sessionStore := range sessionStores { + sessions[sessionStore.Name] = &session{sessionStore.Name, c.Request, sessionStore.Store, nil, false, c.Writer} + } + c.Set(DefaultKey, sessions) + defer context.Clear(c.Request) + c.Next() + } +} + type session struct { name string request *http.Request diff --git a/tester/tester.go b/tester/tester.go index fba6ec4e..b8fe2ee9 100644 --- a/tester/tester.go +++ b/tester/tester.go @@ -316,6 +316,57 @@ func Many(t *testing.T, newStore storeFactory) { r.ServeHTTP(res2, req2) } +func ManyStores(t *testing.T, newStore storeFactory) { + r := gin.Default() + + store := newStore(t) + sessionStores := []sessions.SessionStore{ + {Name: "a", Store: store}, + {Name: "b", Store: store}, + } + + r.Use(sessions.SessionsManyStores(sessionStores)) + + r.GET("/set", func(c *gin.Context) { + sessionA := sessions.DefaultMany(c, "a") + sessionA.Set("hello", "world") + _ = sessionA.Save() + + sessionB := sessions.DefaultMany(c, "b") + sessionB.Set("foo", "bar") + _ = sessionB.Save() + c.String(http.StatusOK, ok) + }) + + r.GET("/get", func(c *gin.Context) { + sessionA := sessions.DefaultMany(c, "a") + if sessionA.Get("hello") != "world" { + t.Error("Session writing failed") + } + _ = sessionA.Save() + + sessionB := sessions.DefaultMany(c, "b") + if sessionB.Get("foo") != "bar" { + t.Error("Session writing failed") + } + _ = sessionB.Save() + c.String(http.StatusOK, ok) + }) + + res1 := httptest.NewRecorder() + req1, _ := http.NewRequestWithContext(context.Background(), "GET", "/set", nil) + r.ServeHTTP(res1, req1) + + res2 := httptest.NewRecorder() + req2, _ := http.NewRequestWithContext(context.Background(), "GET", "/get", nil) + header := "" + for _, x := range res1.Header()["Set-Cookie"] { + header += strings.Split(x, ";")[0] + "; \n" + } + req2.Header.Set("Cookie", header) + r.ServeHTTP(res2, req2) +} + func copyCookies(req *http.Request, res *httptest.ResponseRecorder) { req.Header.Set("Cookie", strings.Join(res.Header().Values("Set-Cookie"), "; ")) }