12
12
import emission .storage .decorations .analysis_timeseries_queries as esda
13
13
import emission .core .get_database as edb
14
14
import emission .analysis .modelling .trip_model .run_model as eamtr
15
+ import emission .tests .common as etc
15
16
16
17
class TestForestModelLoadandSave (unittest .TestCase ):
17
18
"""
18
19
Tests to make sure the model load and save properly
19
20
"""
20
-
21
21
def setUp (self ):
22
22
"""
23
23
sets up the end-to-end run model test with Confirmedtrip data
24
24
"""
25
- logging .basicConfig (format = '%(asctime)s:%(levelname)s:%(message)s' ,
26
- level = logging .DEBUG )
27
-
28
25
# configuration for randomly-generated test data
29
26
self .user_id = user_id = 'TestForestModelLoadAndSave-TestData'
30
27
self .origin = (- 105.1705977 , 39.7402654 ,)
@@ -38,278 +35,162 @@ def setUp(self):
38
35
39
36
self .unused_user_id = 'asdjfkl;asdfjkl;asd08234ur13fi4jhf2103mkl'
40
37
41
- # Ensuring that no previous test data was left in DB after teardown,
42
38
ts = esta .TimeSeries .get_time_series (user_id )
43
- test_data = list (ts .find_entries (["analysis/confirmed_trip" ]))
44
- if len (test_data ) == 0 :
45
- # generate test data for the database
46
- logging .debug (f"inserting mock Confirmedtrips into database" )
47
-
48
- # generate labels with a known sample weight that we can rely on in the test
49
- label_data = {
50
- "mode_confirm" : ['ebike' , 'bike' ],
51
- "purpose_confirm" : ['happy-hour' , 'dog-park' ],
52
- "replaced_mode" : ['walk' ],
53
- "mode_weights" : [0.9 , 0.1 ],
54
- "purpose_weights" : [0.1 , 0.9 ]
55
- }
56
-
57
- test_data = etmm .generate_mock_trips (
58
- user_id = user_id ,
59
- trips = self .total_trips ,
60
- origin = self .origin ,
61
- destination = self .destination ,
62
- trip_part = 'od' ,
63
- label_data = label_data ,
64
- within_threshold = self .clustered_trips ,
65
- threshold = 0.004 , # ~400m
66
- has_label_p = self .has_label_percent
67
- )
68
-
69
- for result_entry in test_data :
70
- result_entry ['data' ]['start_local_dt' ]= result_entry ['metadata' ]['write_local_dt' ]
71
- result_entry ['data' ]['end_local_dt' ]= result_entry ['metadata' ]['write_local_dt' ]
72
39
73
- ts .bulk_insert (test_data )
40
+ # generate labels with a known sample weight that we can rely on in the test
41
+ label_data = {
42
+ "mode_confirm" : ['ebike' , 'bike' ],
43
+ "purpose_confirm" : ['happy-hour' , 'dog-park' ],
44
+ "replaced_mode" : ['walk' ],
45
+ "mode_weights" : [0.9 , 0.1 ],
46
+ "purpose_weights" : [0.1 , 0.9 ]
47
+ }
48
+
49
+ # generate test data for the database
50
+ test_data = etmm .generate_mock_trips (
51
+ user_id = user_id ,
52
+ trips = self .total_trips ,
53
+ origin = self .origin ,
54
+ destination = self .destination ,
55
+ trip_part = 'od' ,
56
+ label_data = label_data ,
57
+ within_threshold = self .clustered_trips ,
58
+ threshold = 0.004 , # ~400m
59
+ has_label_p = self .has_label_percent
60
+ )
74
61
75
- # confirm data write did not fail
76
- test_data = esda .get_entries (key = "analysis/confirmed_trip" , user_id = user_id , time_query = None )
77
- if len (test_data ) != self .total_trips :
78
- logging .debug (f'test invariant failed after generating test data' )
79
- self .fail ()
80
- else :
81
- logging .debug (f'found { self .total_trips } trips in database' )
62
+ for result_entry in test_data :
63
+ result_entry ['data' ]['start_local_dt' ]= result_entry ['metadata' ]['write_local_dt' ]
64
+ result_entry ['data' ]['end_local_dt' ]= result_entry ['metadata' ]['write_local_dt' ]
82
65
83
- self .forest_model_config = eamtc .get_config_value_or_raise ('model_parameters.forest' )
84
-
85
- def tearDown (self ):
86
- """
87
- clean up database
88
- """
89
- edb .get_analysis_timeseries_db ().delete_many ({'user_id' : self .user_id })
90
- edb .get_model_db ().delete_many ({'user_id' : self .user_id })
91
- edb .get_pipeline_state_db ().delete_many ({'user_id' : self .user_id })
66
+ ts .bulk_insert (test_data )
92
67
93
- def testForestModelRoundTrip (self ):
94
- """
95
- RoundTripTest : Serialising an object with 'to_dict' and then immediately
96
- deserialize it with 'from_dict'. After deserialization, the object should have
97
- the same state as original
98
- """
68
+ self .forest_model_config = eamtc .get_config_value_or_raise ('model_parameters.forest' )
99
69
100
- # logging.debug(f'creating Random Forest model based on trips in database')
70
+ # Build and train model
71
+ logging .debug (f'(TRAIN) creating a model based on trips in database' )
101
72
eamur .update_trip_model (
102
73
user_id = self .user_id ,
103
74
model_type = eamumt .ModelType .RANDOM_FOREST_CLASSIFIER ,
104
75
model_storage = eamums .ModelStorage .DOCUMENT_DATABASE ,
105
76
min_trips = self .min_trips ,
106
77
model_config = self .forest_model_config
107
78
)
108
-
109
- model = eamur ._load_stored_trip_model (
79
+
80
+ self . model = eamur ._load_stored_trip_model (
110
81
user_id = self .user_id ,
111
82
model_type = eamumt .ModelType .RANDOM_FOREST_CLASSIFIER ,
112
83
model_storage = eamums .ModelStorage .DOCUMENT_DATABASE ,
113
84
model_config = self .forest_model_config
114
- )
85
+ )
86
+
87
+ def tearDown (self ):
88
+ etc .dropAllCollections (edb ._get_current_db ())
115
89
116
- # logging.debug(f'Loading test data')
117
- test = esda .get_entries (key = "analysis/confirmed_trip" , user_id = self .user_id , time_query = None )
90
+ def testForestModelPredictionsEquality (self ):
91
+ """
92
+ EqualityTest : Serialising an object with 'to_dict' and then immediately
93
+ deserialize it with 'from_dict'. After deserialization, the object should have
94
+ the same state as original
118
95
119
- # logging.debug(f'Predictions on trips in database')
96
+ TypePreservationTest: To ensure that the serialization and deserialization
97
+ process maintains the data types of all model attributes.
98
+ The type of deserialized model attributes and the predictions of this must match
99
+ those of initial model.
100
+ """
101
+ test_trip_data = esda .get_entries (key = esda .CONFIRMED_TRIP_KEY , user_id = self .user_id , time_query = None )
120
102
121
103
predictions_list = eamur .predict_labels_with_n (
122
- trip_list = test ,
123
- model = model
104
+ trip_list = test_trip_data ,
105
+ model = self . model
124
106
)
125
107
126
- # logging.debug(f'Serialising the model ')
127
-
128
- model_data = model .to_dict ()
129
-
130
- # logging.debug(f'Deserialising the model')
131
-
132
-
108
+ model_data = self .model .to_dict ()
133
109
deserialized_model_type = eamumt .ModelType .RANDOM_FOREST_CLASSIFIER
134
110
deserialized_model = deserialized_model_type .build (self .forest_model_config )
135
111
deserialized_model .from_dict (model_data )
136
112
137
- # logging.debug(f'Predictions on trips using deserialised model')
138
- predictions_loaded_model_list = eamur .predict_labels_with_n (
139
- trip_list = test ,
140
- model = deserialized_model
113
+ predictions_deserialized_model_list = eamur .predict_labels_with_n (
114
+ trip_list = test_trip_data ,
115
+ model = deserialized_model
141
116
)
142
- # logging.debug(f'Assert that both predictions are the same')
143
- self .assertEqual (predictions_list , predictions_loaded_model_list , " should be equal" )
117
+
118
+ # Test if the types are correct
119
+ for attr in ['purpose_predictor' ,'mode_predictor' ,'replaced_predictor' ,'purpose_enc' ,'mode_enc' ,'train_df' ]:
120
+ deserialized_attr_value = getattr (deserialized_model .model ,attr )
121
+ original_attr_value = getattr (self .model .model ,attr )
122
+ # Check type preservation
123
+ self .assertIsInstance (deserialized_attr_value ,type (original_attr_value ), f"Type mismatch for { attr } " )
124
+
125
+ # Test if the values are the same
126
+ self .assertEqual (predictions_list , predictions_deserialized_model_list , " should be equal" )
144
127
145
128
def testForestModelConsistency (self ):
146
129
"""
147
130
ConsistencyTest : To Verify that the serialization and deserialization process
148
131
is consistent across multiple executions
149
132
"""
150
- # logging.debug(f'creating a model based on trips in database')
151
- eamur .update_trip_model (
152
- user_id = self .user_id ,
153
- model_type = eamumt .ModelType .RANDOM_FOREST_CLASSIFIER ,
154
- model_storage = eamums .ModelStorage .DOCUMENT_DATABASE ,
155
- min_trips = self .min_trips ,
156
- model_config = self .forest_model_config
157
- )
158
-
159
- model_iter1 = eamur ._load_stored_trip_model (
160
- user_id = self .user_id ,
161
- model_type = eamumt .ModelType .RANDOM_FOREST_CLASSIFIER ,
162
- model_storage = eamums .ModelStorage .DOCUMENT_DATABASE ,
163
- model_config = self .forest_model_config
164
- )
165
-
166
- # logging.debug(f'Load Test data')
167
- test = esda .get_entries (key = "analysis/confirmed_trip" , user_id = self .user_id , time_query = None )
168
-
169
- # logging.debug(f' Model Predictions on trips in database')
133
+ test_trip_data = esda .get_entries (key = esda .CONFIRMED_TRIP_KEY , user_id = self .user_id , time_query = None )
170
134
171
135
predictions_list_model1 = eamur .predict_labels_with_n (
172
- trip_list = test ,
173
- model = model_iter1
136
+ trip_list = test_trip_data ,
137
+ model = self . model
174
138
)
175
- # logging.debug(f' Loading Model again')
176
139
177
140
model_iter2 = eamur ._load_stored_trip_model (
178
141
user_id = self .user_id ,
179
142
model_type = eamumt .ModelType .RANDOM_FOREST_CLASSIFIER ,
180
143
model_storage = eamums .ModelStorage .DOCUMENT_DATABASE ,
181
144
model_config = self .forest_model_config
182
- )
183
- # logging.debug(f' Model Predictions on trips in database')
145
+ )
146
+
184
147
predictions_list_model2 = eamur .predict_labels_with_n (
185
- trip_list = test ,
148
+ trip_list = test_trip_data ,
186
149
model = model_iter2
187
150
)
188
-
189
- self .assertEqual (predictions_list_model1 , predictions_list_model2 , " should be equal" )
190
-
191
151
152
+ self .assertEqual (predictions_list_model1 , predictions_list_model2 , " should be equal" )
192
153
193
- def testSerializationErrorHandling (self ):
154
+ def testSerializationDeserializationErrorHandling (self ):
194
155
"""
195
- SerialisationErrorHandling : To verify that any errors during
156
+ SerializationErrorHandling : To verify that any errors during
196
157
serialising an object with 'to_dict' are handled.
158
+
159
+ DeserializationErrorHandling : To verify that any errors during
160
+ deserialising an object with 'from_dict' are handled.
197
161
"""
198
- # defining a side effect function to simulate a serialization error
162
+ # Test 1: SerializationErrorHandling
163
+ # Defining a side effect function to simulate a serialization error
199
164
def mock_dump (* args ,** kwargs ):
200
165
raise Exception ("Serialization Error" )
201
166
202
- logging .debug (f'(TRAIN) creating a model based on trips in database' )
203
- eamur .update_trip_model (
204
- user_id = self .user_id ,
205
- model_type = eamumt .ModelType .RANDOM_FOREST_CLASSIFIER ,
206
- model_storage = eamums .ModelStorage .DOCUMENT_DATABASE ,
207
- min_trips = self .min_trips ,
208
- model_config = self .forest_model_config
209
- )
210
-
211
- model = eamur ._load_stored_trip_model (
212
- user_id = self .user_id ,
213
- model_type = eamumt .ModelType .RANDOM_FOREST_CLASSIFIER ,
214
- model_storage = eamums .ModelStorage .DOCUMENT_DATABASE ,
215
- model_config = self .forest_model_config
216
- )
217
167
# patch is used to temporarily replace joblib.dump with a
218
168
# mock function that raises an exception
219
169
#
220
170
# side_effect, which is set to mock_dump, is called instead of
221
171
# real joblib.dump function when 'to_dict' is invoked
222
-
223
172
with um .patch ('joblib.dump' ,side_effect = mock_dump ):
224
173
with self .assertRaises (RuntimeError ):
225
- model .to_dict ()
226
-
174
+ self .model .to_dict ()
227
175
228
- def testDeserializationErrorHandling (self ):
229
- """
230
- deserialisationErrorHandling : To verify that any errors during
231
- deserialising an object with 'from_dict' are handled.
232
- """
233
- # defining a side effect function to simulate a deserialization error
176
+ # Test 2: DeserializationErrorHandling
177
+ # Defining a side effect function to simulate a deserialization error
234
178
def mock_load (* args ,** kwargs ):
235
179
raise Exception ("Deserialization Error" )
236
-
237
- logging .debug (f'(TRAIN) creating a model based on trips in database' )
238
- eamur .update_trip_model (
239
- user_id = self .user_id ,
240
- model_type = eamumt .ModelType .RANDOM_FOREST_CLASSIFIER ,
241
- model_storage = eamums .ModelStorage .DOCUMENT_DATABASE ,
242
- min_trips = self .min_trips ,
243
- model_config = self .forest_model_config
244
- )
245
-
246
- model = eamur ._load_stored_trip_model (
247
- user_id = self .user_id ,
248
- model_type = eamumt .ModelType .RANDOM_FOREST_CLASSIFIER ,
249
- model_storage = eamums .ModelStorage .DOCUMENT_DATABASE ,
250
- model_config = self .forest_model_config
251
- )
252
-
253
- model_data = model .to_dict ()
254
180
181
+ model_data = self .model .to_dict ()
255
182
deserialized_model_type = eamumt .ModelType .RANDOM_FOREST_CLASSIFIER
256
183
deserialized_model = deserialized_model_type .build (self .forest_model_config )
184
+
257
185
# patch is used to temporarily replace joblib.load with a
258
186
# mock function that raises an exception
259
187
#
260
188
# side_effect, which is set to mock_load, is called instead of
261
189
# real joblib.load function when 'to_dict' is invoked
262
-
263
190
with um .patch ('joblib.load' ,side_effect = mock_load ):
264
191
with self .assertRaises (RuntimeError ):
265
192
deserialized_model .from_dict (model_data )
266
193
267
-
268
- def testRandomForestTypePreservation (self ):
269
- """
270
- TypePreservationTest: To ensure that the serialization and deserialization
271
- process maintains the data types of all model attributes.
272
- The model is trained, preditions stored, serialised and then desserialized.
273
- The type of deserialised model attributes and the predictions of this must match
274
- those of initial model.
275
- """
276
- ct_entry = eamtr ._get_training_data (self .user_id ,None )
277
- split = int (len (ct_entry )* 0.8 )
278
- trips = ct_entry [:split ]
279
- test_trips = ct_entry [split :]
280
-
281
- ## Build and train model
282
- model_type = eamumt .ModelType .RANDOM_FOREST_CLASSIFIER
283
- model = model_type .build (self .forest_model_config )
284
- model .fit (trips )
285
-
286
- ## Get pre serialization predictions
287
- predictions_list = eamur .predict_labels_with_n (
288
- trip_list = test_trips ,
289
- model = model
290
- )
291
-
292
- ## Serialise
293
- serialised_model_data = model .to_dict ()
294
-
295
- ## build and deserialise a different model
296
- deserialised_model = model_type .build (self .forest_model_config )
297
- deserialised_model .from_dict (serialised_model_data )
298
-
299
- ## test if the types are correct
300
- for attr in ['purpose_predictor' ,'mode_predictor' ,'replaced_predictor' ,'purpose_enc' ,'mode_enc' ,'train_df' ]:
301
- deSerialised_attr_value = getattr (deserialised_model .model ,attr )
302
- original_attr_value = getattr (model .model ,attr )
303
- #Check type preservation
304
- self .assertIsInstance (deSerialised_attr_value ,type (original_attr_value ), f"Type mismatch for { attr } " )
305
- #Check for value equality. This assumes that the attributes are either direc
306
-
307
- ## test if the predictions are correct
308
- deserialised_predictions_list = eamur .predict_labels_with_n (
309
- trip_list = test_trips ,
310
- model = deserialised_model
311
- )
312
- logging .debug (f'TESTIN:{ deserialised_predictions_list } ' )
313
- logging .debug (f'{ predictions_list } ' )
314
- self .assertEqual (deserialised_predictions_list ,predictions_list ,'predictions list not same.' )
315
-
194
+ if __name__ == '__main__' :
195
+ etc .configLogging ()
196
+ unittest .main ()
0 commit comments