Skip to content

Commit 3d7bd11

Browse files
committed
Exec: bind values; Fix 'INSERT INTO ...SELECT' parsing
1 parent b964252 commit 3d7bd11

File tree

4 files changed

+31
-29
lines changed

4 files changed

+31
-29
lines changed

clickhouse.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,7 @@ func (ch *clickhouse) Rollback() error {
168168
}
169169
ch.data = nil
170170
ch.inTransaction = false
171-
if err := ch.cancel(); err != nil {
172-
return err
173-
}
174-
return driver.ErrBadConn
171+
return nil
175172
}
176173

177174
func (ch *clickhouse) Close() error {

clickhouse_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package clickhouse_test
22

33
import (
44
"database/sql"
5-
"database/sql/driver"
65
"fmt"
76
"strings"
87
"testing"
@@ -704,7 +703,7 @@ func Test_Tx(t *testing.T) {
704703
if tx, err := connect.Begin(); assert.NoError(t, err) {
705704
_, err = tx.Query("SELECT 1")
706705
if assert.NoError(t, err) {
707-
if !assert.Equal(t, driver.ErrBadConn, tx.Rollback()) {
706+
if !assert.NoError(t, tx.Rollback()) {
708707
return
709708
}
710709
}

helpers.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"database/sql/driver"
66
"fmt"
7+
"regexp"
78
"strings"
89
"time"
910
)
@@ -96,9 +97,11 @@ func paramParser(reader *bytes.Reader) string {
9697
return name.String()
9798
}
9899

100+
var selectRe = regexp.MustCompile(`\s+SELECT\s+`)
101+
99102
func isInsert(query string) bool {
100103
if f := strings.Fields(query); len(f) > 2 {
101-
return strings.EqualFold("INSERT", f[0]) && strings.EqualFold("INTO", f[1]) && strings.Index(strings.ToUpper(query), " SELECT ") == -1
104+
return strings.EqualFold("INSERT", f[0]) && strings.EqualFold("INTO", f[1]) && !selectRe.MatchString(strings.ToUpper(query))
102105
}
103106
return false
104107
}

stmt.go

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ func (stmt *stmt) execContext(ctx context.Context, args []driver.Value) (driver.
4343
}
4444
return emptyResult, nil
4545
}
46-
if err := stmt.ch.sendQuery(stmt.query); err != nil {
46+
47+
if err := stmt.ch.sendQuery(stmt.bind(convertOldArgs(args))); err != nil {
4748
return nil, err
4849
}
4950
if _, err := stmt.ch.receiveData(); err != nil {
@@ -57,16 +58,33 @@ func (stmt *stmt) Query(args []driver.Value) (driver.Rows, error) {
5758
}
5859

5960
func (stmt *stmt) queryContext(ctx context.Context, args []namedValue) (driver.Rows, error) {
61+
if finish := stmt.ch.watchCancel(ctx); finish != nil {
62+
defer finish()
63+
}
64+
65+
if err := stmt.ch.sendQuery(stmt.bind(args)); err != nil {
66+
return nil, err
67+
}
68+
69+
rows, err := stmt.ch.receiveData()
70+
if err != nil {
71+
return nil, err
72+
}
73+
74+
return rows, nil
75+
}
76+
77+
func (stmt *stmt) Close() error {
78+
stmt.ch.logf("[stmt] close")
79+
return nil
80+
}
81+
82+
func (stmt *stmt) bind(args []namedValue) string {
6083
var (
6184
buf bytes.Buffer
6285
index int
6386
keyword bool
6487
)
65-
66-
if finish := stmt.ch.watchCancel(ctx); finish != nil {
67-
defer finish()
68-
}
69-
7088
switch {
7189
case stmt.NumInput() != 0:
7290
reader := bytes.NewReader([]byte(stmt.query))
@@ -110,22 +128,7 @@ func (stmt *stmt) queryContext(ctx context.Context, args []namedValue) (driver.R
110128
default:
111129
buf.WriteString(stmt.query)
112130
}
113-
114-
if err := stmt.ch.sendQuery(buf.String()); err != nil {
115-
return nil, err
116-
}
117-
118-
rows, err := stmt.ch.receiveData()
119-
if err != nil {
120-
return nil, err
121-
}
122-
123-
return rows, nil
124-
}
125-
126-
func (stmt *stmt) Close() error {
127-
stmt.ch.logf("[stmt] close")
128-
return nil
131+
return buf.String()
129132
}
130133

131134
type namedValue struct {

0 commit comments

Comments
 (0)