Skip to content

Commit a62b5c5

Browse files
authored
Refine BacktestDataIterator (#2591)
1 parent df5f6b1 commit a62b5c5

File tree

3 files changed

+27
-11
lines changed

3 files changed

+27
-11
lines changed

nautilus_trader/backtest/engine.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ cdef class BacktestDataIterator:
116116

117117
cpdef void _reset_single_data(self)
118118
cpdef void add_data(self, str data_name, list data_list, bint append_data=*)
119+
cdef void _add_data(self, str data_name, list data_list, bint append_data=*)
119120
cpdef void remove_data(self, str data_name)
120121
cpdef void _activate_single_data(self)
121122
cpdef void _deactivate_single_data(self)

nautilus_trader/backtest/engine.pyx

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,6 +1678,10 @@ cdef class BacktestDataIterator:
16781678
self._is_single_data = False
16791679
16801680
cpdef void add_data(self, str data_name, list data_list, bint append_data=True):
1681+
# closures inside cpdef functions not yet supported
1682+
self._add_data(data_name, data_list, append_data)
1683+
1684+
cdef void _add_data(self, str data_name, list data_list, bint append_data=True):
16811685
if len(data_list) == 0:
16821686
return
16831687
@@ -1687,24 +1691,25 @@ cdef class BacktestDataIterator:
16871691
data_priority = self._data_priority[data_name]
16881692
self.remove_data(data_name)
16891693
else:
1690-
# heapq is a min priority data so lower number means higher priority
1694+
# heapq is a min priority queue so lower number means higher priority
16911695
data_priority = (1 if append_data else -1) * self._next_data_priority
16921696
self._next_data_priority += 1
16931697
16941698
if self._is_single_data:
16951699
self._deactivate_single_data()
16961700
1697-
self._data[data_priority] = data_list
1701+
self._data[data_priority] = sorted(data_list, key=lambda data: data.ts_init)
16981702
self._data_name[data_priority] = data_name
16991703
self._data_priority[data_name] = data_priority
17001704
self._data_len[data_priority] = len(data_list)
17011705
self._data_index[data_priority] = 0
1702-
self._push_data(data_priority, 0)
17031706
17041707
if len(self._data) == 1:
17051708
self._activate_single_data()
17061709
return
17071710
1711+
self._push_data(data_priority, 0)
1712+
17081713
cpdef void remove_data(self, str data_name):
17091714
if data_name not in self._data_priority:
17101715
return
@@ -1767,11 +1772,17 @@ cdef class BacktestDataIterator:
17671772
17681773
return object_to_return
17691774
1770-
cursor = self._single_data_index
1775+
if self._single_data_index >= self._single_data_len:
1776+
return None
1777+
1778+
object_to_return = self._single_data[self._single_data_index]
17711779
self._single_data_index += 1
17721780
1773-
if cursor < self._single_data_len:
1774-
return self._single_data[cursor]
1781+
if self._single_data_index >= self._single_data_len:
1782+
if self._empty_data_callback is not None:
1783+
self._empty_data_callback(self._single_data_name, self._single_data[-1].ts_init)
1784+
1785+
return object_to_return
17751786
17761787
cpdef void _push_data(self, int data_priority, int data_index):
17771788
cdef uint64_t ts_init
@@ -1812,11 +1823,14 @@ cdef class BacktestDataIterator:
18121823
self._reset_heap()
18131824
18141825
cpdef bint is_done(self):
1815-
return (self._is_single_data and self._single_data_index >= self._single_data_len) or not self._heap
1826+
if self._is_single_data:
1827+
return self._single_data_index >= self._single_data_len
1828+
else:
1829+
return not self._heap
18161830
18171831
cpdef dict all_data(self):
18181832
# we assume dicts are ordered by order of insertion
1819-
return {data_name:data for data_name, data in zip(self._data_name.values(), self._data.values())}
1833+
return {data_name:self._data[data_priority] for data_priority, data_name in self._data_name.items()}
18201834
18211835
cpdef list data(self, str data_name):
18221836
return self._data[self._data_priority[data_name]]

tests/unit_tests/backtest/test_data_iterator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_all_data_returns_mapping(self):
7373

7474
# Assert
7575
assert list(mapping.keys()) == ["only"]
76-
assert mapping["only"] is lst
76+
assert mapping["only"] == lst
7777

7878
def test_remove_stream_effect(self):
7979
"""
@@ -164,11 +164,12 @@ def cb(name, ts):
164164

165165
# Reset and re-consume
166166
it.reset()
167+
callback_data = []
167168
values = [x.value for x in it]
168169

169170
# Assert
170171
assert (first, second, third) == (1, 2, 3)
171-
assert callback_data == []
172+
assert callback_data == [("single", 3)]
172173
assert it.is_done()
173174

174175
assert values == [1, 2, 3]
@@ -216,7 +217,7 @@ def test_set_index_and_data_accessor_and_is_done_empty(self):
216217
data = [MyData(10, ts_init=10), MyData(20, ts_init=20), MyData(30, ts_init=30)]
217218
it.add_data("stream", data)
218219

219-
assert it.data("stream") is data
220+
assert it.data("stream") == data
220221

221222
with pytest.raises(KeyError):
222223
it.data("unknown")

0 commit comments

Comments
 (0)