Skip to content

Commit e7f5d21

Browse files
Mahadik, Mukul ChandrakantMahadik, Mukul Chandrakant
authored andcommitted
Cleaned up TestForestModelLoadandSave.py
a. Removed check for remaining test data from previous test runs. - This should not be possible if data is cleared correctly in tearDown(). - Improved database clearing in tearDown() just to be sure. b. Moved model build to setup() since all tests need this - I did see Shankari's comment stating that model building is a heavyweight process (e-mission@104dd9a#r1486605432) - But it is anyways required by all tests and moving it to setup helps reduce duplicate code. c. Merged EqualityTest with Type Preservation test - Shankari had left a comment to check for values versus checking for types (e-mission#938 (comment)). - Satyam had added changes to check the predictions list after serialization and deserialization respectively. - However this equality test was already being done in a previous test. - Hence merged these two. d. Merged Serialization and Deserialization error handling test. - These tests were identical and mock functions were being used to assert raised exceptions. - Merged these as well. Merging tests helps reduce the number of types we have to build the model as for all tests the common steps involve building model and fetching predictions.
1 parent 419babb commit e7f5d21

File tree

1 file changed

+84
-203
lines changed

1 file changed

+84
-203
lines changed

emission/tests/modellingTests/TestForestModelLoadandSave.py

Lines changed: 84 additions & 203 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,16 @@
1212
import emission.storage.decorations.analysis_timeseries_queries as esda
1313
import emission.core.get_database as edb
1414
import emission.analysis.modelling.trip_model.run_model as eamtr
15+
import emission.tests.common as etc
1516

1617
class TestForestModelLoadandSave(unittest.TestCase):
1718
"""
1819
Tests to make sure the model load and save properly
1920
"""
20-
2121
def setUp(self):
2222
"""
2323
sets up the end-to-end run model test with Confirmedtrip data
2424
"""
25-
logging.basicConfig(format='%(asctime)s:%(levelname)s:%(message)s',
26-
level=logging.DEBUG)
27-
2825
# configuration for randomly-generated test data
2926
self.user_id = user_id = 'TestForestModelLoadAndSave-TestData'
3027
self.origin = (-105.1705977, 39.7402654,)
@@ -38,278 +35,162 @@ def setUp(self):
3835

3936
self.unused_user_id = 'asdjfkl;asdfjkl;asd08234ur13fi4jhf2103mkl'
4037

41-
# Ensuring that no previous test data was left in DB after teardown,
4238
ts = esta.TimeSeries.get_time_series(user_id)
43-
test_data = list(ts.find_entries(["analysis/confirmed_trip"]))
44-
if len(test_data) == 0:
45-
# generate test data for the database
46-
logging.debug(f"inserting mock Confirmedtrips into database")
47-
48-
# generate labels with a known sample weight that we can rely on in the test
49-
label_data = {
50-
"mode_confirm": ['ebike', 'bike'],
51-
"purpose_confirm": ['happy-hour', 'dog-park'],
52-
"replaced_mode": ['walk'],
53-
"mode_weights": [0.9, 0.1],
54-
"purpose_weights": [0.1, 0.9]
55-
}
56-
57-
test_data = etmm.generate_mock_trips(
58-
user_id=user_id,
59-
trips=self.total_trips,
60-
origin=self.origin,
61-
destination=self.destination,
62-
trip_part='od',
63-
label_data=label_data,
64-
within_threshold=self.clustered_trips,
65-
threshold=0.004, # ~400m
66-
has_label_p=self.has_label_percent
67-
)
68-
69-
for result_entry in test_data:
70-
result_entry['data']['start_local_dt']=result_entry['metadata']['write_local_dt']
71-
result_entry['data']['end_local_dt']=result_entry['metadata']['write_local_dt']
7239

73-
ts.bulk_insert(test_data)
40+
# generate labels with a known sample weight that we can rely on in the test
41+
label_data = {
42+
"mode_confirm": ['ebike', 'bike'],
43+
"purpose_confirm": ['happy-hour', 'dog-park'],
44+
"replaced_mode": ['walk'],
45+
"mode_weights": [0.9, 0.1],
46+
"purpose_weights": [0.1, 0.9]
47+
}
48+
49+
# generate test data for the database
50+
test_data = etmm.generate_mock_trips(
51+
user_id=user_id,
52+
trips=self.total_trips,
53+
origin=self.origin,
54+
destination=self.destination,
55+
trip_part='od',
56+
label_data=label_data,
57+
within_threshold=self.clustered_trips,
58+
threshold=0.004, # ~400m
59+
has_label_p=self.has_label_percent
60+
)
7461

