|
10 | 10 | import requests
|
11 | 11 | import emission.net.ext_service.transit_matching.match_stops as enetm
|
12 | 12 | import logging
|
| 13 | +import attrdict as ad |
| 14 | +import math |
13 | 15 |
|
14 | 16 | #Set up query
|
15 | 17 | GEOFABRIK_OVERPASS_KEY = os.environ.get("GEOFABRIK_OVERPASS_KEY")
|
@@ -68,6 +70,153 @@ def test_get_predicted_transit_mode(self):
|
68 | 70 | expected_result = ['train', 'train']
|
69 | 71 | self.assertEqual(actual_result, expected_result)
|
70 | 72 |
|
| 73 | + def test_chunk_list(self): |
| 74 | + # Case 1: List of 10 elements with chunk size of 3. |
| 75 | + data = list(range(1, 11)) # [1, 2, ..., 10] |
| 76 | + chunk_size = 3 |
| 77 | + chunks = list(enetm.chunk_list(data, chunk_size)) |
| 78 | + expected_chunks = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]] |
| 79 | + self.assertEqual(chunks, expected_chunks) |
| 80 | + |
| 81 | + # Case 2: Exact division |
| 82 | + data_exact = list(range(1, 10)) # [1, 2, ..., 9] |
| 83 | + chunk_size = 3 |
| 84 | + chunks_exact = list(enetm.chunk_list(data_exact, chunk_size)) |
| 85 | + expected_chunks_exact = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] |
| 86 | + self.assertEqual(chunks_exact, expected_chunks_exact) |
| 87 | + |
| 88 | + # Case 3: Empty list |
| 89 | + data_empty = [] |
| 90 | + chunks_empty = list(enetm.chunk_list(data_empty, chunk_size)) |
| 91 | + self.assertEqual(chunks_empty, []) |
| 92 | + |
| 93 | + def test_get_stops_near_many_chunks(self): |
| 94 | + """ |
| 95 | + Test get_stops_near with many chunks. |
| 96 | + Override MAX_BBOXES_PER_QUERY to 1 so that each coordinate produces its own chunk. |
| 97 | + Supply 20 dummy coordinates and verify that 20 chunks are returned. |
| 98 | + """ |
| 99 | + original_max = enetm.MAX_BBOXES_PER_QUERY |
| 100 | + enetm.MAX_BBOXES_PER_QUERY = 1 |
| 101 | + |
| 102 | + # Create 20 dummy coordinates ([lon, lat]). |
| 103 | + coords = [[i, i + 0.5] for i in range(20)] |
| 104 | + |
| 105 | + # Patch make_request_and_catch to return a dummy response. |
| 106 | + original_make_request_and_catch = enetm.make_request_and_catch |
| 107 | + def dummy_make_request_and_catch(query): |
| 108 | + # Return a dummy response: one dummy node and a "count" marker. |
| 109 | + return [{'id': 100, 'type': 'node', 'tags': {'dummy': True}}, |
| 110 | + {'type': 'count'}] |
| 111 | + enetm.make_request_and_catch = dummy_make_request_and_catch |
| 112 | + |
| 113 | + stops = enetm.get_stops_near(coords, 150.0) |
| 114 | + # Expect one chunk per coordinate = 20 chunks. |
| 115 | + self.assertEqual(len(stops), 20) |
| 116 | + for chunk in stops: |
| 117 | + # Each chunk (from the dummy response) should contain one stop. |
| 118 | + self.assertEqual(len(chunk), 1) |
| 119 | + self.assertEqual(chunk[0]['tags'], {'dummy': True}) |
| 120 | + |
| 121 | + # Restore original settings. |
| 122 | + enetm.MAX_BBOXES_PER_QUERY = original_max |
| 123 | + enetm.make_request_and_catch = original_make_request_and_catch |
| 124 | + |
| 125 | + def test_get_predicted_transit_mode_many_chunks(self): |
| 126 | + """ |
| 127 | + Test get_predicted_transit_mode when provided with many stops. |
| 128 | + Simulate two sets (start and end) of 20 stops each, where each stop carries |
| 129 | + a unique route (with matching ids across both sets). Expect one matching route per stop. |
| 130 | + """ |
| 131 | + start_stops = [] |
| 132 | + end_stops = [] |
| 133 | + for i in range(20): |
| 134 | + # Create a dummy route with a unique id and route "train". |
| 135 | + # Include a "ref" key to avoid an AttributeError. |
| 136 | + route = ad.AttrDict({ |
| 137 | + 'id': i, |
| 138 | + 'tags': {'route': 'train', 'ref': str(i)} |
| 139 | + }) |
| 140 | + stop_start = ad.AttrDict({ |
| 141 | + 'id': i, |
| 142 | + 'tags': {}, |
| 143 | + 'routes': [route] |
| 144 | + }) |
| 145 | + stop_end = ad.AttrDict({ |
| 146 | + 'id': i, |
| 147 | + 'tags': {}, |
| 148 | + 'routes': [route] |
| 149 | + }) |
| 150 | + start_stops.append(stop_start) |
| 151 | + end_stops.append(stop_end) |
| 152 | + |
| 153 | + actual_result = enetm.get_predicted_transit_mode(start_stops, end_stops) |
| 154 | + expected_result = ['train'] * 20 |
| 155 | + self.assertEqual(actual_result, expected_result) |
| 156 | + |
| 157 | + def test_get_stops_near_different_batch_sizes(self): |
| 158 | + """ |
| 159 | + Test get_stops_near using varying batch sizes. |
| 160 | + For each batch size, override MAX_BBOXES_PER_QUERY and supply a fixed list of dummy |
| 161 | + coordinates. Verify that the number of returned chunks equals ceil(total_coords / batch_size). |
| 162 | + """ |
| 163 | + original_max = enetm.MAX_BBOXES_PER_QUERY |
| 164 | + original_make_request_and_catch = enetm.make_request_and_catch |
| 165 | + |
| 166 | + # Create 7 dummy coordinates. |
| 167 | + coords = [[i, i + 0.5] for i in range(7)] |
| 168 | + |
| 169 | + # Dummy response: one dummy node and a "count" marker. |
| 170 | + def dummy_make_request_and_catch(query): |
| 171 | + return [{'id': 100, 'type': 'node', 'tags': {'dummy': True}}, |
| 172 | + {'type': 'count'}] |
| 173 | + enetm.make_request_and_catch = dummy_make_request_and_catch |
| 174 | + |
| 175 | + for batch_size in [1, 2, 5, 10]: |
| 176 | + enetm.MAX_BBOXES_PER_QUERY = batch_size |
| 177 | + stops = enetm.get_stops_near(coords, 150.0) |
| 178 | + expected_chunks = math.ceil(len(coords) / batch_size) |
| 179 | + self.assertEqual(len(stops), expected_chunks, |
| 180 | + msg=f"Batch size {batch_size} produced {len(stops)} chunks; expected {expected_chunks}.") |
| 181 | + for chunk in stops: |
| 182 | + self.assertEqual(len(chunk), 1) |
| 183 | + self.assertEqual(chunk[0]['tags'], {'dummy': True}) |
| 184 | + |
| 185 | + # Restore original settings. |
| 186 | + enetm.MAX_BBOXES_PER_QUERY = original_max |
| 187 | + enetm.make_request_and_catch = original_make_request_and_catch |
| 188 | + |
| 189 | + def test_get_predicted_transit_mode_different_sizes(self): |
| 190 | + """ |
| 191 | + Test get_predicted_transit_mode for different numbers of stops. |
| 192 | + For various sizes, simulate matching start and end stops and verify the expected matching routes. |
| 193 | + """ |
| 194 | + for size in [1, 3, 7, 20]: |
| 195 | + start_stops = [] |
| 196 | + end_stops = [] |
| 197 | + for i in range(size): |
| 198 | + route = ad.AttrDict({ |
| 199 | + 'id': i, |
| 200 | + 'tags': {'route': 'train', 'ref': str(i)} |
| 201 | + }) |
| 202 | + stop_start = ad.AttrDict({ |
| 203 | + 'id': i, |
| 204 | + 'tags': {}, |
| 205 | + 'routes': [route] |
| 206 | + }) |
| 207 | + stop_end = ad.AttrDict({ |
| 208 | + 'id': i, |
| 209 | + 'tags': {}, |
| 210 | + 'routes': [route] |
| 211 | + }) |
| 212 | + start_stops.append(stop_start) |
| 213 | + end_stops.append(stop_end) |
| 214 | + actual_result = enetm.get_predicted_transit_mode(start_stops, end_stops) |
| 215 | + expected_result = ['train'] * size |
| 216 | + self.assertEqual(actual_result, expected_result, |
| 217 | + msg=f"For {size} stops, expected {expected_result} but got {actual_result}.") |
| 218 | + |
| 219 | + |
71 | 220 | if __name__ == '__main__':
|
72 | 221 | unittest.main()
|
73 | 222 |
|
0 commit comments