Skip to content

Commit 752114b

Browse files
authoredDec 1, 2023
Add options lambdaurl.WithDetectContentType and lambda.WithContextValue (#516)
·
v1.50.0v1.42.0
1 parent 1dca084 commit 752114b

File tree

6 files changed

+252
-31
lines changed

6 files changed

+252
-31
lines changed
 

‎.github/workflows/tests.yml‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ jobs:
88
name: run tests
99
runs-on: ubuntu-latest
1010
strategy:
11+
fail-fast: false
1112
matrix:
1213
go:
1314
- "1.21"

‎lambda/handler.go‎

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type Handler interface {
2323
type handlerOptions struct {
2424
handlerFunc
2525
baseContext context.Context
26+
contextValues map[interface{}]interface{}
2627
jsonRequestUseNumber bool
2728
jsonRequestDisallowUnknownFields bool
2829
jsonResponseEscapeHTML bool
@@ -50,6 +51,23 @@ func WithContext(ctx context.Context) Option {
5051
})
5152
}
5253

54+
// WithContextValue adds a value to the handler context.
55+
// If a base context was set using WithContext, that base is used as the parent.
56+
//
57+
// Usage:
58+
//
59+
// lambda.StartWithOptions(
60+
// func (ctx context.Context) (string, error) {
61+
// return ctx.Value("foo"), nil
62+
// },
63+
// lambda.WithContextValue("foo", "bar")
64+
// )
65+
func WithContextValue(key interface{}, value interface{}) Option {
66+
return Option(func(h *handlerOptions) {
67+
h.contextValues[key] = value
68+
})
69+
}
70+
5371
// WithSetEscapeHTML sets the SetEscapeHTML argument on the underlying json encoder
5472
//
5573
// Usage:
@@ -211,13 +229,17 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions {
211229
}
212230
h := &handlerOptions{
213231
baseContext: context.Background(),
232+
contextValues: map[interface{}]interface{}{},
214233
jsonResponseEscapeHTML: false,
215234
jsonResponseIndentPrefix: "",
216235
jsonResponseIndentValue: "",
217236
}
218237
for _, option := range options {
219238
option(h)
220239
}
240+
for k, v := range h.contextValues {
241+
h.baseContext = context.WithValue(h.baseContext, k, v)
242+
}
221243
if h.enableSIGTERM {
222244
enableSIGTERM(h.sigtermCallbacks)
223245
}

‎lambda/sigterm_test.go‎

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"os"
1010
"os/exec"
1111
"path"
12+
"strconv"
1213
"strings"
1314
"testing"
1415
"time"
@@ -17,10 +18,6 @@ import (
1718
"github.com/stretchr/testify/require"
1819
)
1920

20-
const (
21-
rieInvokeAPI = "http://localhost:8080/2015-03-31/functions/function/invocations"
22-
)
23-
2421
func TestEnableSigterm(t *testing.T) {
2522
if _, err := exec.LookPath("aws-lambda-rie"); err != nil {
2623
t.Skipf("%v - install from https://github.com/aws/aws-lambda-runtime-interface-emulator/", err)
@@ -34,6 +31,7 @@ func TestEnableSigterm(t *testing.T) {
3431
handlerBuild.Stdout = os.Stderr
3532
require.NoError(t, handlerBuild.Run())
3633

34+
portI := 0
3735
for name, opts := range map[string]struct {
3836
envVars []string
3937
assertLogs func(t *testing.T, logs string)
@@ -53,8 +51,12 @@ func TestEnableSigterm(t *testing.T) {
5351
},
5452
} {
5553
t.Run(name, func(t *testing.T) {
54+
portI += 1
55+
addr1 := "localhost:" + strconv.Itoa(8000+portI)
56+
addr2 := "localhost:" + strconv.Itoa(9000+portI)
57+
rieInvokeAPI := "http://" + addr1 + "/2015-03-31/functions/function/invocations"
5658
// run the runtime interface emulator, capture the logs for assertion
57-
cmd := exec.Command("aws-lambda-rie", "sigterm.handler")
59+
cmd := exec.Command("aws-lambda-rie", "--runtime-interface-emulator-address", addr1, "--runtime-api-address", addr2, "sigterm.handler")
5860
cmd.Env = append([]string{
5961
"PATH=" + testDir,
6062
"AWS_LAMBDA_FUNCTION_TIMEOUT=2",

‎lambdaurl/http_handler.go‎

Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,76 @@ import (
1818
"github.com/aws/aws-lambda-go/lambda"
1919
)
2020

21+
type detectContentTypeContextKey struct{}
22+
23+
// WithDetectContentType sets the behavior of content type detection when the Content-Type header is not already provided.
24+
// When true, the first Write call will pass the intial bytes to http.DetectContentType.
25+
// When false, and if no Content-Type is provided, no Content-Type will be sent back to Lambda,
26+
// and the Lambda Function URL will fallback to it's default.
27+
//
28+
// Note: The http.ResponseWriter passed to the handler is unbuffered.
29+
// This may result in different Content-Type headers in the Function URL response when compared to http.ListenAndServe.
30+
//
31+
// Usage:
32+
//
33+
// lambdaurl.Start(
34+
// http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
35+
// w.Write("<!DOCTYPE html><html></html>")
36+
// }),
37+
// lambdaurl.WithDetectContentType(true)
38+
// )
39+
func WithDetectContentType(detectContentType bool) lambda.Option {
40+
return lambda.WithContextValue(detectContentTypeContextKey{}, detectContentType)
41+
}
42+
2143
type httpResponseWriter struct {
44+
detectContentType bool
45+
header http.Header
46+
writer io.Writer
47+
once sync.Once
48+
ready chan<- header
49+
}
50+
51+
type header struct {
52+
code int
2253
header http.Header
23-
writer io.Writer
24-
once sync.Once
25-
status chan<- int
2654
}
2755

2856
func (w *httpResponseWriter) Header() http.Header {
57+
if w.header == nil {
58+
w.header = http.Header{}
59+
}
2960
return w.header
3061
}
3162

3263
func (w *httpResponseWriter) Write(p []byte) (int, error) {
33-
w.once.Do(func() { w.status <- http.StatusOK })
64+
w.writeHeader(http.StatusOK, p)
3465
return w.writer.Write(p)
3566
}
3667

3768
func (w *httpResponseWriter) WriteHeader(statusCode int) {
38-
w.once.Do(func() { w.status <- statusCode })
69+
w.writeHeader(statusCode, nil)
70+
}
71+
72+
func (w *httpResponseWriter) writeHeader(statusCode int, initialPayload []byte) {
73+
w.once.Do(func() {
74+
if w.detectContentType {
75+
if w.Header().Get("Content-Type") == "" {
76+
w.Header().Set("Content-Type", detectContentType(initialPayload))
77+
}
78+
}
79+
w.ready <- header{code: statusCode, header: w.header}
80+
})
81+
}
82+
83+
func detectContentType(p []byte) string {
84+
// http.DetectContentType returns "text/plain; charset=utf-8" for nil and zero-length byte slices.
85+
// This is a weird behavior, since otherwise it defaults to "application/octet-stream"! So we'll do that.
86+
// This differs from http.ListenAndServe, which set no Content-Type when the initial Flush body is empty.
87+
if len(p) == 0 {
88+
return "application/octet-stream"
89+
}
90+
return http.DetectContentType(p)
3991
}
4092

4193
type requestContextKey struct{}
@@ -46,11 +98,13 @@ func RequestFromContext(ctx context.Context) (*events.LambdaFunctionURLRequest,
4698
return req, ok
4799
}
48100

49-
// Wrap converts an http.Handler into a lambda request handler.
101+
// Wrap converts an http.Handler into a Lambda request handler.
102+
//
50103
// Only Lambda Function URLs configured with `InvokeMode: RESPONSE_STREAM` are supported with the returned handler.
51-
// The response body of the handler will conform to the content-type `application/vnd.awslambda.http-integration-response`
104+
// The response body of the handler will conform to the content-type `application/vnd.awslambda.http-integration-response`.
52105
func Wrap(handler http.Handler) func(context.Context, *events.LambdaFunctionURLRequest) (*events.LambdaFunctionURLStreamingResponse, error) {
53106
return func(ctx context.Context, request *events.LambdaFunctionURLRequest) (*events.LambdaFunctionURLStreamingResponse, error) {
107+
54108
var body io.Reader = strings.NewReader(request.Body)
55109
if request.IsBase64Encoded {
56110
body = base64.NewDecoder(base64.StdEncoding, body)
@@ -67,21 +121,28 @@ func Wrap(handler http.Handler) func(context.Context, *events.LambdaFunctionURLR
67121
for k, v := range request.Headers {
68122
httpRequest.Header.Add(k, v)
69123
}
70-
status := make(chan int) // Signals when it's OK to start returning the response body to Lambda
71-
header := http.Header{}
124+
125+
ready := make(chan header) // Signals when it's OK to start returning the response body to Lambda
72126
r, w := io.Pipe()
127+
responseWriter := &httpResponseWriter{writer: w, ready: ready}
128+
if detectContentType, ok := ctx.Value(detectContentTypeContextKey{}).(bool); ok {
129+
responseWriter.detectContentType = detectContentType
130+
}
73131
go func() {
74-
defer close(status)
132+
defer close(ready)
75133
defer w.Close() // TODO: recover and CloseWithError the any panic value once the runtime API client supports plumbing fatal errors through the reader
76-
handler.ServeHTTP(&httpResponseWriter{writer: w, header: header, status: status}, httpRequest)
134+
//nolint:errcheck
135+
defer responseWriter.Write(nil) // force default status, headers, content type detection, if none occured during the execution of the handler
136+
handler.ServeHTTP(responseWriter, httpRequest)
77137
}()
138+
header := <-ready
78139
response := &events.LambdaFunctionURLStreamingResponse{
79140
Body: r,
80-
StatusCode: <-status,
141+
StatusCode: header.code,
81142
}
82-
if len(header) > 0 {
83-
response.Headers = make(map[string]string, len(header))
84-
for k, v := range header {
143+
if len(header.header) > 0 {
144+
response.Headers = make(map[string]string, len(header.header))
145+
for k, v := range header.header {
85146
if k == "Set-Cookie" {
86147
response.Cookies = v
87148
} else {

‎lambdaurl/http_handler_test.go‎

Lines changed: 121 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ import (
1313
"io/ioutil"
1414
"log"
1515
"net/http"
16+
"os"
17+
"os/exec"
18+
"path"
19+
"strconv"
20+
"strings"
1621
"testing"
1722
"time"
1823

@@ -35,12 +40,13 @@ var base64EncodedBodyRequest []byte
3540

3641
func TestWrap(t *testing.T) {
3742
for name, params := range map[string]struct {
38-
input []byte
39-
handler http.HandlerFunc
40-
expectStatus int
41-
expectBody string
42-
expectHeaders map[string]string
43-
expectCookies []string
43+
input []byte
44+
handler http.HandlerFunc
45+
detectContentType bool
46+
expectStatus int
47+
expectBody string
48+
expectHeaders map[string]string
49+
expectCookies []string
4450
}{
4551
"hello": {
4652
input: helloRequest,
@@ -58,10 +64,8 @@ func TestWrap(t *testing.T) {
5864
encoder := json.NewEncoder(w)
5965
_ = encoder.Encode(struct{ RequestQueryParams, Method any }{r.URL.Query(), r.Method})
6066
},
61-
expectStatus: http.StatusTeapot,
62-
expectHeaders: map[string]string{
63-
"Hello": "world1,world2",
64-
},
67+
expectStatus: http.StatusTeapot,
68+
expectHeaders: map[string]string{"Hello": "world1,world2"},
6569
expectCookies: []string{
6670
"yummy=cookie",
6771
"yummy=cake",
@@ -110,6 +114,13 @@ func TestWrap(t *testing.T) {
110114
handler: func(w http.ResponseWriter, r *http.Request) {},
111115
expectStatus: http.StatusOK,
112116
},
117+
"write status code only": {
118+
input: helloRequest,
119+
handler: func(w http.ResponseWriter, r *http.Request) {
120+
w.WriteHeader(http.StatusAccepted)
121+
},
122+
expectStatus: http.StatusAccepted,
123+
},
113124
"base64request": {
114125
input: base64EncodedBodyRequest,
115126
handler: func(w http.ResponseWriter, r *http.Request) {
@@ -118,12 +129,58 @@ func TestWrap(t *testing.T) {
118129
expectStatus: http.StatusOK,
119130
expectBody: "<idk/>",
120131
},
132+
"detect content type: write status code only": {
133+
input: helloRequest,
134+
handler: func(w http.ResponseWriter, r *http.Request) {
135+
w.WriteHeader(http.StatusAccepted)
136+
},
137+
detectContentType: true,
138+
expectStatus: http.StatusAccepted,
139+
expectHeaders: map[string]string{
140+
"Content-Type": "application/octet-stream",
141+
},
142+
},
143+
"detect content type: empty handler": {
144+
input: helloRequest,
145+
handler: func(w http.ResponseWriter, r *http.Request) {
146+
},
147+
detectContentType: true,
148+
expectStatus: http.StatusOK,
149+
expectHeaders: map[string]string{
150+
"Content-Type": "application/octet-stream",
151+
},
152+
},
153+
"detect content type: writes html": {
154+
input: helloRequest,
155+
handler: func(w http.ResponseWriter, r *http.Request) {
156+
_, _ = w.Write([]byte("<!DOCTYPE HTML><html></html>"))
157+
},
158+
detectContentType: true,
159+
expectBody: "<!DOCTYPE HTML><html></html>",
160+
expectStatus: http.StatusOK,
161+
expectHeaders: map[string]string{
162+
"Content-Type": "text/html; charset=utf-8",
163+
},
164+
},
165+
"detect content type: writes zeros": {
166+
input: helloRequest,
167+
handler: func(w http.ResponseWriter, r *http.Request) {
168+
_, _ = w.Write([]byte{0, 0, 0, 0, 0})
169+
},
170+
detectContentType: true,
171+
expectBody: "\x00\x00\x00\x00\x00",
172+
expectStatus: http.StatusOK,
173+
expectHeaders: map[string]string{
174+
"Content-Type": "application/octet-stream",
175+
},
176+
},
121177
} {
122178
t.Run(name, func(t *testing.T) {
123179
handler := Wrap(params.handler)
124180
var req events.LambdaFunctionURLRequest
125181
require.NoError(t, json.Unmarshal(params.input, &req))
126-
res, err := handler(context.Background(), &req)
182+
ctx := context.WithValue(context.Background(), detectContentTypeContextKey{}, params.detectContentType)
183+
res, err := handler(ctx, &req)
127184
require.NoError(t, err)
128185
resultBodyBytes, err := ioutil.ReadAll(res)
129186
require.NoError(t, err)
@@ -155,3 +212,56 @@ func TestRequestContext(t *testing.T) {
155212
_, err := handler(context.Background(), req)
156213
require.NoError(t, err)
157214
}
215+
216+
func TestStartViaEmulator(t *testing.T) {
217+
addr1 := "localhost:" + strconv.Itoa(6001)
218+
addr2 := "localhost:" + strconv.Itoa(7001)
219+
rieInvokeAPI := "http://" + addr1 + "/2015-03-31/functions/function/invocations"
220+
if _, err := exec.LookPath("aws-lambda-rie"); err != nil {
221+
t.Skipf("%v - install from https://github.com/aws/aws-lambda-runtime-interface-emulator/", err)
222+
}
223+
224+
// compile our handler, it'll always run to timeout ensuring the SIGTERM is triggered by aws-lambda-rie
225+
testDir := t.TempDir()
226+
handlerBuild := exec.Command("go", "build", "-o", path.Join(testDir, "lambdaurl.handler"), "./testdata/lambdaurl.go")
227+
handlerBuild.Stderr = os.Stderr
228+
handlerBuild.Stdout = os.Stderr
229+
require.NoError(t, handlerBuild.Run())
230+
231+
// run the runtime interface emulator, capture the logs for assertion
232+
cmd := exec.Command("aws-lambda-rie", "--runtime-interface-emulator-address", addr1, "--runtime-api-address", addr2, "lambdaurl.handler")
233+
cmd.Env = []string{
234+
"PATH=" + testDir,
235+
"AWS_LAMBDA_FUNCTION_TIMEOUT=2",
236+
}
237+
cmd.Stderr = os.Stderr
238+
stdout, err := cmd.StdoutPipe()
239+
require.NoError(t, err)
240+
var logs string
241+
done := make(chan interface{}) // closed on completion of log flush
242+
go func() {
243+
logBytes, err := ioutil.ReadAll(stdout)
244+
require.NoError(t, err)
245+
logs = string(logBytes)
246+
close(done)
247+
}()
248+
require.NoError(t, cmd.Start())
249+
t.Cleanup(func() { _ = cmd.Process.Kill() })
250+
251+
// give a moment for the port to bind
252+
time.Sleep(500 * time.Millisecond)
253+
254+
client := &http.Client{Timeout: 5 * time.Second} // http client timeout to prevent case from hanging on aws-lambda-rie
255+
resp, err := client.Post(rieInvokeAPI, "application/json", strings.NewReader("{}"))
256+
require.NoError(t, err)
257+
defer resp.Body.Close()
258+
body, err := ioutil.ReadAll(resp.Body)
259+
assert.NoError(t, err)
260+
261+
expected := "{\"statusCode\":200,\"headers\":{\"Content-Type\":\"text/html; charset=utf-8\"}}\x00\x00\x00\x00\x00\x00\x00\x00<!DOCTYPE HTML>\n<html>\n<body>\nHello World!\n</body>\n</html>\n"
262+
assert.Equal(t, expected, string(body))
263+
264+
require.NoError(t, cmd.Process.Kill()) // now ensure the logs are drained
265+
<-done
266+
t.Logf("stdout:\n%s", logs)
267+
}

‎lambdaurl/testdata/lambdaurl.go‎

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package main
2+
3+
import (
4+
"io"
5+
"net/http"
6+
"strings"
7+
8+
"github.com/aws/aws-lambda-go/lambdaurl"
9+
)
10+
11+
const content = `<!DOCTYPE HTML>
12+
<html>
13+
<body>
14+
Hello World!
15+
</body>
16+
</html>
17+
`
18+
19+
func main() {
20+
lambdaurl.Start(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
21+
_, _ = io.Copy(w, strings.NewReader(content))
22+
}),
23+
lambdaurl.WithDetectContentType(true),
24+
)
25+
}

0 commit comments

Comments
 (0)
Please sign in to comment.