75-
# confirm data write did not fail
76-
test_data = esda.get_entries(key="analysis/confirmed_trip", user_id=user_id, time_query=None)
77-
if len(test_data) != self.total_trips:
78-
logging.debug(f'test invariant failed after generating test data')
79-
self.fail()
80-
else:
81-
logging.debug(f'found {self.total_trips} trips in database')
62+
for result_entry in test_data:
63+
result_entry['data']['start_local_dt']=result_entry['metadata']['write_local_dt']
64+
result_entry['data']['end_local_dt']=result_entry['metadata']['write_local_dt']
8265

83-
self.forest_model_config= eamtc.get_config_value_or_raise('model_parameters.forest')
84-
85-
def tearDown(self):
86-
"""
87-
clean up database
88-
"""
89-
edb.get_analysis_timeseries_db().delete_many({'user_id': self.user_id})
90-
edb.get_model_db().delete_many({'user_id': self.user_id})
91-
edb.get_pipeline_state_db().delete_many({'user_id': self.user_id})
66+
ts.bulk_insert(test_data)
9267

93-
def testForestModelRoundTrip(self):
94-
"""
95-
RoundTripTest : Serialising an object with 'to_dict' and then immediately
96-
deserialize it with 'from_dict'. After deserialization, the object should have
97-
the same state as original
98-
"""
68+
self.forest_model_config= eamtc.get_config_value_or_raise('model_parameters.forest')
9969

100-
# logging.debug(f'creating Random Forest model based on trips in database')
70+
# Build and train model
71+
logging.debug(f'(TRAIN) creating a model based on trips in database')
10172
eamur.update_trip_model(
10273
user_id=self.user_id,
10374
model_type=eamumt.ModelType.RANDOM_FOREST_CLASSIFIER,
10475
model_storage=eamums.ModelStorage.DOCUMENT_DATABASE,
10576
min_trips=self.min_trips,
10677
model_config=self.forest_model_config
10778
)
108-
109-
model = eamur._load_stored_trip_model(
79+
80+
self.model = eamur._load_stored_trip_model(
11081
user_id=self.user_id,
11182
model_type=eamumt.ModelType.RANDOM_FOREST_CLASSIFIER,
11283
model_storage=eamums.ModelStorage.DOCUMENT_DATABASE,
11384
model_config=self.forest_model_config
114-
)
85+
)
86+
87+
def tearDown(self):
88+
etc.dropAllCollections(edb._get_current_db())
11589

116-
# logging.debug(f'Loading test data')
117-
test = esda.get_entries(key="analysis/confirmed_trip", user_id=self.user_id, time_query=None)
90+
def testForestModelPredictionsEquality(self):
91+
"""
92+
EqualityTest : Serialising an object with 'to_dict' and then immediately
93+
deserialize it with 'from_dict'. After deserialization, the object should have
94+
the same state as original
11895
119-
# logging.debug(f'Predictions on trips in database')
96+
TypePreservationTest: To ensure that the serialization and deserialization
97+
process maintains the data types of all model attributes.
98+
The type of deserialized model attributes and the predictions of this must match
99+
those of initial model.
100+
"""
101+
test_trip_data = esda.get_entries(key=esda.CONFIRMED_TRIP_KEY, user_id=self.user_id, time_query=None)
120102

121103
predictions_list = eamur.predict_labels_with_n(
122-
trip_list = test,
123-
model=model
104+
trip_list = test_trip_data,
105+
model=self.model
124106
)
125107

126-
# logging.debug(f'Serialising the model ')
127-
128-
model_data=model.to_dict()
129-
130-
# logging.debug(f'Deserialising the model')
131-
132-
108+
model_data=self.model.to_dict()
133109
deserialized_model_type=eamumt.ModelType.RANDOM_FOREST_CLASSIFIER
134110
deserialized_model = deserialized_model_type.build(self.forest_model_config)
135111
deserialized_model.from_dict(model_data)
136112

