Skip to content

Update global learners to work with scikit-learn 1.6.0 #291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 7, 2025
Merged
2 changes: 1 addition & 1 deletion doubleml/irm/tests/test_apos_external_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def doubleml_apos_ext_fixture(n_rep, treatment_levels, set_ml_m_ext, set_ml_g_ex
"draw_sample_splitting": False,
}

dml_obj = DoubleMLAPOS(ml_g=LinearRegression(), ml_m=LogisticRegression(), **kwargs)
dml_obj = DoubleMLAPOS(ml_g=LinearRegression(), ml_m=LogisticRegression(random_state=42), **kwargs)
dml_obj.set_sample_splitting(all_smpls=all_smpls)

np.random.seed(3141)
Expand Down
15 changes: 13 additions & 2 deletions doubleml/irm/tests/test_ssm_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,13 @@ def predict_proba(self):
pass


class LogisticRegressionManipulatedType(LogisticRegression):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.estimator_type = None
return tags


@pytest.mark.ci
@pytest.mark.filterwarnings(
r"ignore:.*is \(probably\) neither a regressor nor a classifier.*:UserWarning",
Expand Down Expand Up @@ -233,9 +240,13 @@ def test_ssm_exception_learner():

# construct a classifier which is not identifiable as classifier via is_classifier by sklearn
# it then predicts labels and therefore an exception will be thrown
log_reg = LogisticRegression()
log_reg = LogisticRegressionManipulatedType()
# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0
log_reg._estimator_type = None
msg = r"Learner provided for ml_m is probably invalid: LogisticRegression\(\) is \(probably\) no classifier."
msg = (
r"Learner provided for ml_m is probably invalid: LogisticRegressionManipulatedType\(\) is \(probably\) "
"no classifier."
)
with pytest.warns(UserWarning, match=msg):
_ = DoubleMLSSM(dml_data_mar, ml_g, ml_pi, log_reg)

Expand Down
15 changes: 11 additions & 4 deletions doubleml/tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,11 @@ def predict_proba(self):


class LogisticRegressionManipulatedPredict(LogisticRegression):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.estimator_type = None
return tags

def predict(self, X):
if self.max_iter == 314:
preds = super().predict_proba(X)[:, 1]
Expand Down Expand Up @@ -1063,16 +1068,17 @@ def test_doubleml_exception_learner():

# construct a classifier which is not identifiable as classifier via is_classifier by sklearn
# it then predicts labels and therefore an exception will be thrown
log_reg = LogisticRegression()
log_reg = LogisticRegressionManipulatedPredict()
# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0
log_reg._estimator_type = None
msg = (
r"Learner provided for ml_m is probably invalid: LogisticRegression\(\) is \(probably\) neither a regressor "
"nor a classifier. Method predict is used for prediction."
r"Learner provided for ml_m is probably invalid: LogisticRegressionManipulatedPredict\(\) is \(probably\) "
"neither a regressor nor a classifier. Method predict is used for prediction."
)
with pytest.warns(UserWarning, match=msg):
dml_plr_hidden_classifier = DoubleMLPLR(dml_data_irm, Lasso(), log_reg)
msg = (
r"For the binary variable d, predictions obtained with the ml_m learner LogisticRegression\(\) "
r"For the binary variable d, predictions obtained with the ml_m learner LogisticRegressionManipulatedPredict\(\) "
"are also observed to be binary with values 0 and 1. Make sure that for classifiers probabilities and not "
"labels are predicted."
)
Expand All @@ -1083,6 +1089,7 @@ def test_doubleml_exception_learner():
# it then predicts labels and therefore an exception will be thrown
# whether predict() or predict_proba() is being called can also be manipulated via the unrelated max_iter variable
log_reg = LogisticRegressionManipulatedPredict()
# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0
log_reg._estimator_type = None
msg = (
r"Learner provided for ml_g is probably invalid: LogisticRegressionManipulatedPredict\(\) is \(probably\) "
Expand Down
6 changes: 3 additions & 3 deletions doubleml/utils/dummy_learners.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sklearn.base import BaseEstimator
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin


class DMLDummyRegressor(BaseEstimator):
class DMLDummyRegressor(RegressorMixin, BaseEstimator):
"""
A dummy regressor that raises an AttributeError when attempting to access
its fit, predict, or set_params methods.
Expand Down Expand Up @@ -35,7 +35,7 @@ def set_params(*args):
raise AttributeError("Accessed set_params method of DMLDummyRegressor!")


class DMLDummyClassifier(BaseEstimator):
class DMLDummyClassifier(ClassifierMixin, BaseEstimator):
"""
A dummy classifier that raises an AttributeError when attempting to access
its fit, predict, set_params, or predict_proba methods.
Expand Down
44 changes: 36 additions & 8 deletions doubleml/utils/global_learner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
from sklearn import __version__ as sklearn_version
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, clone, is_classifier, is_regressor
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import _check_sample_weight, check_is_fitted


class GlobalRegressor(BaseEstimator, RegressorMixin):
def parse_version(version):
return tuple(map(int, version.split(".")[:2]))


# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0
sklearn_supports_validation = parse_version(sklearn_version) >= (1, 6)
if sklearn_supports_validation:
from sklearn.utils.validation import validate_data


class GlobalRegressor(RegressorMixin, BaseEstimator):
"""
A global regressor that ignores the attribute `sample_weight` when being fit to ensure a global fit.

Expand All @@ -13,9 +25,6 @@ class GlobalRegressor(BaseEstimator, RegressorMixin):
"""

def __init__(self, base_estimator):
if not is_regressor(base_estimator):
raise ValueError(f"base_estimator must be a regressor. Got {base_estimator.__class__.__name__} instead.")

self.base_estimator = base_estimator

def fit(self, X, y, sample_weight=None):
Expand All @@ -33,6 +42,15 @@ def fit(self, X, y, sample_weight=None):
sample_weight: array-like of shape (n_samples,).
Individual weights for each sample. Ignored.
"""
if not is_regressor(self.base_estimator):
raise ValueError(f"base_estimator must be a regressor. Got {self.base_estimator.__class__.__name__} instead.")

# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0
if sklearn_supports_validation:
X, y = validate_data(self, X, y)
else:
X, y = self._validate_data(X, y)
_check_sample_weight(sample_weight, X)
self._fitted_learner = clone(self.base_estimator)
self._fitted_learner.fit(X, y)

Expand All @@ -47,10 +65,12 @@ def predict(self, X):
X: array-like of shape (n_samples, n_features)
Samples.
"""

check_is_fitted(self)
return self._fitted_learner.predict(X)


class GlobalClassifier(BaseEstimator, ClassifierMixin):
class GlobalClassifier(ClassifierMixin, BaseEstimator):
"""
A global classifier that ignores the attribute ``sample_weight`` when being fit to ensure a global fit.

Expand All @@ -61,9 +81,6 @@ class GlobalClassifier(BaseEstimator, ClassifierMixin):
"""

def __init__(self, base_estimator):
if not is_classifier(base_estimator):
raise ValueError(f"base_estimator must be a classifier. Got {base_estimator.__class__.__name__} instead.")

self.base_estimator = base_estimator

def fit(self, X, y, sample_weight=None):
Expand All @@ -81,6 +98,15 @@ def fit(self, X, y, sample_weight=None):
sample_weight: array-like of shape (n_samples,).
Individual weights for each sample. Ignored.
"""
if not is_classifier(self.base_estimator):
raise ValueError(f"base_estimator must be a classifier. Got {self.base_estimator.__class__.__name__} instead.")

# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0
if sklearn_supports_validation:
X, y = validate_data(self, X, y)
else:
X, y = self._validate_data(X, y)
_check_sample_weight(sample_weight, X)
self.classes_ = unique_labels(y)
self._fitted_learner = clone(self.base_estimator)
self._fitted_learner.fit(X, y)
Expand All @@ -96,6 +122,7 @@ def predict(self, X):
X: array-like of shape (n_samples, n_features)
Samples.
"""
check_is_fitted(self)
return self._fitted_learner.predict(X)

def predict_proba(self, X):
Expand All @@ -108,4 +135,5 @@ def predict_proba(self, X):
X: array-like of shape (n_samples, n_features)
Samples to be scored.
"""
check_is_fitted(self)
return self._fitted_learner.predict_proba(X)
6 changes: 4 additions & 2 deletions doubleml/utils/tests/test_exceptions_global_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
def test_global_regressor_input():
msg = "base_estimator must be a regressor. Got LogisticRegression instead."
with pytest.raises(ValueError, match=msg):
_ = GlobalRegressor(base_estimator=LogisticRegression(random_state=42))
reg = GlobalRegressor(base_estimator=LogisticRegression(random_state=42))
reg.fit(X=[[1, 2], [3, 4]], y=[1, 2])


@pytest.mark.ci
def test_global_classifier_input():
msg = "base_estimator must be a classifier. Got LinearRegression instead."
with pytest.raises(ValueError, match=msg):
_ = GlobalClassifier(base_estimator=LinearRegression())
clas = GlobalClassifier(base_estimator=LinearRegression())
clas.fit(X=[[1, 2], [3, 4]], y=[1, 2])
133 changes: 129 additions & 4 deletions doubleml/utils/tests/test_global_learners.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
import numpy as np
import pytest
from sklearn import __version__ as sklearn_version
from sklearn.base import clone
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, StackingClassifier, StackingRegressor
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.model_selection import KFold
from sklearn.utils.estimator_checks import check_estimator

from doubleml.utils import GlobalClassifier, GlobalRegressor


def parse_version(version):
return tuple(map(int, version.split(".")[:2]))


# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0
sklearn_post_1_6 = parse_version(sklearn_version) >= (1, 6)


@pytest.fixture(
scope="module", params=[LinearRegression(), RandomForestRegressor(n_estimators=10, max_depth=2, random_state=42)]
)
Expand All @@ -22,6 +33,36 @@ def classifier(request):
return request.param


@pytest.mark.ci
def test_global_regressor(regressor):
if sklearn_post_1_6:
check_estimator(
estimator=GlobalRegressor(base_estimator=regressor),
expected_failed_checks={
"check_sample_weight_equivalence_on_dense_data": "weights are ignored",
"check_estimators_nan_inf": "allowed for some estimators",
},
)
else:
# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0
pytest.skip("sklearn version is too old for this test")


@pytest.mark.ci
def test_global_classifier(classifier):
if sklearn_post_1_6:
check_estimator(
estimator=GlobalClassifier(base_estimator=classifier),
expected_failed_checks={
"check_sample_weight_equivalence_on_dense_data": "weights are ignored",
"check_estimators_nan_inf": "allowed for some estimators",
},
)
else:
# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0
pytest.skip("sklearn version is too old for this test")


@pytest.fixture(scope="module")
def gl_fixture(regressor, classifier):
global_reg = GlobalRegressor(base_estimator=regressor)
Expand Down Expand Up @@ -54,9 +95,9 @@ def gl_fixture(regressor, classifier):
weighted_clas_pred = weighted_clas.predict(X=X)
unweighted_clas_pred = unweighted_clas.predict(X=X)

global_clas_pred_proba = global_clas.predict(X=X)
weighted_clas_pred_proba = weighted_clas.predict(X=X)
unweighted_clas_pred_proba = unweighted_clas.predict(X=X)
global_clas_pred_proba = global_clas.predict_proba(X=X)
weighted_clas_pred_proba = weighted_clas.predict_proba(X=X)
unweighted_clas_pred_proba = unweighted_clas.predict_proba(X=X)

result_dict = {
"GlobalRegressor": global_reg,
Expand Down Expand Up @@ -120,3 +161,87 @@ def test_clone(gl_fixture):

np.allclose(pred_global_reg, pred_clone_reg)
np.allclose(pred_global_clas, pred_clone_clas)


@pytest.fixture(scope="module")
def gl_stacking_fixture():

regressor = RandomForestRegressor(n_estimators=10, max_depth=2, random_state=42)
classifier = RandomForestClassifier(n_estimators=10, max_depth=2, random_state=42)

X = np.random.normal(0, 1, size=(100, 2))
y_con = np.random.normal(0, 1, size=(100))
y_cat = np.random.binomial(1, 0.5, size=(100))
sample_weight = np.random.random(size=(100))

kf = KFold(n_splits=2, shuffle=False)

global_reg = StackingRegressor(
[
("global", GlobalRegressor(base_estimator=clone(regressor))),
("lr", GlobalRegressor(LinearRegression())),
],
final_estimator=GlobalRegressor(LinearRegression()),
cv=kf,
)
unweighted_reg = StackingRegressor(
[("global", clone(regressor)), ("lr", LinearRegression())],
final_estimator=LinearRegression(),
cv=kf,
)

global_clas = StackingClassifier(
[
("global", GlobalClassifier(base_estimator=clone(classifier))),
("lr", GlobalClassifier(LogisticRegression(random_state=42))),
],
final_estimator=GlobalClassifier(LogisticRegression(random_state=42)),
cv=kf,
)
unweighted_clas = StackingClassifier(
[
("global", clone(classifier)),
("lr", LogisticRegression(random_state=42)),
],
final_estimator=LogisticRegression(random_state=42),
cv=kf,
)

# fit models
global_reg.fit(y=y_con, X=X, sample_weight=sample_weight)
unweighted_reg.fit(y=y_con, X=X)

global_clas.fit(y=y_cat, X=X, sample_weight=sample_weight)
unweighted_clas.fit(y=y_cat, X=X)

global_reg_pred = global_reg.predict(X=X)
unweighted_reg_pred = unweighted_reg.predict(X=X)

global_clas_pred = global_clas.predict(X=X)
unweighted_clas_pred = unweighted_clas.predict(X=X)

global_clas_pred_proba = global_clas.predict_proba(X=X)

unweighted_clas_pred_proba = unweighted_clas.predict_proba(X=X)

result_dict = {
"global_reg_pred": global_reg_pred,
"unweighted_reg_pred": unweighted_reg_pred,
"global_clas_pred": global_clas_pred,
"unweighted_clas_pred": unweighted_clas_pred,
"global_clas_pred_proba": global_clas_pred_proba,
"unweighted_clas_pred_proba": unweighted_clas_pred_proba,
}

return result_dict


@pytest.mark.ci
def test_stacking_predict(gl_stacking_fixture):
assert np.allclose(gl_stacking_fixture["global_reg_pred"], gl_stacking_fixture["unweighted_reg_pred"])
assert np.allclose(gl_stacking_fixture["global_clas_pred"], gl_stacking_fixture["unweighted_clas_pred"])


@pytest.mark.ci
def test_stacking_predict_proba(gl_stacking_fixture):
assert np.allclose(gl_stacking_fixture["global_clas_pred_proba"], gl_stacking_fixture["unweighted_clas_pred_proba"])
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
"numpy",
"pandas",
"scipy",
"scikit-learn>=1.4.0,<1.6.0",
"scikit-learn>=1.4.0",
"statsmodels",
"matplotlib",
"plotly"
Expand Down