Skip to content

Commit 4531ff5

Browse files
authored
[python-package] adapt to scikit-learn 1.6 testing changes, pin more packages in R 3.6 CI jobs (#6718)
1 parent 5151fe8 commit 4531ff5

File tree

5 files changed

+135
-9
lines changed

5 files changed

+135
-9
lines changed

.ci/install-old-r-packages.R

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# [description]
2+
#
3+
# Installs a pinned set of packages that worked together
4+
# as of the last R 3.6 release.
5+
#
6+
7+
.install_packages <- function(packages) {
8+
install.packages( # nolint: undesirable_function
9+
pkgs = paste( # nolint: paste
10+
"https://cran.r-project.org/src/contrib/Archive"
11+
, packages
12+
, sep = "/"
13+
)
14+
, dependencies = FALSE
15+
, lib = Sys.getenv("R_LIBS")
16+
, repos = NULL
17+
)
18+
}
19+
20+
# when confronted with a bunch of URLs like this, install.packages() sometimes
21+
# struggles to determine install order... so install packages in batches here,
22+
# starting from the root of the dependency graph and working up
23+
24+
# there was only a single release of {praise}, so there is no contrib/Archive URL for it
25+
install.packages( # nolint: undesirable_function
26+
pkgs = "https://cran.r-project.org/src/contrib/praise_1.0.0.tar.gz"
27+
, dependencies = FALSE
28+
, lib = Sys.getenv("R_LIBS")
29+
, repos = NULL
30+
)
31+
32+
.install_packages(c(
33+
"brio/brio_1.1.4.tar.gz" # nolint: non_portable_path
34+
, "cli/cli_3.6.2.tar.gz" # nolint: non_portable_path
35+
, "crayon/crayon_1.5.2.tar.gz" # nolint: non_portable_path
36+
, "digest/digest_0.6.36.tar.gz" # nolint: non_portable_path
37+
, "evaluate/evaluate_0.23.tar.gz" # nolint: non_portable_path
38+
, "fansi/fansi_1.0.5.tar.gz" # nolint: non_portable_path
39+
, "fs/fs_1.6.4.tar.gz" # nolint: non_portable_path
40+
, "glue/glue_1.7.0.tar.gz" # nolint: non_portable_path
41+
, "jsonlite/jsonlite_1.8.8.tar.gz" # nolint: non_portable_path
42+
, "lattice/lattice_0.20-41.tar.gz" # nolint: non_portable_path
43+
, "magrittr/magrittr_2.0.2.tar.gz" # nolint: non_portable_path
44+
, "pkgconfig/pkgconfig_2.0.2.tar.gz" # nolint: non_portable_path
45+
, "ps/ps_1.8.0.tar.gz" # nolint: non_portable_path
46+
, "R6/R6_2.5.0.tar.gz" # nolint: non_portable_path
47+
, "rlang/rlang_1.1.3.tar.gz" # nolint: non_portable_path
48+
, "rprojroot/rprojroot_2.0.3.tar.gz" # nolint: non_portable_path
49+
, "utf8/utf8_1.2.3.tar.gz" # nolint: non_portable_path
50+
, "withr/withr_3.0.1.tar.gz" # nolint: non_portable_path
51+
))
52+
53+
.install_packages(c(
54+
"desc/desc_1.4.2.tar.gz" # nolint: non_portable_path
55+
, "diffobj/diffobj_0.3.4.tar.gz" # nolint: non_portable_path
56+
, "lifecycle/lifecycle_1.0.3.tar.gz" # nolint: non_portable_path
57+
, "processx/processx_3.8.3.tar.gz" # nolint: non_portable_path
58+
))
59+
60+
.install_packages(c(
61+
"callr/callr_3.7.5.tar.gz" # nolint: non_portable_path
62+
, "vctrs/vctrs_0.6.4.tar.gz" # nolint: non_portable_path
63+
))
64+
65+
.install_packages(c(
66+
"pillar/pillar_1.8.1.tar.gz" # nolint: non_portable_path
67+
, "tibble/tibble_3.2.0.tar.gz" # nolint: non_portable_path
68+
))
69+
70+
.install_packages(c(
71+
"pkgbuild/pkgbuild_1.4.4.tar.gz" # nolint: non_portable_path
72+
, "rematch2/rematch2_2.1.1.tar.gz" # nolint: non_portable_path
73+
, "waldo/waldo_0.5.3.tar.gz" # nolint: non_portable_path
74+
))
75+
76+
.install_packages(c(
77+
"pkgload/pkgload_1.3.4.tar.gz" # nolint: non_portable_path
78+
, "testthat/testthat_3.2.1.tar.gz" # nolint: non_portable_path
79+
))

.ci/test-r-package.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,10 @@ if [[ $OS_NAME == "macos" ]]; then
108108
export R_TIDYCMD=/usr/local/bin/tidy
109109
fi
110110

111-
# fix for issue where CRAN was not returning {lattice} and {evaluate} when using R 3.6
111+
# fix for issue where CRAN was not returning {evaluate}, {lattice}, or {waldo} when using R 3.6
112112
# "Warning: dependency ‘lattice’ is not available"
113113
if [[ "${R_MAJOR_VERSION}" == "3" ]]; then
114-
Rscript --vanilla -e "install.packages(c('https://cran.r-project.org/src/contrib/Archive/lattice/lattice_0.20-41.tar.gz', 'https://cran.r-project.org/src/contrib/Archive/evaluate/evaluate_0.23.tar.gz'), repos = NULL, lib = '${R_LIB_PATH}')"
114+
Rscript --vanilla ./.ci/install-old-r-packages.R
115115
else
116116
# {Matrix} needs {lattice}, so this needs to run before manually installing {Matrix}.
117117
# This should be unnecessary on R >=4.4.0

python-package/lightgbm/compat.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@
1414
from sklearn.utils.multiclass import check_classification_targets
1515
from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
1616

17+
# sklearn.utils Tags types can be imported unconditionally once
18+
# lightgbm's minimum scikit-learn version is 1.6 or higher
19+
try:
20+
from sklearn.utils import ClassifierTags as _sklearn_ClassifierTags
21+
from sklearn.utils import RegressorTags as _sklearn_RegressorTags
22+
except ImportError:
23+
_sklearn_ClassifierTags = None
24+
_sklearn_RegressorTags = None
1725
try:
1826
from sklearn.exceptions import NotFittedError
1927
from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
@@ -140,6 +148,8 @@ class _LGBMRegressorBase: # type: ignore
140148
_LGBMCheckClassificationTargets = None
141149
_LGBMComputeSampleWeight = None
142150
_LGBMValidateData = None
151+
_sklearn_ClassifierTags = None
152+
_sklearn_RegressorTags = None
143153
_sklearn_version = None
144154

145155
# additional scikit-learn imports only for type hints

python-package/lightgbm/sklearn.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
_LGBMModelBase,
4141
_LGBMRegressorBase,
4242
_LGBMValidateData,
43+
_sklearn_ClassifierTags,
44+
_sklearn_RegressorTags,
4345
_sklearn_version,
4446
dt_DataTable,
4547
pd_DataFrame,
@@ -703,7 +705,6 @@ def _update_sklearn_tags_from_dict(
703705
tags.input_tags.allow_nan = tags_dict["allow_nan"]
704706
tags.input_tags.sparse = "sparse" in tags_dict["X_types"]
705707
tags.target_tags.one_d_labels = "1dlabels" in tags_dict["X_types"]
706-
tags._xfail_checks = tags_dict["_xfail_checks"]
707708
return tags
708709

709710
def __sklearn_tags__(self) -> Optional["_sklearn_Tags"]:
@@ -1291,7 +1292,10 @@ def _more_tags(self) -> Dict[str, Any]:
12911292
return tags
12921293

12931294
def __sklearn_tags__(self) -> "_sklearn_Tags":
1294-
return LGBMModel.__sklearn_tags__(self)
1295+
tags = LGBMModel.__sklearn_tags__(self)
1296+
tags.estimator_type = "regressor"
1297+
tags.regressor_tags = _sklearn_RegressorTags(multi_label=False)
1298+
return tags
12951299

12961300
def fit( # type: ignore[override]
12971301
self,
@@ -1350,7 +1354,10 @@ def _more_tags(self) -> Dict[str, Any]:
13501354
return tags
13511355

13521356
def __sklearn_tags__(self) -> "_sklearn_Tags":
1353-
return LGBMModel.__sklearn_tags__(self)
1357+
tags = LGBMModel.__sklearn_tags__(self)
1358+
tags.estimator_type = "classifier"
1359+
tags.classifier_tags = _sklearn_ClassifierTags(multi_class=True, multi_label=False)
1360+
return tags
13541361

13551362
def fit( # type: ignore[override]
13561363
self,

tests/python_package_test/test_sklearn.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,18 @@
1717
from sklearn.metrics import accuracy_score, log_loss, mean_squared_error, r2_score
1818
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, train_test_split
1919
from sklearn.multioutput import ClassifierChain, MultiOutputClassifier, MultiOutputRegressor, RegressorChain
20-
from sklearn.utils.estimator_checks import parametrize_with_checks
20+
from sklearn.utils.estimator_checks import parametrize_with_checks as sklearn_parametrize_with_checks
2121
from sklearn.utils.validation import check_is_fitted
2222

2323
import lightgbm as lgb
24-
from lightgbm.compat import DATATABLE_INSTALLED, PANDAS_INSTALLED, dt_DataTable, pd_DataFrame, pd_Series
24+
from lightgbm.compat import (
25+
DATATABLE_INSTALLED,
26+
PANDAS_INSTALLED,
27+
_sklearn_version,
28+
dt_DataTable,
29+
pd_DataFrame,
30+
pd_Series,
31+
)
2532

2633
from .utils import (
2734
assert_silent,
@@ -35,6 +42,9 @@
3542
softmax,
3643
)
3744

45+
SKLEARN_MAJOR, SKLEARN_MINOR, *_ = _sklearn_version.split(".")
46+
SKLEARN_VERSION_GTE_1_6 = (int(SKLEARN_MAJOR), int(SKLEARN_MINOR)) >= (1, 6)
47+
3848
decreasing_generator = itertools.count(0, -1)
3949
estimator_classes = (lgb.LGBMModel, lgb.LGBMClassifier, lgb.LGBMRegressor, lgb.LGBMRanker)
4050
task_to_model_factory = {
@@ -1432,7 +1442,28 @@ def test_getting_feature_names_in_pd_input(estimator_class):
14321442
np.testing.assert_array_equal(model.feature_names_in_, X.columns)
14331443

14341444

1435-
@parametrize_with_checks([lgb.LGBMClassifier(), lgb.LGBMRegressor()])
1445+
# Starting with scikit-learn 1.6 (https://github.com/scikit-learn/scikit-learn/pull/30149),
1446+
# the only API for marking estimator tests as expected to fail is to pass a keyword argument
1447+
# to parametrize_with_checks(). That function didn't accept additional arguments in earlier
1448+
# versions.
1449+
#
1450+
# This block defines a patched version of parametrize_with_checks() so lightgbm's tests
1451+
# can be compatible with scikit-learn <1.6 and >=1.6.
1452+
#
1453+
# This should be removed once minimum supported scikit-learn version is at least 1.6.
1454+
if SKLEARN_VERSION_GTE_1_6:
1455+
parametrize_with_checks = sklearn_parametrize_with_checks
1456+
else:
1457+
1458+
def parametrize_with_checks(estimator, *args, **kwargs):
1459+
return sklearn_parametrize_with_checks(estimator)
1460+
1461+
1462+
def _get_expected_failed_tests(estimator):
1463+
return estimator._more_tags()["_xfail_checks"]
1464+
1465+
1466+
@parametrize_with_checks([lgb.LGBMClassifier(), lgb.LGBMRegressor()], expected_failed_checks=_get_expected_failed_tests)
14361467
def test_sklearn_integration(estimator, check):
14371468
estimator.set_params(min_child_samples=1, min_data_in_bin=1)
14381469
check(estimator)
@@ -1457,7 +1488,6 @@ def test_sklearn_tags_should_correctly_reflect_lightgbm_specific_values(estimato
14571488
assert sklearn_tags.input_tags.allow_nan is True
14581489
assert sklearn_tags.input_tags.sparse is True
14591490
assert sklearn_tags.target_tags.one_d_labels is True
1460-
assert sklearn_tags._xfail_checks == more_tags["_xfail_checks"]
14611491

14621492

14631493
@pytest.mark.parametrize("task", all_tasks)

0 commit comments

Comments
 (0)