137-
# logging.debug(f'Predictions on trips using deserialised model')
138-
predictions_loaded_model_list = eamur.predict_labels_with_n(
139-
trip_list = test,
140-
model=deserialized_model
113+
predictions_deserialized_model_list = eamur.predict_labels_with_n(
114+
trip_list = test_trip_data,
115+
model=deserialized_model
141116
)
142-
# logging.debug(f'Assert that both predictions are the same')
143-
self.assertEqual(predictions_list, predictions_loaded_model_list, " should be equal")
117+
118+
# Test if the types are correct
119+
for attr in ['purpose_predictor','mode_predictor','replaced_predictor','purpose_enc','mode_enc','train_df']:
120+
deserialized_attr_value=getattr(deserialized_model.model,attr)
121+
original_attr_value=getattr(self.model.model,attr)
122+
# Check type preservation
123+
self.assertIsInstance(deserialized_attr_value,type(original_attr_value), f"Type mismatch for {attr} ")
124+
125+
# Test if the values are the same
126+
self.assertEqual(predictions_list, predictions_deserialized_model_list, " should be equal")
144127

145128
def testForestModelConsistency(self):
146129
"""
147130
ConsistencyTest : To Verify that the serialization and deserialization process
148131
is consistent across multiple executions
149132
"""
150-
# logging.debug(f'creating a model based on trips in database')
151-
eamur.update_trip_model(
152-
user_id=self.user_id,
153-
model_type=eamumt.ModelType.RANDOM_FOREST_CLASSIFIER,
154-
model_storage=eamums.ModelStorage.DOCUMENT_DATABASE,
155-
min_trips=self.min_trips,
156-
model_config=self.forest_model_config
157-
)
158-
159-
model_iter1 = eamur._load_stored_trip_model(
160-
user_id=self.user_id,
161-
model_type=eamumt.ModelType.RANDOM_FOREST_CLASSIFIER,
162-
model_storage=eamums.ModelStorage.DOCUMENT_DATABASE,
163-
model_config=self.forest_model_config
164-
)
165-
166-
# logging.debug(f'Load Test data')
167-
test = esda.get_entries(key="analysis/confirmed_trip", user_id=self.user_id, time_query=None)
168-
169-
# logging.debug(f' Model Predictions on trips in database')
133+
test_trip_data = esda.get_entries(key=esda.CONFIRMED_TRIP_KEY, user_id=self.user_id, time_query=None)
170134

171135
predictions_list_model1 = eamur.predict_labels_with_n(
172-
trip_list = test,
173-
model=model_iter1
136+
trip_list = test_trip_data,
137+
model=self.model
174138
)
175-
# logging.debug(f' Loading Model again')
176139

177140
model_iter2 = eamur._load_stored_trip_model(
178141
user_id=self.user_id,
179142
model_type=eamumt.ModelType.RANDOM_FOREST_CLASSIFIER,
180143
model_storage=eamums.ModelStorage.DOCUMENT_DATABASE,
181144
model_config=self.forest_model_config
182-
)
183-
# logging.debug(f' Model Predictions on trips in database')
145+
)
146+
184147
predictions_list_model2 = eamur.predict_labels_with_n(
185-
trip_list = test,
148+
trip_list = test_trip_data,
186149
model=model_iter2
187150
)
188-
189-
self.assertEqual(predictions_list_model1, predictions_list_model2, " should be equal")
190-
191151

152+
self.assertEqual(predictions_list_model1, predictions_list_model2, " should be equal")
192153

193-
def testSerializationErrorHandling(self):
154+
def testSerializationDeserializationErrorHandling(self):
194155
"""
195-
SerialisationErrorHandling : To verify that any errors during
156+
SerializationErrorHandling : To verify that any errors during
196157
serialising an object with 'to_dict' are handled.
158+
159+
DeserializationErrorHandling : To verify that any errors during
160+
deserialising an object with 'from_dict' are handled.
197161
"""
198-
# defining a side effect function to simulate a serialization error
162+
# Test 1: SerializationErrorHandling
163+
# Defining a side effect function to simulate a serialization error
199164
def mock_dump(*args,**kwargs):
200165
raise Exception("Serialization Error")
201166

202-
logging.debug(f'(TRAIN) creating a model based on trips in database')
203-
eamur.update_trip_model(
204-
user_id=self.user_id,
205-
model_type=eamumt.ModelType.RANDOM_FOREST_CLASSIFIER,
206-
model_storage=eamums.ModelStorage.DOCUMENT_DATABASE,
207-
min_trips=self.min_trips,
208-
model_config=self.forest_model_config
209-
)
210-
211-
model = eamur._load_stored_trip_model(
212-
user_id=self.user_id,
213-
model_type=eamumt.ModelType.RANDOM_FOREST_CLASSIFIER,
214-
model_storage=eamums.ModelStorage.DOCUMENT_DATABASE,
215-
model_config=self.forest_model_config
216-
)
217167
# patch is used to temporarily replace joblib.dump with a
218168
# mock function that raises an exception
219169
#
220170
# side_effect, which is set to mock_dump, is called instead of
221171
# real joblib.dump function when 'to_dict' is invoked
222-
223172
with um.patch('joblib.dump',side_effect=mock_dump):
224173
with self.assertRaises(RuntimeError):
225-
model.to_dict()
226-
174+
self.model.to_dict()
227175

