Skip to content
This repository was archived by the owner on Dec 23, 2024. It is now read-only.

Commit 759e085

Browse files
authored
Merge pull request #24 from InVisionApp/qparamtoken
query param access token
2 parents e4ee692 + 139d65e commit 759e085

2 files changed

Lines changed: 177 additions & 18 deletions

File tree

middleware_accesstoken.go

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ import (
77
)
88

99
type accessTokens struct {
10-
headerName string
11-
tokens []string
10+
paramName string
11+
tokens []string
12+
getFunc func(string, *http.Request) string
13+
missingMessage string
1214
}
1315

1416
/*
15-
NewMiddlewareAccessToken creates a new handler to verify access tokens in a rye chain.
17+
NewMiddlewareAccessToken creates a new handler to verify access tokens passed as a header.
1618
1719
Example usage:
1820
@@ -23,19 +25,60 @@ Example usage:
2325
})).Methods("POST")
2426
*/
2527
func NewMiddlewareAccessToken(headerName string, tokens []string) func(rw http.ResponseWriter, req *http.Request) *Response {
28+
return newAccessTokenHandler(headerName, tokens, "header")
29+
}
30+
31+
/*
32+
NewMiddlewareAccessQueryToken creates a new handler to verify access tokens passed as a query parameter.
33+
34+
Example usage:
35+
36+
routes.Handle("/some/route", a.Dependencies.MWHandler.Handle(
37+
[]rye.Handler{
38+
rye.NewMiddlewareAccessQueryToken(queryParamName, []string{token1, token2}),
39+
yourHandler,
40+
})).Methods("POST")
41+
*/
42+
func NewMiddlewareAccessQueryToken(queryParamName string, tokens []string) func(rw http.ResponseWriter, req *http.Request) *Response {
43+
return newAccessTokenHandler(queryParamName, tokens, "query")
44+
}
45+
46+
func newAccessTokenHandler(name string, tokens []string, tokenType string) func(rw http.ResponseWriter, req *http.Request) *Response {
2647
a := &accessTokens{
27-
headerName: headerName,
28-
tokens: tokens,
48+
paramName: name,
49+
tokens: tokens,
50+
}
51+
52+
switch tokenType {
53+
54+
case "query":
55+
a.getFunc = func(s string, r *http.Request) string {
56+
q, ok := r.URL.Query()[s]
57+
if !ok {
58+
return ""
59+
}
60+
61+
return q[0]
62+
}
63+
a.missingMessage = fmt.Sprintf("No access token found; ensure you pass the '%s' parameter", name)
64+
65+
default:
66+
// default to using the header
67+
a.getFunc = func(s string, r *http.Request) string {
68+
return r.Header.Get(s)
69+
}
70+
a.missingMessage = fmt.Sprintf("No access token found; ensure you pass '%s' in header", name)
2971
}
72+
3073
return a.handle
3174
}
3275

