Skip to content
This repository was archived by the owner on Feb 1, 2024. It is now read-only.

Commit ad60765

Browse files
committed
Database Schema Test Infrastructure also tests indexes on tables
1 parent fea5021 commit ad60765

3 files changed

Lines changed: 52 additions & 1 deletion

File tree

cmd/trade_test.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ func TestTradeUpgradeScripts(t *testing.T) {
6969
DataType: "text",
7070
CharacterMaximumLength: nil,
7171
}, &columns[4])
72+
// check indexes of db_version table
73+
indexes := database.GetTableIndexes(db, "db_version")
74+
assert.Equal(t, 1, len(indexes))
75+
database.AssertIndex(t, "db_version", "db_version_pkey", "CREATE UNIQUE INDEX db_version_pkey ON public.db_version USING btree (version)", indexes)
7276

7377
// check schema of markets table
7478
columns = database.GetTableSchema(db, "markets")
@@ -105,6 +109,10 @@ func TestTradeUpgradeScripts(t *testing.T) {
105109
DataType: "text",
106110
CharacterMaximumLength: nil,
107111
}, &columns[3])
112+
// check indexes of markets table
113+
indexes = database.GetTableIndexes(db, "markets")
114+
assert.Equal(t, 1, len(indexes))
115+
database.AssertIndex(t, "markets", "markets_pkey", "CREATE UNIQUE INDEX markets_pkey ON public.markets USING btree (market_id)", indexes)
108116

109117
// check schema of trades table
110118
columns = database.GetTableSchema(db, "trades")
@@ -181,6 +189,11 @@ func TestTradeUpgradeScripts(t *testing.T) {
181189
DataType: "double precision",
182190
CharacterMaximumLength: nil,
183191
}, &columns[8])
192+
// check indexes of trades table
193+
indexes = database.GetTableIndexes(db, "trades")
194+
assert.Equal(t, 2, len(indexes))
195+
database.AssertIndex(t, "trades", "trades_pkey", "CREATE UNIQUE INDEX trades_pkey ON public.trades USING btree (market_id, txid)", indexes)
196+
database.AssertIndex(t, "trades", "trades_mdd", "CREATE INDEX trades_mdd ON public.trades USING btree (market_id, date(date_utc), date_utc)", indexes)
184197

185198
// check entries of db_version table
186199
var allRows [][]interface{}
@@ -189,7 +202,7 @@ func TestTradeUpgradeScripts(t *testing.T) {
189202
// first three code_version_string is nil becuase the field was not supported at the time when the upgrade script was run, and only in version 4 of
190203
// the database do we add the field. See upgradeScripts and RunUpgradeScripts() for more details
191204
database.ValidateDBVersionRow(t, allRows[0], 1, time.Now(), 1, 10, nil)
192-
database.ValidateDBVersionRow(t, allRows[1], 2, time.Now(), 3, 10, nil)
205+
database.ValidateDBVersionRow(t, allRows[1], 2, time.Now(), 3, 15, nil)
193206
database.ValidateDBVersionRow(t, allRows[2], 3, time.Now(), 2, 10, nil)
194207
database.ValidateDBVersionRow(t, allRows[3], 4, time.Now(), 1, 10, &codeVersionString)
195208

support/database/upgrade_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ func TestUpgradeScripts(t *testing.T) {
8181
CharacterMaximumLength: nil,
8282
}, &columns[4])
8383

84+
// check indexes of db_version table
85+
indexes := GetTableIndexes(db, "db_version")
86+
assert.Equal(t, 1, len(indexes))
87+
AssertIndex(t, "db_version", "db_version_pkey", "CREATE UNIQUE INDEX db_version_pkey ON public.db_version USING btree (version)", indexes)
88+
8489
// check entries of db_version table
8590
allRows := QueryAllRows(db, "db_version")
8691
assert.Equal(t, 2, len(allRows))

support/database/upgrade_test_helper.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,39 @@ func GetTableSchema(db *sql.DB, tableName string) []TableColumn {
133133
return items
134134
}
135135

136+
// IndexSearchResult captures the result from GetTableIndexes() and is used as input to AssertIndex()
137+
type IndexSearchResult map[string]string
138+
139+
// GetTableIndexes is well-named
140+
func GetTableIndexes(db *sql.DB, tableName string) IndexSearchResult {
141+
indexQueryResult, e := db.Query(fmt.Sprintf("SELECT indexname, indexdef from pg_indexes where schemaname = 'public' AND tablename = '%s'", tableName))
142+
if e != nil {
143+
panic(e)
144+
}
145+
defer indexQueryResult.Close() // remembering to defer closing the query
146+
147+
m := map[string]string{}
148+
for indexQueryResult.Next() { // remembering to call Next() before Scan()
149+
var name, def string
150+
e = indexQueryResult.Scan(&name, &def)
151+
if e != nil {
152+
panic(e)
153+
}
154+
155+
m[name] = def
156+
}
157+
158+
return m
159+
}
160+
161+
// AssertIndex validates that the index exists
162+
func AssertIndex(t *testing.T, tableName string, wantIndexName string, wantDefinition string, indexes IndexSearchResult) {
163+
m := map[string]string(indexes)
164+
if v, ok := m[wantIndexName]; assert.True(t, ok, fmt.Sprintf("index '%s' should exist in the table '%s'", wantIndexName, tableName)) {
165+
assert.Equal(t, wantDefinition, v)
166+
}
167+
}
168+
136169
// QueryAllRows queries all the rows of a given table in a database
137170
func QueryAllRows(db *sql.DB, tableName string) [][]interface{} {
138171
queryResult, e := db.Query(fmt.Sprintf("SELECT * FROM %s", tableName))

0 commit comments

Comments
 (0)