Skip to content

Commit 71daecb

Browse files
committed
Ensure transactions are closed in pqtest
Just calling db.Close() isn't enough, as it waits for queries to finish, and it won't kill the BEGIN; etc. For regular PostgreSQL tests this isn't a huge problem, most of the time, but it does show up with something like "go test -count=50" because eventually it will run out of connections. e.g. pgbouncer and Supavisor tend to run out of connections even sooner, with pgbouncer just waiting indefinitely until a connection gets closed, causing the pgbouncer tests to be flaky on the CI. Also wrap everything in TestMain() that tests if there are no connections at the end of the test run.
1 parent 8f44823 commit 71daecb

File tree

10 files changed

+191
-75
lines changed

10 files changed

+191
-75
lines changed

conn_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,7 @@ func TestStmtQueryContext(t *testing.T) {
10731073
if err != nil {
10741074
t.Fatal(err)
10751075
}
1076+
defer stmt.Close()
10761077
_, err = stmt.QueryContext(ctx)
10771078
if !pqtest.ErrorContains(err, tt.wantErr) {
10781079
t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr)
@@ -1116,6 +1117,7 @@ func TestStmtExecContext(t *testing.T) {
11161117
if err != nil {
11171118
t.Fatal(err)
11181119
}
1120+
defer stmt.Close()
11191121
_, err = stmt.ExecContext(ctx)
11201122
if !pqtest.ErrorContains(err, tt.wantErr) {
11211123
t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr)
@@ -1745,7 +1747,6 @@ func BenchmarkSelect(b *testing.B) {
17451747
func BenchmarkPreparedSelect(b *testing.B) {
17461748
run := func(b *testing.B, result any, query string) {
17471749
stmt := pqtest.Prepare(b, pqtest.MustDB(b), query).Stmt
1748-
defer stmt.Close()
17491750

17501751
b.ResetTimer()
17511752
for i := 0; i < b.N; i++ {
@@ -1808,6 +1809,7 @@ func BenchmarkPreparedSelect(b *testing.B) {
18081809
if err != nil {
18091810
b.Fatal(err)
18101811
}
1812+
defer stmt.Close()
18111813

18121814
b.ResetTimer()
18131815
for i := 0; i < b.N; i++ {
@@ -1828,6 +1830,7 @@ func BenchmarkPreparedSelect(b *testing.B) {
18281830
if err != nil {
18291831
b.Fatal(err)
18301832
}
1833+
defer stmt.Close()
18311834

18321835
b.ResetTimer()
18331836
for i := 0; i < b.N; i++ {

copy_test.go

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@ func TestCopyInError(t *testing.T) {
4545

4646
func TestCopyInErrorWrongType(t *testing.T) {
4747
t.Parallel()
48-
tx := pqtest.Begin(t, pqtest.MustDB(t))
48+
db := pqtest.MustDB(t)
49+
tx := pqtest.Begin(t, db)
4950
pqtest.Exec(t, tx, `create temp table tbl (num integer)`)
5051

51-
stmt := pqtest.Prepare(t, tx, `copy tbl (num) from stdin`)
52+
stmt := pqtest.Prepare(t, tx, `copy tbl (num) from stdin`, db)
5253
stmt.MustExec(t, "Héllö\n ☃!\r\t\\")
5354
_, err := stmt.Exec()
5455
mustAs(t, err, pqerror.InvalidTextRepresentation)
@@ -66,10 +67,11 @@ func TestCopyInErrorOutsideTransaction(t *testing.T) {
6667

6768
func TestCopyInQueryWhileCopy(t *testing.T) {
6869
t.Parallel()
69-
tx := pqtest.Begin(t, pqtest.MustDB(t))
70+
db := pqtest.MustDB(t)
71+
tx := pqtest.Begin(t, db)
7072
pqtest.Exec(t, tx, `create temp table tbl (i int primary key)`)
7173

72-
pqtest.Prepare(t, tx, "copy tbl (i) from stdin")
74+
pqtest.Prepare(t, tx, "copy tbl (i) from stdin", db)
7375
_, err := tx.Query(`select 1`)
7476
if !errors.Is(err, errQueryInProgress) {
7577
t.Errorf("wrong error:\nhave: %s\nwant: %s", err, errQueryInProgress)
@@ -97,10 +99,11 @@ func TestCopyInNull(t *testing.T) {
9799
tt := tt
98100
t.Run("", func(t *testing.T) {
99101
t.Parallel()
100-
tx := pqtest.Begin(t, pqtest.MustDB(t))
102+
db := pqtest.MustDB(t)
103+
tx := pqtest.Begin(t, db)
101104

102105
pqtest.Exec(t, tx, `create temp table tbl (i int, t text)`)
103-
stmt := pqtest.Prepare(t, tx, tt.copy)
106+
stmt := pqtest.Prepare(t, tx, tt.copy, db)
104107
stmt.MustExec(t, 42, "forty-two")
105108
stmt.MustExec(t, tt.null, tt.null)
106109
stmt.MustExec(t)
@@ -130,10 +133,11 @@ func TestCopyInMultipleValues(t *testing.T) {
130133
tt := tt
131134
t.Run("", func(t *testing.T) {
132135
t.Parallel()
133-
tx := pqtest.Begin(t, pqtest.MustDB(t))
136+
db := pqtest.MustDB(t)
137+
tx := pqtest.Begin(t, db)
134138
pqtest.Exec(t, tx, `create temp table tbl (a int, b varchar)`)
135139

136-
stmt := pqtest.Prepare(t, tx, tt.query)
140+
stmt := pqtest.Prepare(t, tx, tt.query, db)
137141
for i := 0; i < 500; i++ {
138142
stmt.MustExec(t, int64(i), strings.Repeat("#", 500))
139143
}
@@ -161,7 +165,8 @@ func TestCopyInMultipleValues(t *testing.T) {
161165

162166
func TestCopyInRaiseStmtTrigger(t *testing.T) {
163167
t.Parallel()
164-
tx := pqtest.Begin(t, pqtest.MustDB(t))
168+
db := pqtest.MustDB(t)
169+
tx := pqtest.Begin(t, db)
165170
pqtest.Exec(t, tx, `create temp table tbl (a int, b varchar)`)
166171
pqtest.Exec(t, tx, `
167172
create or replace function pg_temp.temptest()
@@ -178,7 +183,7 @@ func TestCopyInRaiseStmtTrigger(t *testing.T) {
178183
for each row execute procedure pg_temp.temptest()
179184
`)
180185

181-
stmt := pqtest.Prepare(t, tx, `copy tbl (a, b) from stdin`)
186+
stmt := pqtest.Prepare(t, tx, `copy tbl (a, b) from stdin`, db)
182187
stmt.MustExec(t, int64(1), strings.Repeat("#", 500))
183188
stmt.MustExec(t)
184189
stmt.MustClose(t)
@@ -195,10 +200,11 @@ func TestCopyInRaiseStmtTrigger(t *testing.T) {
195200

196201
func TestCopyInTypes(t *testing.T) {
197202
t.Parallel()
198-
tx := pqtest.Begin(t, pqtest.MustDB(t))
203+
db := pqtest.MustDB(t)
204+
tx := pqtest.Begin(t, db)
199205
pqtest.Exec(t, tx, `create temp table tbl (num integer, text varchar, blob bytea, nothing varchar)`)
200206

201-
stmt := pqtest.Prepare(t, tx, `copy tbl (num, text, blob, nothing) from stdin`)
207+
stmt := pqtest.Prepare(t, tx, `copy tbl (num, text, blob, nothing) from stdin`, db)
202208
stmt.MustExec(t, int64(1234567890), "Héllö\n ☃!\r\t\\", []byte{0, 255, 9, 10, 13}, nil)
203209
stmt.MustExec(t)
204210
stmt.MustClose(t)
@@ -241,7 +247,7 @@ func TestCopyInRespLoopConnectionError(t *testing.T) {
241247

242248
pid := pqtest.Query[int64](t, tx, `select pg_backend_pid() as pid`)
243249
pqtest.Exec(t, tx, "create temp table tbl (a int)")
244-
stmt := pqtest.Prepare(t, tx, `copy tbl (a) from stdin`)
250+
stmt := pqtest.Prepare(t, tx, `copy tbl (a) from stdin`, db)
245251
pqtest.Exec(t, db, `select pg_terminate_backend($1)`, pid[0]["pid"])
246252

247253
var err error
@@ -271,10 +277,11 @@ func TestCopyInRespLoopConnectionError(t *testing.T) {
271277
}
272278

273279
func BenchmarkCopyIn(b *testing.B) {
274-
tx := pqtest.Begin(b, pqtest.MustDB(b))
280+
db := pqtest.MustDB(b)
281+
tx := pqtest.Begin(b, db)
275282

276283
pqtest.Exec(b, tx, `create temp table tbl (a int, b varchar)`)
277-
stmt := pqtest.Prepare(b, tx, `copy tbl (a, b) from stdin`)
284+
stmt := pqtest.Prepare(b, tx, `copy tbl (a, b) from stdin`, db)
278285

279286
b.ResetTimer()
280287
for i := 0; i < b.N; i++ {

encode_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ func TestByteaOutputFormats(t *testing.T) {
334334
}
335335

336336
{ // Same but with Prepare
337-
stmt := pqtest.Prepare(t, tx, `select decode('5c7800ff6162630108', 'hex')`)
337+
stmt := pqtest.Prepare(t, tx, `select decode('5c7800ff6162630108', 'hex')`, db)
338338
rows, err := stmt.Query()
339339
if err != nil {
340340
t.Fatal(err)

error_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ func TestNetworkError(t *testing.T) {
255255
}
256256
c.Dialer(failDialer{})
257257
db := sql.OpenDB(c)
258+
defer db.Close()
258259
db.SetMaxIdleConns(1)
259260
db.SetMaxOpenConns(1)
260261
if err := db.Ping(); err != nil {

example_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ func Example_copyFromStdin() {
142142
if err != nil {
143143
log.Fatal(err)
144144
}
145+
defer stmt.Close()
145146

146147
// Insert rows.
147148
users := []struct {

helper_test.go

Lines changed: 0 additions & 29 deletions
This file was deleted.

internal/pqtest/pqtest.go

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ func Home(t *testing.T) string {
102102
return pqutil.Home(true)
103103
}
104104

105+
var (
106+
mu sync.Mutex
107+
cleanups = make(map[*sql.DB][]func())
108+
)
109+
105110
// DB connects to the test database and returns the Ping error. The connection
106111
// is closed in t.Cleanup().
107112
func DB(t testing.TB, conninfo ...string) (*sql.DB, error) {
@@ -110,7 +115,15 @@ func DB(t testing.TB, conninfo ...string) (*sql.DB, error) {
110115
if err != nil {
111116
t.Fatalf("pqtest.DB: %s", err)
112117
}
113-
t.Cleanup(func() { db.Close() })
118+
t.Cleanup(func() {
119+
mu.Lock()
120+
defer mu.Unlock()
121+
for i := len(cleanups[db]) - 1; i >= 0; i-- {
122+
cleanups[db][i]()
123+
}
124+
delete(cleanups, db)
125+
db.Close()
126+
})
114127
return db, db.Ping()
115128
}
116129

@@ -132,10 +145,9 @@ func Begin(t testing.TB, db *sql.DB) *sql.Tx {
132145
if err != nil {
133146
t.Fatalf("pqtest.Begin: %s", err)
134147
}
135-
// We can't call t.Cleanup here as that will race with the t.Cleanup from
136-
// MustDB (it's called in "last added, first called", so the tx.Rollback
137-
// gets called after db.Close)
138-
// t.Cleanup(func() { tx.Rollback() })
148+
mu.Lock()
149+
defer mu.Unlock()
150+
cleanups[db] = append(cleanups[db], func() { tx.Rollback() })
139151
return tx
140152
}
141153

@@ -163,12 +175,25 @@ func (s *Stmt) MustClose(t testing.TB) {
163175
// Prepare a new statement, calling t.Fatal() if this fails.
164176
func Prepare(t testing.TB, db interface {
165177
Prepare(string) (*sql.Stmt, error)
166-
}, q string) *Stmt {
178+
}, q string, sqldb ...*sql.DB) *Stmt {
167179
t.Helper()
168180
stmt, err := db.Prepare(q)
169181
if err != nil {
170182
t.Fatalf("pqtest.Prepare: %s", err)
171183
}
184+
185+
if len(sqldb) == 0 {
186+
sqldb = make([]*sql.DB, 1)
187+
var ok bool
188+
sqldb[0], ok = db.(*sql.DB)
189+
if !ok {
190+
t.Fatalf("pqtest.Prepare: must pass sql.DB when using transaction")
191+
}
192+
}
193+
mu.Lock()
194+
defer mu.Unlock()
195+
cleanups[sqldb[0]] = append(cleanups[sqldb[0]], func() { stmt.Close() })
196+
172197
return &Stmt{stmt}
173198
}
174199

internal/pqtest/pqtest_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package pqtest_test
2+
3+
import (
4+
"testing"
5+
6+
_ "github.com/lib/pq"
7+
"github.com/lib/pq/internal/pqtest"
8+
)
9+
10+
// Just calling db.Close() isn't enough, as it waits for queries to finish, and
11+
// it won't kill the BEGIN; etc.
12+
func TestCleanup(t *testing.T) {
13+
t.Run("", func(t *testing.T) {
14+
t.Setenv("PGAPPNAME", "pqgo-cleanup")
15+
db := pqtest.MustDB(t)
16+
pqtest.Begin(t, db)
17+
pqtest.Query[int](t, db, `select 1`)
18+
stmt := pqtest.Prepare(t, db, `select 1`)
19+
20+
// No helper function for these as they're not used that frequently, and
21+
// for stmt.Query() also difficult to do right.
22+
rows1, _ := db.Query(`select 1`)
23+
defer rows1.Close()
24+
rows2, _ := stmt.Query()
25+
defer rows2.Close()
26+
})
27+
28+
rows := pqtest.Query[any](t, pqtest.MustDB(t),
29+
`select pid, query from pg_stat_activity where application_name = 'pqgo-cleanup' and pid != pg_backend_pid()`)
30+
for _, r := range rows {
31+
t.Errorf("connection still active: pid=%d; query=%q\n", r["pid"], r["query"])
32+
}
33+
}

notify_test.go

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"math/big"
1010
"net"
1111
"runtime"
12-
"sync"
12+
"sync/atomic"
1313
"testing"
1414
"time"
1515

@@ -181,24 +181,23 @@ func TestListenerConnExecDeadlock(t *testing.T) {
181181
l, _ := newTestListenerConn(t)
182182
defer l.Close()
183183

184-
var wg sync.WaitGroup
185-
wg.Add(2)
184+
var done atomic.Int32
186185
go func() {
187186
l.ExecSimpleQuery("select pg_sleep(0.2)")
188-
wg.Done()
187+
done.Add(1)
189188
}()
190189
runtime.Gosched()
191190
go func() {
192191
l.ExecSimpleQuery("select 1")
193-
wg.Done()
192+
done.Add(1)
194193
}()
195-
// give the two goroutines some time to get into position
196-
runtime.Gosched()
197-
// calls Close on the net.Conn; equivalent to a network failure
198-
l.Close()
194+
runtime.Gosched() // Give the above goroutine some time to get into position.
195+
l.Close() // Calls Close on the net.Conn; equivalent to a network failure.
199196

200-
defer time.AfterFunc(200*time.Millisecond, func() { panic("timed out") }).Stop()
201-
wg.Wait()
197+
time.Sleep(200 * time.Millisecond)
198+
if done.Load() != 2 {
199+
t.Fatal("timed out")
200+
}
202201
}
203202

204203
// Test for ListenerConn being closed while a slow query is executing
@@ -207,29 +206,26 @@ func TestListenerConnCloseWhileQueryIsExecuting(t *testing.T) {
207206
l, _ := newTestListenerConn(t)
208207
defer l.Close()
209208

210-
var wg sync.WaitGroup
211-
wg.Add(1)
212-
209+
var done atomic.Int32
213210
go func() {
214211
sent, err := l.ExecSimpleQuery("select pg_sleep(0.2)")
215212
if sent {
216213
panic("expected sent=false")
217214
}
218-
// could be any of a number of errors
219-
if err == nil {
215+
if err == nil { // Could be any of a number of errors.
220216
panic("expected error")
221217
}
222-
wg.Done()
218+
done.Add(1)
223219
}()
224-
// give the above goroutine some time to get into position
225-
runtime.Gosched()
226-
err := l.Close()
227-
if err != nil {
220+
runtime.Gosched() // Give the above goroutine some time to get into position.
221+
if err := l.Close(); err != nil {
228222
t.Fatal(err)
229223
}
230224

231-
defer time.AfterFunc(200*time.Millisecond, func() { panic("timed out") }).Stop()
232-
wg.Wait()
225+
time.Sleep(200 * time.Millisecond)
226+
if done.Load() != 1 {
227+
t.Fatal("timed out")
228+
}
233229
}
234230

235231
func TestListenerNotifyExtra(t *testing.T) {

0 commit comments

Comments
 (0)