Skip to content
This repository was archived by the owner on Feb 1, 2024. It is now read-only.

Commit fea5021

Browse files
authored
Add support for code_version_string field in the db_version table of the trade bot's database (closes #447) (#448)
* 1 - export test infrastructure fields for use outside the support/database package * 2 - TestTradeUpgradeScripts along with upgrading db_version table in trade#upgradeScripts variable
1 parent f394b46 commit fea5021

5 files changed

Lines changed: 487 additions & 258 deletions

File tree

cmd/trade.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ import (
3232
)
3333

3434
var upgradeScripts = []*database.UpgradeScript{
35-
database.MakeUpgradeScript(1, database.SqlDbVersionTableCreate),
35+
database.MakeUpgradeScript(1,
36+
database.SqlDbVersionTableCreate,
37+
),
3638
database.MakeUpgradeScript(2,
3739
kelpdb.SqlMarketsTableCreate,
3840
kelpdb.SqlTradesTableCreate,
@@ -42,6 +44,9 @@ var upgradeScripts = []*database.UpgradeScript{
4244
kelpdb.SqlTradesIndexDrop,
4345
kelpdb.SqlTradesIndexCreate2,
4446
),
47+
database.MakeUpgradeScript(4,
48+
database.SqlDbVersionTableAlter1,
49+
),
4550
}
4651

4752
const tradeExamples = ` kelp trade --botConf ./path/trader.cfg --strategy buysell --stratConf ./path/buysell.cfg

cmd/trade_test.go

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
package cmd
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
"time"
7+
8+
"github.com/stretchr/testify/assert"
9+
10+
"github.com/stellar/kelp/support/database"
11+
)
12+
13+
func TestTradeUpgradeScripts(t *testing.T) {
14+
// run the PreTest and defer running the postTest
15+
db, dbname := database.PreTest(t)
16+
defer database.PostTestWithDbClose(db, dbname)
17+
18+
// run the upgrade scripts
19+
codeVersionString := "TestTradeUpgradeScripts"
20+
database.RunUpgradeScripts(db, upgradeScripts, codeVersionString)
21+
22+
// assert current state of the database
23+
assert.Equal(t, 3, database.GetNumTablesInDb(db))
24+
assert.True(t, database.CheckTableExists(db, "db_version"))
25+
assert.True(t, database.CheckTableExists(db, "markets"))
26+
assert.True(t, database.CheckTableExists(db, "trades"))
27+
28+
// check schema of db_version table
29+
var columns []database.TableColumn
30+
columns = database.GetTableSchema(db, "db_version")
31+
assert.Equal(t, 5, len(columns), fmt.Sprintf("%v", columns))
32+
database.AssertTableColumnsEqual(t, &database.TableColumn{
33+
ColumnName: "version",
34+
OrdinalPosition: 1,
35+
ColumnDefault: nil,
36+
IsNullable: "NO",
37+
DataType: "integer",
38+
CharacterMaximumLength: nil,
39+
}, &columns[0])
40+
database.AssertTableColumnsEqual(t, &database.TableColumn{
41+
ColumnName: "date_completed_utc",
42+
OrdinalPosition: 2,
43+
ColumnDefault: nil,
44+
IsNullable: "NO",
45+
DataType: "timestamp without time zone",
46+
CharacterMaximumLength: nil,
47+
}, &columns[1])
48+
database.AssertTableColumnsEqual(t, &database.TableColumn{
49+
ColumnName: "num_scripts",
50+
OrdinalPosition: 3,
51+
ColumnDefault: nil,
52+
IsNullable: "NO",
53+
DataType: "integer",
54+
CharacterMaximumLength: nil,
55+
}, &columns[2])
56+
database.AssertTableColumnsEqual(t, &database.TableColumn{
57+
ColumnName: "time_elapsed_millis",
58+
OrdinalPosition: 4,
59+
ColumnDefault: nil,
60+
IsNullable: "NO",
61+
DataType: "bigint",
62+
CharacterMaximumLength: nil,
63+
}, &columns[3])
64+
database.AssertTableColumnsEqual(t, &database.TableColumn{
65+
ColumnName: "code_version_string",
66+
OrdinalPosition: 5,
67+
ColumnDefault: nil,
68+
IsNullable: "YES",
69+
DataType: "text",
70+
CharacterMaximumLength: nil,
71+
}, &columns[4])
72+
73+
// check schema of markets table
74+
columns = database.GetTableSchema(db, "markets")
75+
assert.Equal(t, 4, len(columns), fmt.Sprintf("%v", columns))
76+
database.AssertTableColumnsEqual(t, &database.TableColumn{
77+
ColumnName: "market_id",
78+
OrdinalPosition: 1,
79+
ColumnDefault: nil,
80+
IsNullable: "NO",
81+
DataType: "text",
82+
CharacterMaximumLength: nil,
83+
}, &columns[0])
84+
database.AssertTableColumnsEqual(t, &database.TableColumn{
85+
ColumnName: "exchange_name",
86+
OrdinalPosition: 2,
87+
ColumnDefault: nil,
88+
IsNullable: "NO",
89+
DataType: "text",
90+
CharacterMaximumLength: nil,
91+
}, &columns[1])
92+
database.AssertTableColumnsEqual(t, &database.TableColumn{
93+
ColumnName: "base",
94+
OrdinalPosition: 3,
95+
ColumnDefault: nil,
96+
IsNullable: "NO",
97+
DataType: "text",
98+
CharacterMaximumLength: nil,
99+
}, &columns[2])
100+
database.AssertTableColumnsEqual(t, &database.TableColumn{
101+
ColumnName: "quote",
102+
OrdinalPosition: 4,
103+
ColumnDefault: nil,
104+
IsNullable: "NO",
105+
DataType: "text",
106+
CharacterMaximumLength: nil,
107+
}, &columns[3])
108+
109+
// check schema of trades table
110+
columns = database.GetTableSchema(db, "trades")
111+
assert.Equal(t, 9, len(columns), fmt.Sprintf("%v", columns))
112+
database.AssertTableColumnsEqual(t, &database.TableColumn{
113+
ColumnName: "market_id",
114+
OrdinalPosition: 1,
115+
ColumnDefault: nil,
116+
IsNullable: "NO",
117+
DataType: "text",
118+
CharacterMaximumLength: nil,
119+
}, &columns[0])
120+
database.AssertTableColumnsEqual(t, &database.TableColumn{
121+
ColumnName: "txid",
122+
OrdinalPosition: 2,
123+
ColumnDefault: nil,
124+
IsNullable: "NO",
125+
DataType: "text",
126+
CharacterMaximumLength: nil,
127+
}, &columns[1])
128+
database.AssertTableColumnsEqual(t, &database.TableColumn{
129+
ColumnName: "date_utc",
130+
OrdinalPosition: 3,
131+
ColumnDefault: nil,
132+
IsNullable: "NO",
133+
DataType: "timestamp without time zone",
134+
CharacterMaximumLength: nil,
135+
}, &columns[2])
136+
database.AssertTableColumnsEqual(t, &database.TableColumn{
137+
ColumnName: "action",
138+
OrdinalPosition: 4,
139+
ColumnDefault: nil,
140+
IsNullable: "NO",
141+
DataType: "text",
142+
CharacterMaximumLength: nil,
143+
}, &columns[3])
144+
database.AssertTableColumnsEqual(t, &database.TableColumn{
145+
ColumnName: "type",
146+
OrdinalPosition: 5,
147+
ColumnDefault: nil,
148+
IsNullable: "NO",
149+
DataType: "text",
150+
CharacterMaximumLength: nil,
151+
}, &columns[4])
152+
database.AssertTableColumnsEqual(t, &database.TableColumn{
153+
ColumnName: "counter_price",
154+
OrdinalPosition: 6,
155+
ColumnDefault: nil,
156+
IsNullable: "NO",
157+
DataType: "double precision",
158+
CharacterMaximumLength: nil,
159+
}, &columns[5])
160+
database.AssertTableColumnsEqual(t, &database.TableColumn{
161+
ColumnName: "base_volume",
162+
OrdinalPosition: 7,
163+
ColumnDefault: nil,
164+
IsNullable: "NO",
165+
DataType: "double precision",
166+
CharacterMaximumLength: nil,
167+
}, &columns[6])
168+
database.AssertTableColumnsEqual(t, &database.TableColumn{
169+
ColumnName: "counter_cost",
170+
OrdinalPosition: 8,
171+
ColumnDefault: nil,
172+
IsNullable: "NO",
173+
DataType: "double precision",
174+
CharacterMaximumLength: nil,
175+
}, &columns[7])
176+
database.AssertTableColumnsEqual(t, &database.TableColumn{
177+
ColumnName: "fee",
178+
OrdinalPosition: 9,
179+
ColumnDefault: nil,
180+
IsNullable: "NO",
181+
DataType: "double precision",
182+
CharacterMaximumLength: nil,
183+
}, &columns[8])
184+
185+
// check entries of db_version table
186+
var allRows [][]interface{}
187+
allRows = database.QueryAllRows(db, "db_version")
188+
assert.Equal(t, 4, len(allRows))
189+
// first three code_version_string is nil becuase the field was not supported at the time when the upgrade script was run, and only in version 4 of
190+
// the database do we add the field. See upgradeScripts and RunUpgradeScripts() for more details
191+
database.ValidateDBVersionRow(t, allRows[0], 1, time.Now(), 1, 10, nil)
192+
database.ValidateDBVersionRow(t, allRows[1], 2, time.Now(), 3, 10, nil)
193+
database.ValidateDBVersionRow(t, allRows[2], 3, time.Now(), 2, 10, nil)
194+
database.ValidateDBVersionRow(t, allRows[3], 4, time.Now(), 1, 10, &codeVersionString)
195+
196+
// check entries of markets table
197+
allRows = database.QueryAllRows(db, "markets")
198+
assert.Equal(t, 0, len(allRows))
199+
200+
// check entries of markets table
201+
allRows = database.QueryAllRows(db, "trades")
202+
assert.Equal(t, 0, len(allRows))
203+
}

support/database/upgrade.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func ConnectInitializedDatabase(postgresDbConfig *postgresdb.Config, upgradeScri
5959
// don't defer db.Close() here becuase we want it open for the life of the application for now
6060

6161
log.Printf("creating db schema and running upgrade scripts ...\n")
62-
e = runUpgradeScripts(db, upgradeScripts, codeVersionString)
62+
e = RunUpgradeScripts(db, upgradeScripts, codeVersionString)
6363
if e != nil {
6464
return nil, fmt.Errorf("could not run upgrade scripts: %s", e)
6565
}
@@ -68,7 +68,8 @@ func ConnectInitializedDatabase(postgresDbConfig *postgresdb.Config, upgradeScri
6868
return db, nil
6969
}
7070

71-
func runUpgradeScripts(db *sql.DB, scripts []*UpgradeScript, codeVersionString string) error {
71+
// RunUpgradeScripts is a utility function that can be run from outside this package so we need to export it
72+
func RunUpgradeScripts(db *sql.DB, scripts []*UpgradeScript, codeVersionString string) error {
7273
// save feature flags for the db_version table here
7374
hasCodeVersionString := false
7475

0 commit comments

Comments
 (0)