Skip to content

Commit 4bfa954

Browse files
authored
Add tls_config field to OAuth 2.0 Config (#331)
* Add `TLSConfig` field Signed-off-by: Levi Harrison <[email protected]> * Add tests Signed-off-by: Levi Harrison <[email protected]>
1 parent 5a26535 commit 4bfa954

File tree

2 files changed

+74
-7
lines changed

2 files changed

+74
-7
lines changed

config/http_config.go

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ type OAuth2 struct {
159159
Scopes []string `yaml:"scopes,omitempty" json:"scopes,omitempty"`
160160
TokenURL string `yaml:"token_url" json:"token_url"`
161161
EndpointParams map[string]string `yaml:"endpoint_params,omitempty" json:"endpoint_params,omitempty"`
162+
163+
// TLSConfig is used to connect to the token URL.
164+
TLSConfig TLSConfig `yaml:"tls_config,omitempty"`
162165
}
163166

164167
// SetDirectory joins any relative file paths with dir.
@@ -594,7 +597,25 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
594597
EndpointParams: mapToValues(rt.config.EndpointParams),
595598
}
596599

597-
tokenSource := config.TokenSource(context.Background())
600+
tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig)
601+
if err != nil {
602+
return nil, err
603+
}
604+
605+
var t http.RoundTripper
606+
if len(rt.config.TLSConfig.CAFile) == 0 {
607+
t = &http.Transport{TLSClientConfig: tlsConfig}
608+
} else {
609+
t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.CAFile, func(tls *tls.Config) (http.RoundTripper, error) {
610+
return &http.Transport{TLSClientConfig: tls}, nil
611+
})
612+
if err != nil {
613+
return nil, err
614+
}
615+
}
616+
617+
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: t})
618+
tokenSource := config.TokenSource(ctx)
598619

599620
rt.mtx.Lock()
600621
rt.secret = secret
@@ -763,7 +784,6 @@ func NewTLSRoundTripper(
763784
return nil, err
764785
}
765786
t.rt = rt
766-
767787
_, t.hashCAFile, err = t.getCAWithHash()
768788
if err != nil {
769789
return nil, err

config/http_config_test.go

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ const (
6666
ExpectedAuthenticationCredentials = AuthorizationType + " " + BearerToken
6767
ExpectedUsername = "arthurdent"
6868
ExpectedPassword = "42"
69+
ExpectedAccessToken = "12345"
6970
)
7071

7172
var invalidHTTPClientConfigs = []struct {
@@ -363,6 +364,45 @@ func TestNewClientFromConfig(t *testing.T) {
363364
}
364365
},
365366
},
367+
{
368+
clientConfig: HTTPClientConfig{
369+
OAuth2: &OAuth2{
370+
ClientID: "ExpectedUsername",
371+
ClientSecret: "ExpectedPassword",
372+
TLSConfig: TLSConfig{
373+
CAFile: TLSCAChainPath,
374+
CertFile: ClientCertificatePath,
375+
KeyFile: ClientKeyNoPassPath,
376+
ServerName: "",
377+
InsecureSkipVerify: false},
378+
},
379+
TLSConfig: TLSConfig{
380+
CAFile: TLSCAChainPath,
381+
CertFile: ClientCertificatePath,
382+
KeyFile: ClientKeyNoPassPath,
383+
ServerName: "",
384+
InsecureSkipVerify: false},
385+
},
386+
handler: func(w http.ResponseWriter, r *http.Request) {
387+
switch r.URL.Path {
388+
case "/token":
389+
res, _ := json.Marshal(oauth2TestServerResponse{
390+
AccessToken: ExpectedAccessToken,
391+
TokenType: "Bearer",
392+
})
393+
w.Header().Add("Content-Type", "application/json")
394+
_, _ = w.Write(res)
395+
396+
default:
397+
authorization := r.Header.Get("Authorization")
398+
if authorization != "Bearer "+ExpectedAccessToken {
399+
fmt.Fprintf(w, "Expected Authorization header %q, got %q", "Bearer "+ExpectedAccessToken, authorization)
400+
} else {
401+
fmt.Fprint(w, ExpectedMessage)
402+
}
403+
}
404+
},
405+
},
366406
}
367407

368408
for _, validConfig := range newClientValidConfig {
@@ -372,6 +412,12 @@ func TestNewClientFromConfig(t *testing.T) {
372412
}
373413
defer testServer.Close()
374414

415+
if validConfig.clientConfig.OAuth2 != nil {
416+
// We don't have access to the test server's URL when configuring the test cases,
417+
// so it has to be specified here.
418+
validConfig.clientConfig.OAuth2.TokenURL = testServer.URL + "/token"
419+
}
420+
375421
err = validConfig.clientConfig.Validate()
376422
if err != nil {
377423
t.Fatal(err.Error())
@@ -381,6 +427,7 @@ func TestNewClientFromConfig(t *testing.T) {
381427
t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig)
382428
continue
383429
}
430+
384431
response, err := client.Get(testServer.URL)
385432
if err != nil {
386433
t.Errorf("Can't connect to the test server using this config: %+v: %v", validConfig.clientConfig, err)
@@ -1129,14 +1176,14 @@ func NewRoundTripCheckRequest(checkRequest func(*http.Request), theResponse *htt
11291176
theError: theError}}
11301177
}
11311178

1132-
type testServerResponse struct {
1179+
type oauth2TestServerResponse struct {
11331180
AccessToken string `json:"access_token"`
11341181
TokenType string `json:"token_type"`
11351182
}
11361183

11371184
func TestOAuth2(t *testing.T) {
11381185
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1139-
res, _ := json.Marshal(testServerResponse{
1186+
res, _ := json.Marshal(oauth2TestServerResponse{
11401187
AccessToken: "12345",
11411188
TokenType: "Bearer",
11421189
})
@@ -1169,7 +1216,7 @@ endpoint_params:
11691216
t.Fatalf("Expected no error unmarshalling yaml, got %v", err)
11701217
}
11711218
if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) {
1172-
t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig)
1219+
t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig)
11731220
}
11741221

11751222
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)
@@ -1197,7 +1244,7 @@ func TestOAuth2WithFile(t *testing.T) {
11971244
t.Fatal("token endpoint called twice")
11981245
}
11991246
previousAuth = auth
1200-
res, _ := json.Marshal(testServerResponse{
1247+
res, _ := json.Marshal(oauth2TestServerResponse{
12011248
AccessToken: "12345",
12021249
TokenType: "Bearer",
12031250
})
@@ -1244,7 +1291,7 @@ endpoint_params:
12441291
t.Fatalf("Expected no error unmarshalling yaml, got %v", err)
12451292
}
12461293
if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) {
1247-
t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig)
1294+
t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig)
12481295
}
12491296

12501297
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)

0 commit comments

Comments
 (0)