228-
def testDeserializationErrorHandling(self):
229-
"""
230-
deserialisationErrorHandling : To verify that any errors during
231-
deserialising an object with 'from_dict' are handled.
232-
"""
233-
# defining a side effect function to simulate a deserialization error
176+
# Test 2: DeserializationErrorHandling
177+
# Defining a side effect function to simulate a deserialization error
234178
def mock_load(*args,**kwargs):
235179
raise Exception("Deserialization Error")
236-
237-
logging.debug(f'(TRAIN) creating a model based on trips in database')
238-
eamur.update_trip_model(
239-
user_id=self.user_id,
240-
model_type=eamumt.ModelType.RANDOM_FOREST_CLASSIFIER,
241-
model_storage=eamums.ModelStorage.DOCUMENT_DATABASE,
242-
min_trips=self.min_trips,
243-
model_config=self.forest_model_config
244-
)
245-
246-
model = eamur._load_stored_trip_model(
247-
user_id=self.user_id,
248-
model_type=eamumt.ModelType.RANDOM_FOREST_CLASSIFIER,
249-
model_storage=eamums.ModelStorage.DOCUMENT_DATABASE,
250-
model_config=self.forest_model_config
251-
)
252-
253-
model_data=model.to_dict()
254180

181+
model_data=self.model.to_dict()
255182
deserialized_model_type=eamumt.ModelType.RANDOM_FOREST_CLASSIFIER
256183
deserialized_model = deserialized_model_type.build(self.forest_model_config)
184+
257185
# patch is used to temporarily replace joblib.load with a
258186
# mock function that raises an exception
259187
#
260188
# side_effect, which is set to mock_load, is called instead of
261189
# real joblib.load function when 'to_dict' is invoked
262-
263190
with um.patch('joblib.load',side_effect=mock_load):
264191
with self.assertRaises(RuntimeError):
265192
deserialized_model.from_dict(model_data)
266193

267-
268-
def testRandomForestTypePreservation(self):
269-
"""
270-
TypePreservationTest: To ensure that the serialization and deserialization
271-
process maintains the data types of all model attributes.
272-
The model is trained, preditions stored, serialised and then desserialized.
273-
The type of deserialised model attributes and the predictions of this must match
274-
those of initial model.
275-
"""
276-
ct_entry=eamtr._get_training_data(self.user_id,None)
277-
split= int(len(ct_entry)*0.8)
278-
trips=ct_entry[:split]
279-
test_trips=ct_entry[split:]
280-
281-
## Build and train model
282-
model_type= eamumt.ModelType.RANDOM_FOREST_CLASSIFIER
283-
model = model_type.build(self.forest_model_config)
284-
model.fit(trips)
285-
286-
## Get pre serialization predictions
287-
predictions_list = eamur.predict_labels_with_n(
288-
trip_list = test_trips,
289-
model=model
290-
)
291-
292-
## Serialise
293-
serialised_model_data=model.to_dict()
294-
295-
## build and deserialise a different model
296-
deserialised_model = model_type.build(self.forest_model_config)
297-
deserialised_model.from_dict(serialised_model_data)
298-
299-
## test if the types are correct
300-
for attr in ['purpose_predictor','mode_predictor','replaced_predictor','purpose_enc','mode_enc','train_df']:
301-
deSerialised_attr_value=getattr(deserialised_model.model,attr)
302-
original_attr_value=getattr(model.model,attr)
303-
#Check type preservation
304-
self.assertIsInstance(deSerialised_attr_value,type(original_attr_value), f"Type mismatch for {attr} ")
305-
#Check for value equality. This assumes that the attributes are either direc
306-
307-
## test if the predictions are correct
308-
deserialised_predictions_list = eamur.predict_labels_with_n(
309-
trip_list = test_trips,
310-
model=deserialised_model
311-
)
312-
logging.debug(f'TESTIN:{deserialised_predictions_list}')
313-
logging.debug(f'{predictions_list}')
314-
self.assertEqual(deserialised_predictions_list,predictions_list,'predictions list not same.')
315-
194+
if __name__ == '__main__':
195+
etc.configLogging()
196+
unittest.main()

0 commit comments

Comments
 (0)