@@ -2,22 +2,42 @@ package csrf
22
33import (
44 "errors"
5+ "reflect"
56 "time"
67
78 "github.com/gofiber/fiber/v2"
89)
910
10- var errTokenNotFound = errors .New ("csrf token not found" )
11+ var (
12+ ErrTokenNotFound = errors .New ("csrf token not found" )
13+ ErrTokenInvalid = errors .New ("csrf token invalid" )
14+ ErrNoReferer = errors .New ("referer not supplied" )
15+ ErrBadReferer = errors .New ("referer invalid" )
16+ dummyValue = []byte {'+' }
17+ )
18+
19+ type CSRFHandler struct {
20+ config * Config
21+ sessionManager * sessionManager
22+ storageManager * storageManager
23+ }
1124
1225// New creates a new middleware handler
1326func New (config ... Config ) fiber.Handler {
1427 // Set default config
1528 cfg := configDefault (config ... )
1629
17- // Create manager to simplify storage operations ( see manager.go )
18- manager := newManager (cfg .Storage )
30+ // Create manager to simplify storage operations ( see *_manager.go )
31+ var sessionManager * sessionManager
32+ var storageManager * storageManager
33+ if cfg .Session != nil {
34+ // Register the Token struct in the session store
35+ cfg .Session .RegisterType (Token {})
1936
20- dummyValue := []byte {'+' }
37+ sessionManager = newSessionManager (cfg .Session , cfg .SessionKey )
38+ } else {
39+ storageManager = newStorageManager (cfg .Storage )
40+ }
2141
2242 // Return new handler
2343 return func (c * fiber.Ctx ) error {
@@ -26,36 +46,69 @@ func New(config ...Config) fiber.Handler {
2646 return c .Next ()
2747 }
2848
49+ // Store the CSRF handler in the context if a context key is specified
50+ if cfg .HandlerContextKey != "" {
51+ c .Locals (cfg .HandlerContextKey , & CSRFHandler {
52+ config : & cfg ,
53+ sessionManager : sessionManager ,
54+ storageManager : storageManager ,
55+ })
56+ }
57+
2958 var token string
3059
3160 // Action depends on the HTTP method
3261 switch c .Method () {
3362 case fiber .MethodGet , fiber .MethodHead , fiber .MethodOptions , fiber .MethodTrace :
34- // Declare empty token and try to get existing CSRF from cookie
35- token = c .Cookies (cfg .CookieName )
63+ cookieToken := c .Cookies (cfg .CookieName )
64+
65+ if cookieToken != "" {
66+ rawToken := getTokenFromStorage (c , cookieToken , cfg , sessionManager , storageManager )
67+
68+ if rawToken != nil {
69+ token = string (rawToken )
70+ }
71+ }
3672 default :
3773 // Assume that anything not defined as 'safe' by RFC7231 needs protection
3874
75+ // Enforce an origin check for HTTPS connections.
76+ if c .Protocol () == "https" {
77+ if err := refererMatchesHost (c ); err != nil {
78+ return cfg .ErrorHandler (c , err )
79+ }
80+ }
81+
3982 // Extract token from client request i.e. header, query, param, form or cookie
40- token , err := cfg .Extractor (c )
83+ extractedToken , err := cfg .Extractor (c )
4184 if err != nil {
4285 return cfg .ErrorHandler (c , err )
4386 }
4487
45- // if token does not exist in Storage
46- if manager .getRaw (token ) == nil {
47- // Expire cookie
48- c .Cookie (& fiber.Cookie {
49- Name : cfg .CookieName ,
50- Domain : cfg .CookieDomain ,
51- Path : cfg .CookiePath ,
52- Expires : time .Now ().Add (- 1 * time .Minute ),
53- Secure : cfg .CookieSecure ,
54- HTTPOnly : cfg .CookieHTTPOnly ,
55- SameSite : cfg .CookieSameSite ,
56- SessionOnly : cfg .CookieSessionOnly ,
57- })
58- return cfg .ErrorHandler (c , errTokenNotFound )
88+ if extractedToken == "" {
89+ return cfg .ErrorHandler (c , ErrTokenNotFound )
90+ }
91+
92+ // If not using CsrfFromCookie extractor, check that the token matches the cookie
93+ // This is to prevent CSRF attacks by using a Double Submit Cookie method
94+ // Useful when we do not have access to the users Session
95+ if ! isCsrfFromCookie (cfg .Extractor ) && extractedToken != c .Cookies (cfg .CookieName ) {
96+ return cfg .ErrorHandler (c , ErrTokenInvalid )
97+ }
98+
99+ rawToken := getTokenFromStorage (c , extractedToken , cfg , sessionManager , storageManager )
100+
101+ if rawToken == nil {
102+ // If token is not in storage, expire the cookie
103+ expireCSRFCookie (c , cfg )
104+ // and return an error
105+ return cfg .ErrorHandler (c , ErrTokenNotFound )
106+ }
107+ if cfg .SingleUseToken {
108+ // If token is single use, delete it from storage
109+ deleteTokenFromStorage (c , extractedToken , cfg , sessionManager , storageManager )
110+ } else {
111+ token = string (rawToken )
59112 }
60113 }
61114
@@ -65,29 +118,16 @@ func New(config ...Config) fiber.Handler {
65118 token = cfg .KeyGenerator ()
66119 }
67120
68- // Add/update token to Storage
69- manager .setRaw (token , dummyValue , cfg .Expiration )
70-
71- // Create cookie to pass token to client
72- cookie := & fiber.Cookie {
73- Name : cfg .CookieName ,
74- Value : token ,
75- Domain : cfg .CookieDomain ,
76- Path : cfg .CookiePath ,
77- Expires : time .Now ().Add (cfg .Expiration ),
78- Secure : cfg .CookieSecure ,
79- HTTPOnly : cfg .CookieHTTPOnly ,
80- SameSite : cfg .CookieSameSite ,
81- SessionOnly : cfg .CookieSessionOnly ,
82- }
83- // Set cookie to response
84- c .Cookie (cookie )
121+ // Create or extend the token in the storage
122+ createOrExtendTokenInStorage (c , token , cfg , sessionManager , storageManager )
85123
86- // Protect clients from caching the response by telling the browser
87- // a new header value is generated
124+ // Update the CSRF cookie
125+ updateCSRFCookie (c , cfg , token )
126+
127+ // Tell the browser that a new header value is generated
88128 c .Vary (fiber .HeaderCookie )
89129
90- // Store token in context if set
130+ // Store the token in the context if a context key is specified
91131 if cfg .ContextKey != "" {
92132 c .Locals (cfg .ContextKey , token )
93133 }
@@ -96,3 +136,95 @@ func New(config ...Config) fiber.Handler {
96136 return c .Next ()
97137 }
98138}
139+
140+ // getTokenFromStorage returns the raw token from the storage
141+ // returns nil if the token does not exist, is expired or is invalid
142+ func getTokenFromStorage (c * fiber.Ctx , token string , cfg Config , sessionManager * sessionManager , storageManager * storageManager ) []byte {
143+ if cfg .Session != nil {
144+ return sessionManager .getRaw (c , token , dummyValue )
145+ }
146+ return storageManager .getRaw (token )
147+ }
148+
149+ // createOrExtendTokenInStorage creates or extends the token in the storage
150+ func createOrExtendTokenInStorage (c * fiber.Ctx , token string , cfg Config , sessionManager * sessionManager , storageManager * storageManager ) {
151+ if cfg .Session != nil {
152+ sessionManager .setRaw (c , token , dummyValue , cfg .Expiration )
153+ } else {
154+ storageManager .setRaw (token , dummyValue , cfg .Expiration )
155+ }
156+ }
157+
158+ func deleteTokenFromStorage (c * fiber.Ctx , token string , cfg Config , sessionManager * sessionManager , storageManager * storageManager ) {
159+ if cfg .Session != nil {
160+ sessionManager .delRaw (c )
161+ } else {
162+ storageManager .delRaw (token )
163+ }
164+ }
165+
166+ // Update CSRF cookie
167+ // if expireCookie is true, the cookie will expire immediately
168+ func updateCSRFCookie (c * fiber.Ctx , cfg Config , token string ) {
169+ setCSRFCookie (c , cfg , token , cfg .Expiration )
170+ }
171+
172+ func expireCSRFCookie (c * fiber.Ctx , cfg Config ) {
173+ setCSRFCookie (c , cfg , "" , - time .Hour )
174+ }
175+
176+ func setCSRFCookie (c * fiber.Ctx , cfg Config , token string , expiry time.Duration ) {
177+ cookie := & fiber.Cookie {
178+ Name : cfg .CookieName ,
179+ Value : token ,
180+ Domain : cfg .CookieDomain ,
181+ Path : cfg .CookiePath ,
182+ Secure : cfg .CookieSecure ,
183+ HTTPOnly : cfg .CookieHTTPOnly ,
184+ SameSite : cfg .CookieSameSite ,
185+ SessionOnly : cfg .CookieSessionOnly ,
186+ Expires : time .Now ().Add (expiry ),
187+ }
188+
189+ // Set the CSRF cookie to the response
190+ c .Cookie (cookie )
191+ }
192+
193+ // DeleteToken removes the token found in the context from the storage
194+ // and expires the CSRF cookie
195+ func (handler * CSRFHandler ) DeleteToken (c * fiber.Ctx ) error {
196+ // Get the config from the context
197+ config := handler .config
198+ if config == nil {
199+ panic ("CSRFHandler config not found in context" )
200+ }
201+ // Extract token from the client request cookie
202+ cookieToken := c .Cookies (config .CookieName )
203+ if cookieToken == "" {
204+ return config .ErrorHandler (c , ErrTokenNotFound )
205+ }
206+ // Remove the token from storage
207+ deleteTokenFromStorage (c , cookieToken , * config , handler .sessionManager , handler .storageManager )
208+ // Expire the cookie
209+ expireCSRFCookie (c , * config )
210+ return nil
211+ }
212+
213+ // isCsrfFromCookie checks if the extractor is set to ExtractFromCookie
214+ func isCsrfFromCookie (extractor interface {}) bool {
215+ return reflect .ValueOf (extractor ).Pointer () == reflect .ValueOf (CsrfFromCookie ).Pointer ()
216+ }
217+
218+ // refererMatchesHost checks that the referer header matches the host header
219+ // returns an error if the referer header is not present or is invalid
220+ // returns nil if the referer header is valid
221+ func refererMatchesHost (c * fiber.Ctx ) error {
222+ referer := c .Get (fiber .HeaderReferer )
223+ if referer == "" {
224+ return ErrNoReferer
225+ }
226+ if referer != c .Protocol ()+ "://" + c .Hostname () {
227+ return ErrBadReferer
228+ }
229+ return nil
230+ }
0 commit comments