Skip to content

Commit 50bd6ae

Browse files
committed
Added functional option to allow to customize DialContext() in HTTP client
Signed-off-by: Marco Pracucci <[email protected]>
1 parent 4240322 commit 50bd6ae

File tree

2 files changed

+63
-7
lines changed

2 files changed

+63
-7
lines changed

config/http_config.go

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@ package config
1717

1818
import (
1919
"bytes"
20+
"context"
2021
"crypto/sha256"
2122
"crypto/tls"
2223
"crypto/x509"
2324
"fmt"
2425
"io/ioutil"
26+
"net"
2527
"net/http"
2628
"net/url"
2729
"strings"
@@ -194,15 +196,33 @@ func (a *BasicAuth) UnmarshalYAML(unmarshal func(interface{}) error) error {
194196
return unmarshal((*plain)(a))
195197
}
196198

199+
// DialContextFunc defines the signature of the DialContext() function implemented
200+
// by net.Dialer.
201+
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
202+
203+
type httpClientOptions struct {
204+
dialContextFunc DialContextFunc
205+
}
206+
207+
// HTTPClientOption defines an option that can be applied to the HTTP client.
208+
type HTTPClientOption func(options *httpClientOptions)
209+
210+
// WithDialContextFunc allows you to override func gets used for the actual dialing. The default is `net.Dialer.DialContext`.
211+
func WithDialContextFunc(fn DialContextFunc) HTTPClientOption {
212+
return func(opts *httpClientOptions) {
213+
opts.dialContextFunc = fn
214+
}
215+
}
216+
197217
// NewClient returns a http.Client using the specified http.RoundTripper.
198218
func newClient(rt http.RoundTripper) *http.Client {
199219
return &http.Client{Transport: rt}
200220
}
201221

202222
// NewClientFromConfig returns a new HTTP client configured for the
203223
// given config.HTTPClientConfig. The name is used as go-conntrack metric label.
204-
func NewClientFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool) (*http.Client, error) {
205-
rt, err := NewRoundTripperFromConfig(cfg, name, disableKeepAlives, enableHTTP2)
224+
func NewClientFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool, optFuncs ...HTTPClientOption) (*http.Client, error) {
225+
rt, err := NewRoundTripperFromConfig(cfg, name, disableKeepAlives, enableHTTP2, optFuncs...)
206226
if err != nil {
207227
return nil, err
208228
}
@@ -217,7 +237,25 @@ func NewClientFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, e
217237

218238
// NewRoundTripperFromConfig returns a new HTTP RoundTripper configured for the
219239
// given config.HTTPClientConfig. The name is used as go-conntrack metric label.
220-
func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool) (http.RoundTripper, error) {
240+
func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool, optFuncs ...HTTPClientOption) (http.RoundTripper, error) {
241+
opts := &httpClientOptions{}
242+
for _, f := range optFuncs {
243+
f(opts)
244+
}
245+
246+
var dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
247+
248+
if opts.dialContextFunc != nil {
249+
dialContext = conntrack.NewDialContextFunc(
250+
conntrack.DialWithDialContextFunc((func(context.Context, string, string) (net.Conn, error))(opts.dialContextFunc)),
251+
conntrack.DialWithTracing(),
252+
conntrack.DialWithName(name))
253+
} else {
254+
dialContext = conntrack.NewDialContextFunc(
255+
conntrack.DialWithTracing(),
256+
conntrack.DialWithName(name))
257+
}
258+
221259
newRT := func(tlsConfig *tls.Config) (http.RoundTripper, error) {
222260
// The only timeout we care about is the configured scrape timeout.
223261
// It is applied on request. So we leave out any timings here.
@@ -233,10 +271,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, disableKeepAli
233271
IdleConnTimeout: 5 * time.Minute,
234272
TLSHandshakeTimeout: 10 * time.Second,
235273
ExpectContinueTimeout: 1 * time.Second,
236-
DialContext: conntrack.NewDialContextFunc(
237-
conntrack.DialWithTracing(),
238-
conntrack.DialWithName(name),
239-
),
274+
DialContext: dialContext,
240275
}
241276
if enableHTTP2 {
242277
// HTTP/2 support is golang has many problematic cornercases where

config/http_config_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
package config
1717

1818
import (
19+
"context"
1920
"crypto/tls"
2021
"crypto/x509"
22+
"errors"
2123
"fmt"
2224
"io/ioutil"
25+
"net"
2326
"net/http"
2427
"net/http/httptest"
2528
"os"
@@ -50,6 +53,7 @@ const (
5053
MissingKey = "missing/secret.key"
5154

5255
ExpectedMessage = "I'm here to serve you!!!"
56+
ExpectedError = "expected error"
5357
AuthorizationCredentials = "theanswertothegreatquestionoflifetheuniverseandeverythingisfortytwo"
5458
AuthorizationCredentialsFile = "testdata/bearer.token"
5559
AuthorizationType = "APIKEY"
@@ -413,6 +417,23 @@ func TestNewClientFromInvalidConfig(t *testing.T) {
413417
}
414418
}
415419

420+
func TestCustomDialContextFunc(t *testing.T) {
421+
dialFn := func(_ context.Context, _, _ string) (net.Conn, error) {
422+
return nil, errors.New(ExpectedError)
423+
}
424+
425+
cfg := HTTPClientConfig{}
426+
client, err := NewClientFromConfig(cfg, "test", false, true, WithDialContextFunc(dialFn))
427+
if err != nil {
428+
t.Fatalf("Can't create a client from this config: %+v", cfg)
429+
}
430+
431+
_, err = client.Get("http://localhost")
432+
if err == nil || !strings.Contains(err.Error(), ExpectedError) {
433+
t.Errorf("Expected error %q but got %q", ExpectedError, err)
434+
}
435+
}
436+
416437
func TestMissingBearerAuthFile(t *testing.T) {
417438
cfg := HTTPClientConfig{
418439
BearerTokenFile: MissingBearerTokenFile,

0 commit comments

Comments
 (0)