Skip to content

Commit 3569a01

Browse files
committed
Added tests to test chunking of the BBOXES.
1 parent 25c278a commit 3569a01

File tree

1 file changed

+149
-0
lines changed

1 file changed

+149
-0
lines changed

emission/individual_tests/TestOverpass.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import requests
1111
import emission.net.ext_service.transit_matching.match_stops as enetm
1212
import logging
13+
import attrdict as ad
14+
import math
1315

1416
#Set up query
1517
GEOFABRIK_OVERPASS_KEY = os.environ.get("GEOFABRIK_OVERPASS_KEY")
@@ -68,6 +70,153 @@ def test_get_predicted_transit_mode(self):
6870
expected_result = ['train', 'train']
6971
self.assertEqual(actual_result, expected_result)
7072

73+
def test_chunk_list(self):
74+
# Case 1: List of 10 elements with chunk size of 3.
75+
data = list(range(1, 11)) # [1, 2, ..., 10]
76+
chunk_size = 3
77+
chunks = list(enetm.chunk_list(data, chunk_size))
78+
expected_chunks = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]]
79+
self.assertEqual(chunks, expected_chunks)
80+
81+
# Case 2: Exact division
82+
data_exact = list(range(1, 10)) # [1, 2, ..., 9]
83+
chunk_size = 3
84+
chunks_exact = list(enetm.chunk_list(data_exact, chunk_size))
85+
expected_chunks_exact = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
86+
self.assertEqual(chunks_exact, expected_chunks_exact)
87+
88+
# Case 3: Empty list
89+
data_empty = []
90+
chunks_empty = list(enetm.chunk_list(data_empty, chunk_size))
91+
self.assertEqual(chunks_empty, [])
92+
93+
def test_get_stops_near_many_chunks(self):
94+
"""
95+
Test get_stops_near with many chunks.
96+
Override MAX_BBOXES_PER_QUERY to 1 so that each coordinate produces its own chunk.
97+
Supply 20 dummy coordinates and verify that 20 chunks are returned.
98+
"""
99+
original_max = enetm.MAX_BBOXES_PER_QUERY
100+
enetm.MAX_BBOXES_PER_QUERY = 1
101+
102+
# Create 20 dummy coordinates ([lon, lat]).
103+
coords = [[i, i + 0.5] for i in range(20)]
104+
105+
# Patch make_request_and_catch to return a dummy response.
106+
original_make_request_and_catch = enetm.make_request_and_catch
107+
def dummy_make_request_and_catch(query):
108+
# Return a dummy response: one dummy node and a "count" marker.
109+
return [{'id': 100, 'type': 'node', 'tags': {'dummy': True}},
110+
{'type': 'count'}]
111+
enetm.make_request_and_catch = dummy_make_request_and_catch
112+
113+
stops = enetm.get_stops_near(coords, 150.0)
114+
# Expect one chunk per coordinate = 20 chunks.
115+
self.assertEqual(len(stops), 20)
116+
for chunk in stops:
117+
# Each chunk (from the dummy response) should contain one stop.
118+
self.assertEqual(len(chunk), 1)
119+
self.assertEqual(chunk[0]['tags'], {'dummy': True})
120+
121+
# Restore original settings.
122+
enetm.MAX_BBOXES_PER_QUERY = original_max
123+
enetm.make_request_and_catch = original_make_request_and_catch
124+
125+
def test_get_predicted_transit_mode_many_chunks(self):
126+
"""
127+
Test get_predicted_transit_mode when provided with many stops.
128+
Simulate two sets (start and end) of 20 stops each, where each stop carries
129+
a unique route (with matching ids across both sets). Expect one matching route per stop.
130+
"""
131+
start_stops = []
132+
end_stops = []
133+
for i in range(20):
134+
# Create a dummy route with a unique id and route "train".
135+
# Include a "ref" key to avoid an AttributeError.
136+
route = ad.AttrDict({
137+
'id': i,
138+
'tags': {'route': 'train', 'ref': str(i)}
139+
})
140+
stop_start = ad.AttrDict({
141+
'id': i,
142+
'tags': {},
143+
'routes': [route]
144+
})
145+
stop_end = ad.AttrDict({
146+
'id': i,
147+
'tags': {},
148+
'routes': [route]
149+
})
150+
start_stops.append(stop_start)
151+
end_stops.append(stop_end)
152+
153+
actual_result = enetm.get_predicted_transit_mode(start_stops, end_stops)
154+
expected_result = ['train'] * 20
155+
self.assertEqual(actual_result, expected_result)
156+
157+
def test_get_stops_near_different_batch_sizes(self):
158+
"""
159+
Test get_stops_near using varying batch sizes.
160+
For each batch size, override MAX_BBOXES_PER_QUERY and supply a fixed list of dummy
161+
coordinates. Verify that the number of returned chunks equals ceil(total_coords / batch_size).
162+
"""
163+
original_max = enetm.MAX_BBOXES_PER_QUERY
164+
original_make_request_and_catch = enetm.make_request_and_catch
165+
166+
# Create 7 dummy coordinates.
167+
coords = [[i, i + 0.5] for i in range(7)]
168+
169+
# Dummy response: one dummy node and a "count" marker.
170+
def dummy_make_request_and_catch(query):
171+
return [{'id': 100, 'type': 'node', 'tags': {'dummy': True}},
172+
{'type': 'count'}]
173+
enetm.make_request_and_catch = dummy_make_request_and_catch
174+
175+
for batch_size in [1, 2, 5, 10]:
176+
enetm.MAX_BBOXES_PER_QUERY = batch_size
177+
stops = enetm.get_stops_near(coords, 150.0)
178+
expected_chunks = math.ceil(len(coords) / batch_size)
179+
self.assertEqual(len(stops), expected_chunks,
180+
msg=f"Batch size {batch_size} produced {len(stops)} chunks; expected {expected_chunks}.")
181+
for chunk in stops:
182+
self.assertEqual(len(chunk), 1)
183+
self.assertEqual(chunk[0]['tags'], {'dummy': True})
184+
185+
# Restore original settings.
186+
enetm.MAX_BBOXES_PER_QUERY = original_max
187+
enetm.make_request_and_catch = original_make_request_and_catch
188+
189+
def test_get_predicted_transit_mode_different_sizes(self):
190+
"""
191+
Test get_predicted_transit_mode for different numbers of stops.
192+
For various sizes, simulate matching start and end stops and verify the expected matching routes.
193+
"""
194+
for size in [1, 3, 7, 20]:
195+
start_stops = []
196+
end_stops = []
197+
for i in range(size):
198+
route = ad.AttrDict({
199+
'id': i,
200+
'tags': {'route': 'train', 'ref': str(i)}
201+
})
202+
stop_start = ad.AttrDict({
203+
'id': i,
204+
'tags': {},
205+
'routes': [route]
206+
})
207+
stop_end = ad.AttrDict({
208+
'id': i,
209+
'tags': {},
210+
'routes': [route]
211+
})
212+
start_stops.append(stop_start)
213+
end_stops.append(stop_end)
214+
actual_result = enetm.get_predicted_transit_mode(start_stops, end_stops)
215+
expected_result = ['train'] * size
216+
self.assertEqual(actual_result, expected_result,
217+
msg=f"For {size} stops, expected {expected_result} but got {actual_result}.")
218+
219+
71220
if __name__ == '__main__':
72221
unittest.main()
73222

0 commit comments

Comments
 (0)