Skip to content

Commit 1b523ed

Browse files
committed
[Tested] Improvements for model integration
1. Improved tests in `TestForestModelLoadandSave.py` 2. Better comments, imports nd cleanup
1 parent 104dd9a commit 1b523ed

File tree

8 files changed

+163
-411
lines changed

8 files changed

+163
-411
lines changed

emission/analysis/modelling/trip_model/dbscan_svm.py

Lines changed: 0 additions & 250 deletions
This file was deleted.

emission/analysis/modelling/trip_model/forest_classifier.py

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
import pandas as pd
2-
from sklearn.preprocessing import OneHotEncoder
31
import joblib
42
from typing import Dict, List, Optional, Tuple
5-
from sklearn.metrics.pairwise import haversine_distances
3+
import sklearn.metrics.pairwise as smp
64
import emission.core.wrapper.confirmedtrip as ecwc
75
import logging
86
from io import BytesIO
97

8+
import json
109
import emission.analysis.modelling.trip_model.trip_model as eamuu
1110
import emission.analysis.modelling.trip_model.config as eamtc
1211
import emission.storage.timeseries.builtin_timeseries as estb
1312
import emission.storage.decorations.trip_queries as esdtq
14-
from emission.analysis.modelling.trip_model.models import ForestClassifier
13+
import emission.analysis.modelling.trip_model.models as eamtm
1514

1615
EARTH_RADIUS = 6371000
1716

@@ -33,45 +32,33 @@ def __init__(self,config=None):
3332
'min_samples_leaf',
3433
'max_features',
3534
'bootstrap',
36-
]
37-
cluster_expected_keys= [
38-
'radius',
39-
'size_thresh',
40-
'purity_thresh',
41-
'gamma',
42-
'C',
43-
'use_start_clusters',
44-
'use_trip_clusters',
45-
]
46-
35+
]
36+
######### Not Tested #########
37+
# The below code is used when we cluster the coordinates (loc_cluster parameter = True)
38+
# before passing to Random Forest. Commenting this for now since it is not used. Not tested either.
39+
###############################
40+
41+
# cluster_expected_keys= [
42+
# 'radius',
43+
# 'size_thresh',
44+
# 'purity_thresh',
45+
# 'gamma',
46+
# 'C',
47+
# 'use_start_clusters',
48+
# 'use_trip_clusters',
49+
# ]
50+
#
51+
# if config['loc_feature'] == 'cluster':
52+
# for k in cluster_expected_keys:
53+
# if config.get(k) is None:
54+
# msg = f"cluster trip model config missing expected key {k}"
55+
# raise KeyError(msg)
56+
#######################################
4757
for k in random_forest_expected_keys:
4858
if config.get(k) is None:
4959
msg = f"forest trip model config missing expected key {k}"
5060
raise KeyError(msg)
51-
52-
if config['loc_feature'] == 'cluster':
53-
for k in cluster_expected_keys:
54-
if config.get(k) is None:
55-
msg = f"cluster trip model config missing expected key {k}"
56-
raise KeyError(msg)
57-
maxdepth =config['max_depth'] if config['max_depth']!='null' else None
58-
self.model=ForestClassifier( loc_feature=config['loc_feature'],
59-
radius= config['radius'],
60-
size_thresh=config['radius'],
61-
purity_thresh=config['purity_thresh'],
62-
gamma=config['gamma'],
63-
C=config['C'],
64-
n_estimators=config['n_estimators'],
65-
criterion=config['criterion'],
66-
max_depth=maxdepth,
67-
min_samples_split=config['min_samples_split'],
68-
min_samples_leaf=config['min_samples_leaf'],
69-
max_features=config['max_features'],
70-
bootstrap=config['bootstrap'],
71-
random_state=config['random_state'],
72-
# drop_unclustered=False,
73-
use_start_clusters=config['use_start_clusters'],
74-
use_trip_clusters=config['use_trip_clusters'])
61+
self.model=eamtm.ForestClassifier(**config)
7562

7663

7764
def fit(self,trips: List[ecwc.Confirmedtrip]):
@@ -139,9 +126,15 @@ def to_dict(self):
139126
"""
140127
data={}
141128
attr=[ 'purpose_predictor','mode_predictor','replaced_predictor','purpose_enc','mode_enc','train_df']
142-
if self.model.loc_feature == 'cluster':
143-
## confirm this includes all the extra encoders/models
144-
attr.extend([ 'cluster_enc','end_cluster_model','start_cluster_model','trip_grouper'])
129+
130+
######### Not Tested #########
131+
# The below code is used when we cluster the coordinates (loc_cluster parameter = True)
132+
# before passing to Random Forest. Commenting this for now since it is not used. Not tested either.
133+
###############################
134+
# if self.model.loc_feature == 'cluster':
135+
# ## confirm this includes all the extra encoders/models
136+
# attr.extend([ 'cluster_enc','end_cluster_model','start_cluster_model','trip_grouper'])
137+
145138
for attribute_name in attr:
146139
if not hasattr(self.model,attribute_name):
147140
raise ValueError(f"Attribute {attribute_name} not found in the model")
@@ -153,17 +146,22 @@ def to_dict(self):
153146
raise RuntimeError(f"Error serializing { attribute_name}: {str(e)}")
154147
buffer.seek(0)
155148
data[attribute_name]=buffer.getvalue()
156-
149+
157150
return data
158151

159152
def from_dict(self,model: Dict):
160153
"""
161154
Load the model from a dictionary.
162155
"""
163156
attr=[ 'purpose_predictor','mode_predictor','replaced_predictor','purpose_enc','mode_enc','train_df']
164-
if self.model.loc_feature == 'cluster':
165-
## TODO : confirm this includes all the extra encoders/models
166-
attr.extend([ 'cluster_enc','end_cluster_model','start_cluster_model','trip_grouper'])
157+
158+
######### Not Tested #########
159+
# The below code is used when we cluster the coordinates (loc_cluster parameter = True)
160+
# before passing to Random Forest. Commenting this for now since it is not used. Not tested either.
161+
###############################
162+
# if self.model.loc_feature == 'cluster':
163+
# ## TODO : confirm this includes all the extra encoders/models
164+
# attr.extend([ 'cluster_enc','end_cluster_model','start_cluster_model','trip_grouper'])
167165
for attribute_name in attr:
168166
if attribute_name not in model:
169167
raise ValueError(f"Attribute {attribute_name} missing in the model")
@@ -183,6 +181,7 @@ def extract_features(self, trip: ecwc.Confirmedtrip) -> List[float]:
183181
:return: a vector containing features to predict from
184182
:rtype: List[float]
185183
"""
184+
# ForestClassifier class in models.py file handles features extraction.
186185
pass
187186

188187
def is_incremental(self) -> bool:

0 commit comments

Comments
 (0)