Skip to content

Commit d913768

Browse files
authored
Add DialContext option (#487)
* Add DialContext option It allows to using a cancellable context for a connection * Remove previous Dial-option
1 parent c2bcf7c commit d913768

File tree

5 files changed

+39
-17
lines changed

5 files changed

+39
-17
lines changed

clickhouse.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,13 @@ func (ch *clickhouse) Stats() driver.Stats {
181181
}
182182
}
183183

184-
func (ch *clickhouse) dial() (conn *connect, err error) {
184+
func (ch *clickhouse) dial(ctx context.Context) (conn *connect, err error) {
185185
connID := int(atomic.AddInt64(&ch.connID, 1))
186186
for num := range ch.opt.Addr {
187187
if ch.opt.ConnOpenStrategy == ConnOpenRoundRobin {
188188
num = int(connID) % len(ch.opt.Addr)
189189
}
190-
if conn, err = dial(ch.opt.Addr[num], connID, ch.opt); err == nil {
190+
if conn, err = dial(ctx, ch.opt.Addr[num], connID, ch.opt); err == nil {
191191
return conn, nil
192192
}
193193
}
@@ -213,7 +213,7 @@ func (ch *clickhouse) acquire(ctx context.Context) (conn *connect, err error) {
213213
case conn := <-ch.idle:
214214
if conn.isBad() {
215215
conn.close()
216-
if conn, err = ch.dial(); err != nil {
216+
if conn, err = ch.dial(ctx); err != nil {
217217
select {
218218
case <-ch.open:
219219
default:
@@ -225,7 +225,7 @@ func (ch *clickhouse) acquire(ctx context.Context) (conn *connect, err error) {
225225
return conn, nil
226226
default:
227227
}
228-
if conn, err = ch.dial(); err != nil {
228+
if conn, err = ch.dial(ctx); err != nil {
229229
select {
230230
case <-ch.open:
231231
default:

clickhouse_options.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package clickhouse
1919

2020
import (
21+
"context"
2122
"crypto/tls"
2223
"fmt"
2324
"net"
@@ -29,9 +30,7 @@ import (
2930
"github.com/ClickHouse/clickhouse-go/v2/lib/compress"
3031
)
3132

32-
var (
33-
CompressionLZ4 compress.Method = compress.LZ4
34-
)
33+
var CompressionLZ4 compress.Method = compress.LZ4
3534

3635
type Auth struct { // has_control_character
3736
Database string
@@ -62,7 +61,7 @@ type Options struct {
6261
TLS *tls.Config
6362
Addr []string
6463
Auth Auth
65-
Dial func(addr string) (net.Conn, error)
64+
DialContext func(ctx context.Context, addr string) (net.Conn, error)
6665
Debug bool
6766
Settings Settings
6867
Compression *Compression

clickhouse_std.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func (o *stdConnOpener) Driver() driver.Driver {
4242
return &stdDriver{}
4343
}
4444

45-
func (o *stdConnOpener) Connect(context.Context) (_ driver.Conn, err error) {
45+
func (o *stdConnOpener) Connect(ctx context.Context) (_ driver.Conn, err error) {
4646
if o.err != nil {
4747
return nil, o.err
4848
}
@@ -54,7 +54,7 @@ func (o *stdConnOpener) Connect(context.Context) (_ driver.Conn, err error) {
5454
if o.opt.ConnOpenStrategy == ConnOpenRoundRobin {
5555
num = int(connID) % len(o.opt.Addr)
5656
}
57-
if conn, err = dial(o.opt.Addr[num], connID, o.opt); err == nil {
57+
if conn, err = dial(ctx, o.opt.Addr[num], connID, o.opt); err == nil {
5858
return &stdDriver{
5959
conn: conn,
6060
}, nil

conn.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ import (
3232
"github.com/ClickHouse/clickhouse-go/v2/lib/proto"
3333
)
3434

35-
func dial(addr string, num int, opt *Options) (*connect, error) {
35+
func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, error) {
3636
var (
3737
err error
3838
conn net.Conn
3939
debugf = func(format string, v ...interface{}) {}
4040
)
4141
switch {
42-
case opt.Dial != nil:
43-
conn, err = opt.Dial(addr)
42+
case opt.DialContext != nil:
43+
conn, err = opt.DialContext(ctx, addr)
4444
default:
4545
switch {
4646
case opt.TLS != nil:

tests/custom_dial_test.go

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626
"github.com/stretchr/testify/assert"
2727
)
2828

29-
func TestCustomDial(t *testing.T) {
29+
func TestCustomDialContext(t *testing.T) {
3030
var (
3131
dialCount int
3232
conn, err = clickhouse.Open(&clickhouse.Options{
@@ -36,16 +36,39 @@ func TestCustomDial(t *testing.T) {
3636
Username: "default",
3737
Password: "",
3838
},
39-
Dial: func(addr string) (net.Conn, error) {
39+
DialContext: func(ctx context.Context, addr string) (net.Conn, error) {
4040
dialCount++
41-
return net.Dial("tcp", addr)
41+
var d net.Dialer
42+
return d.DialContext(ctx, "tcp", addr)
4243
},
4344
})
4445
)
4546
if !assert.NoError(t, err) {
4647
return
4748
}
48-
if err := conn.Ping(context.Background()); assert.NoError(t, err) {
49+
ctx := context.Background()
50+
if err := conn.Ping(ctx); assert.NoError(t, err) {
4951
assert.Equal(t, 1, dialCount)
5052
}
53+
54+
ctx1, cancel1 := context.WithCancel(ctx)
55+
go func() {
56+
cancel1()
57+
}()
58+
59+
ctx2, cancel2 := context.WithCancel(ctx)
60+
defer cancel2()
61+
62+
// query is cancelled with context
63+
err = conn.QueryRow(ctx1, "SELECT sleep(3)").Scan()
64+
if assert.Error(t, err, "context cancelled") {
65+
assert.Equal(t, 1, dialCount)
66+
}
67+
68+
// uncancelled context still works (new connection is acquired)
69+
var i uint8
70+
err = conn.QueryRow(ctx2, "SELECT 1").Scan(&i)
71+
if assert.NoError(t, err) {
72+
assert.Equal(t, 2, dialCount)
73+
}
5174
}

0 commit comments

Comments
 (0)