Skip to content

Commit a39c8aa

Browse files
committed
Add RegisterDial function: #372
1 parent b68d962 commit a39c8aa

File tree

2 files changed

+59
-15
lines changed

2 files changed

+59
-15
lines changed

clickhouse_test.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,7 +1774,7 @@ func Test_Tuple(t *testing.T) {
17741774
?,
17751775
?,
17761776
?,
1777-
?,
1777+
?,
17781778
?,
17791779
?,
17801780
?,
@@ -2061,7 +2061,7 @@ func Test_ReadHistogram(t *testing.T) {
20612061
?,
20622062
?,
20632063
?,
2064-
?,
2064+
?,
20652065
?,
20662066
?,
20672067
?,
@@ -2173,12 +2173,12 @@ func Test_ReadHistogram(t *testing.T) {
21732173
func Test_ReadArrayArrayTuple(t *testing.T) {
21742174
const (
21752175
query = `
2176-
select
2176+
select
21772177
[
2178-
[(1.0, 2.0, 3.0)],
2179-
[(4.0, 5.0, 6.0), (7.0, 8.0, 9.0)],
2178+
[(1.0, 2.0, 3.0)],
2179+
[(4.0, 5.0, 6.0), (7.0, 8.0, 9.0)],
21802180
[(10.0, 11.0, 12.0), (13.0, 14.0, 15.0), (16.0, 17.0, 18.0), (19.0, 20.0, 21.0), (22.0, 23.0, 24.0)]
2181-
],
2181+
],
21822182
number
21832183
from numbers(2)
21842184
group by number;
@@ -2229,3 +2229,13 @@ func Test_ReadArrayArrayTuple(t *testing.T) {
22292229
}
22302230
}
22312231
}
2232+
2233+
func Test_RegisterDial(t *testing.T) {
2234+
clickhouse.RegisterDial(func(network, address string, timeout time.Duration, config *tls.Config) (net.Conn, error) {
2235+
return net.DialTimeout(network, address, timeout)
2236+
})
2237+
if connect, err := sql.Open("clickhouse", "tcp://127.0.0.1:9000?debug=true"); assert.NoError(t, err) {
2238+
assert.NoError(t, connect.Ping())
2239+
}
2240+
clickhouse.DeregisterDial()
2241+
}

connect.go

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"crypto/tls"
66
"database/sql/driver"
77
"net"
8+
"sync"
89
"sync/atomic"
910
"time"
1011
)
@@ -39,6 +40,28 @@ type connOptions struct {
3940
logf func(string, ...interface{})
4041
}
4142

43+
// DialFunc is a function which can be used to establish the network connection.
44+
// Custom dial functions must be registered with RegisterDial
45+
type DialFunc func(network, address string, timeout time.Duration, config *tls.Config) (net.Conn, error)
46+
47+
var (
48+
customDialLock sync.RWMutex
49+
customDial DialFunc
50+
)
51+
52+
// RegisterDial registers a custom dial function.
53+
func RegisterDial(dial DialFunc) {
54+
customDialLock.Lock()
55+
customDial = dial
56+
customDialLock.Unlock()
57+
}
58+
59+
// DeregisterDial deregisters the custom dial function.
60+
func DeregisterDial() {
61+
customDialLock.Lock()
62+
customDial = nil
63+
customDialLock.Unlock()
64+
}
4265
func dial(options connOptions) (*connect, error) {
4366
var (
4467
err error
@@ -74,18 +97,29 @@ func dial(options connOptions) (*connect, error) {
7497
}
7598
checkedHosts[num] = struct{}{}
7699
}
100+
customDialLock.RLock()
101+
cd := customDial
102+
customDialLock.RUnlock()
77103
switch {
78104
case options.secure:
79-
conn, err = tls.DialWithDialer(
80-
&net.Dialer{
81-
Timeout: options.connTimeout,
82-
},
83-
"tcp",
84-
options.hosts[num],
85-
tlsConfig,
86-
)
105+
if cd != nil {
106+
conn, err = cd("tcp", options.hosts[num], options.connTimeout, tlsConfig)
107+
} else {
108+
conn, err = tls.DialWithDialer(
109+
&net.Dialer{
110+
Timeout: options.connTimeout,
111+
},
112+
"tcp",
113+
options.hosts[num],
114+
tlsConfig,
115+
)
116+
}
87117
default:
88-
conn, err = net.DialTimeout("tcp", options.hosts[num], options.connTimeout)
118+
if cd != nil {
119+
conn, err = cd("tcp", options.hosts[num], options.connTimeout, nil)
120+
} else {
121+
conn, err = net.DialTimeout("tcp", options.hosts[num], options.connTimeout)
122+
}
89123
}
90124
if err == nil {
91125
options.logf(

0 commit comments

Comments
 (0)