3376
func (a *accessTokens) handle(rw http.ResponseWriter, r *http.Request) *Response {
34-
token := r.Header.Get(a.headerName)
77+
token := a.getFunc(a.paramName, r)
3578

3679
if token == "" {
3780
return &Response{
38-
Err: fmt.Errorf("No access token found; ensure you pass '%s' in header", a.headerName),
81+
Err: errors.New(a.missingMessage),
3982
StatusCode: http.StatusUnauthorized,
4083
}
4184
}

middleware_accesstoken_test.go

Lines changed: 127 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package rye
22

33
import (
4+
"fmt"
45
"net/http"
56
"net/http/httptest"
7+
"net/url"
68

79
. "github.com/onsi/ginkgo"
810
. "github.com/onsi/gomega"
@@ -14,39 +16,48 @@ var _ = Describe("AccessToken Middleware", func() {
1416
request *http.Request
1517
response *httptest.ResponseRecorder
1618

17-
tokenHeaderName = "at-hname"
18-
token1, token2 string
19+
testHandler func(http.ResponseWriter, *http.Request) *Response
20+
21+
token1, token2 string
1922
)
2023

2124
BeforeEach(func() {
2225
response = httptest.NewRecorder()
23-
request = &http.Request{
24-
Header: map[string][]string{},
25-
}
2626

2727
token1 = "test1"
2828
token2 = "test2"
2929
})
3030

31-
Describe("handle", func() {
31+
Context("header token", func() {
32+
var (
33+
tokenHeaderName = "at-hname"
34+
)
35+
36+
BeforeEach(func() {
37+
testHandler = NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})
38+
request = &http.Request{
39+
Header: map[string][]string{},
40+
}
41+
})
42+
3243
Context("when a valid token is used", func() {
3344
It("should return nil", func() {
3445
request.Header.Add(tokenHeaderName, token1)
35-
resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request)
46+
resp := testHandler(response, request)
3647
Expect(resp).To(BeNil())
3748
})
3849

3950
It("should return nil", func() {
4051
request.Header.Add(tokenHeaderName, token2)
41-
resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request)
52+
resp := testHandler(response, request)
4253
Expect(resp).To(BeNil())
4354
})
4455
})
4556

4657
Context("when an invalid token is used", func() {
4758
It("should return an error", func() {
4859
request.Header.Add(tokenHeaderName, "blah")
49-
resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request)
60+
resp := testHandler(response, request)
5061
Expect(resp).ToNot(BeNil())
5162
Expect(resp.Err).To(HaveOccurred())
5263
Expect(resp.Error()).To(ContainSubstring("invalid access token"))
@@ -56,7 +67,7 @@ var _ = Describe("AccessToken Middleware", func() {
5667

5768
Context("when no token header exists", func() {
5869
It("should return an error", func() {
59-
resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request)
70+
resp := testHandler(response, request)
6071
Expect(resp).ToNot(BeNil())
6172
Expect(resp.Err).To(HaveOccurred())
6273
Expect(resp.Error()).To(ContainSubstring("No access token found"))
@@ -67,12 +78,117 @@ var _ = Describe("AccessToken Middleware", func() {
6778
Context("when token header is blank", func() {
6879
It("should return an error", func() {
6980
request.Header.Add(tokenHeaderName, "")
70-
resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request)
81+
resp := testHandler(response, request)
7182
Expect(resp).ToNot(BeNil())
7283
Expect(resp.Err).To(HaveOccurred())
7384
Expect(resp.Error()).To(ContainSubstring("No access token found"))
7485
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
7586
})
7687
})
7788
})
89+
90+
Context("query param token", func() {
91+
var (
92+
qParamName string
93+
qParams string
94+
)
95+
96+
BeforeEach(func() {
97+
qParamName = "token"
98+
testHandler = NewMiddlewareAccessQueryToken(qParamName, []string{token1, token2})
99+
})
100+
101+
JustBeforeEach(func() {
102+
u, err := url.Parse(fmt.Sprintf("http://doesntmatter.io/blah?%s", qParams))
103+
Expect(err).ToNot(HaveOccurred())
104+
105+
request = &http.Request{
106+
URL: u,
107+
}
108+
})
109+
110+
Context("when a valid token is used", func() {
111+
BeforeEach(func() {
112+
qParams = fmt.Sprintf("%s=%s", qParamName, token1)
113+
})
114+
115+
It("should return nil", func() {
116+
resp := testHandler(response, request)
117+
Expect(resp).To(BeNil())
118+
})
119+
})
120+
121+
Context("when the other valid token is used", func() {
122+
BeforeEach(func() {
123+
qParams = fmt.Sprintf("%s=%s", qParamName, token2)
124+
})
125+
126+
It("should return nil", func() {
127+
resp := testHandler(response, request)
128+
Expect(resp).To(BeNil())
129+
})
130+
})
131+
132+
Context("when an invalid token is used", func() {
133+
BeforeEach(func() {
134+
qParams = fmt.Sprintf("%s=blah", qParamName)
135+
})
136+
137+
It("should return an error", func() {
138+
resp := testHandler(response, request)
139+
Expect(resp).ToNot(BeNil())
140+
Expect(resp.Err).To(HaveOccurred())
141+
Expect(resp.Error()).To(ContainSubstring("invalid access token"))
142+
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
143+
})
144+
})
145+
146+
Context("when no token param exists", func() {
147+
BeforeEach(func() {
148+
qParams = "something=else"
149+
})
150+
151+
It("should return an error", func() {
152+
resp := testHandler(response, request)
153+
Expect(resp).ToNot(BeNil())
154+
Expect(resp.Err).To(HaveOccurred())
155+
Expect(resp.Error()).To(ContainSubstring("No access token found"))
156+
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
157+
})
158+
})
159+
160+
Context("when token param is blank", func() {
161+
BeforeEach(func() {
162+
qParams = fmt.Sprintf("%s=''", qParamName)
163+
})
164+
165+
It("should return an error", func() {
166+
resp := testHandler(response, request)
167+
Expect(resp).ToNot(BeNil())
168+
Expect(resp.Err).To(HaveOccurred())
169+
Expect(resp.Error()).To(ContainSubstring("invalid access token"))
170+
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
171+
})
172+
})
173+
174+
Context("when no query params", func() {
175+
JustBeforeEach(func() {
176+
u, err := url.Parse("http://doesntmatter.io/blah")
177+
Expect(err).ToNot(HaveOccurred())
178+
179+
request = &http.Request{
180+
URL: u,
181+
}
182+
})
183+
184+
It("should return an error", func() {
185+
resp := testHandler(response, request)
186+
Expect(resp).ToNot(BeNil())
187+
Expect(resp.Err).To(HaveOccurred())
188+
Expect(resp.Error()).To(ContainSubstring("No access token found"))
189+
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
190+
})
191+
})
192+
193+
})
78194
})

0 commit comments

Comments
 (0)