Skip to content

Commit 7aed2a0

Browse files
authored
Adapt to deprecations in pandas 2.2.0 (#3620)
* Adapt to deprecations in pandas 2.2.0 * Backcompat for bug where the fix was deprecated? * More multi-version support
1 parent a3cb0f1 commit 7aed2a0

File tree

8 files changed

+42
-14
lines changed

8 files changed

+42
-14
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,6 @@ exclude = ["doc/_static/*.svg"]
6666
[tool.pytest.ini_options]
6767
filterwarnings = [
6868
"ignore:The --rsyncdir command line argument and rsyncdirs config variable are deprecated.:DeprecationWarning",
69+
"ignore:\\s*Pyarrow will become a required dependency of pandas:DeprecationWarning",
70+
"ignore:datetime.datetime.utcfromtimestamp\\(\\) is deprecated:DeprecationWarning",
6971
]

seaborn/_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -942,9 +942,9 @@ def iter_data(
942942

943943
for key in iter_keys:
944944

945-
# Pandas fails with singleton tuple inputs
946-
pd_key = key[0] if len(key) == 1 else key
947-
945+
pd_key = (
946+
key[0] if len(key) == 1 and _version_predates(pd, "2.2.0") else key
947+
)
948948
try:
949949
data_subset = grouped_data.get_group(pd_key)
950950
except KeyError:

seaborn/_compat.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Literal
33

44
import numpy as np
5+
import pandas as pd
56
import matplotlib as mpl
67
from matplotlib.figure import Figure
78
from seaborn.utils import _version_predates
@@ -114,3 +115,9 @@ def get_legend_handles(legend):
114115
return legend.legendHandles
115116
else:
116117
return legend.legend_handles
118+
119+
120+
def groupby_apply_include_groups(val):
121+
if _version_predates(pd, "2.2.0"):
122+
return {}
123+
return {"include_groups": val}

seaborn/_core/plot.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from seaborn._core.exceptions import PlotSpecError
4242
from seaborn._core.rules import categorical_order
4343
from seaborn._compat import get_layout_engine, set_layout_engine
44+
from seaborn.utils import _version_predates
4445
from seaborn.rcmod import axes_style, plotting_context
4546
from seaborn.palettes import color_palette
4647

@@ -1637,9 +1638,10 @@ def split_generator(keep_na=False) -> Generator:
16371638

16381639
for key in itertools.product(*grouping_keys):
16391640

1640-
# Pandas fails with singleton tuple inputs
1641-
pd_key = key[0] if len(key) == 1 else key
1642-
1641+
pd_key = (
1642+
key[0] if len(key) == 1 and _version_predates(pd, "2.2.0")
1643+
else key
1644+
)
16431645
try:
16441646
df_subset = grouped_df.get_group(pd_key)
16451647
except KeyError:

seaborn/categorical.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
_scatter_legend_artist,
2929
_version_predates,
3030
)
31+
from seaborn._compat import groupby_apply_include_groups
3132
from seaborn._statistics import (
3233
EstimateAggregator,
3334
LetterValues,
@@ -634,10 +635,10 @@ def get_props(element, artist=mpl.lines.Line2D):
634635
ax = self._get_axes(sub_vars)
635636

636637
grouped = sub_data.groupby(self.orient)[value_var]
638+
positions = sorted(sub_data[self.orient].unique().astype(float))
637639
value_data = [x.to_numpy() for _, x in grouped]
638640
stats = pd.DataFrame(mpl.cbook.boxplot_stats(value_data, whis=whis,
639641
bootstrap=bootstrap))
640-
positions = grouped.grouper.result_index.to_numpy(dtype=float)
641642

642643
orig_width = width * self._native_width
643644
data = pd.DataFrame({self.orient: positions, "width": orig_width})
@@ -1207,7 +1208,7 @@ def plot_points(
12071208
agg_data = sub_data if sub_data.empty else (
12081209
sub_data
12091210
.groupby(self.orient)
1210-
.apply(aggregator, agg_var)
1211+
.apply(aggregator, agg_var, **groupby_apply_include_groups(False))
12111212
.reindex(pd.Index(positions, name=self.orient))
12121213
.reset_index()
12131214
)
@@ -1278,7 +1279,7 @@ def plot_bars(
12781279
agg_data = sub_data if sub_data.empty else (
12791280
sub_data
12801281
.groupby(self.orient)
1281-
.apply(aggregator, agg_var)
1282+
.apply(aggregator, agg_var, **groupby_apply_include_groups(False))
12821283
.reset_index()
12831284
)
12841285

seaborn/relational.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_get_transform_functions,
1818
_scatter_legend_artist,
1919
)
20+
from ._compat import groupby_apply_include_groups
2021
from ._statistics import EstimateAggregator, WeightedAggregator
2122
from .axisgrid import FacetGrid, _facet_docs
2223
from ._docstrings import DocstringComponents, _core_docs
@@ -290,7 +291,11 @@ def plot(self, ax, kws):
290291
grouped = sub_data.groupby(orient, sort=self.sort)
291292
# Could pass as_index=False instead of reset_index,
292293
# but that fails on a corner case with older pandas.
293-
sub_data = grouped.apply(agg, other).reset_index()
294+
sub_data = (
295+
grouped
296+
.apply(agg, other, **groupby_apply_include_groups(False))
297+
.reset_index()
298+
)
294299
else:
295300
sub_data[f"{other}min"] = np.nan
296301
sub_data[f"{other}max"] = np.nan

tests/_stats/test_density.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from seaborn._core.groupby import GroupBy
88
from seaborn._stats.density import KDE, _no_scipy
9+
from seaborn._compat import groupby_apply_include_groups
910

1011

1112
class TestKDE:
@@ -93,7 +94,10 @@ def test_common_norm(self, df, common_norm):
9394

9495
areas = (
9596
res.groupby("alpha")
96-
.apply(lambda x: self.integrate(x["density"], x[ori]))
97+
.apply(
98+
lambda x: self.integrate(x["density"], x[ori]),
99+
**groupby_apply_include_groups(False),
100+
)
97101
)
98102

99103
if common_norm:
@@ -111,11 +115,18 @@ def test_common_norm_variables(self, df):
111115
def integrate_by_color_and_sum(x):
112116
return (
113117
x.groupby("color")
114-
.apply(lambda y: self.integrate(y["density"], y[ori]))
118+
.apply(
119+
lambda y: self.integrate(y["density"], y[ori]),
120+
**groupby_apply_include_groups(False)
121+
)
115122
.sum()
116123
)
117124

118-
areas = res.groupby("alpha").apply(integrate_by_color_and_sum)
125+
areas = (
126+
res
127+
.groupby("alpha")
128+
.apply(integrate_by_color_and_sum, **groupby_apply_include_groups(False))
129+
)
119130
assert_array_almost_equal(areas, [1, 1], decimal=3)
120131

121132
@pytest.mark.parametrize("param", ["norm", "grid"])

tests/test_categorical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2078,7 +2078,7 @@ def test_xy_native_scale_log_transform(self):
20782078

20792079
def test_datetime_native_scale_axis(self):
20802080

2081-
x = pd.date_range("2010-01-01", periods=20, freq="m")
2081+
x = pd.date_range("2010-01-01", periods=20, freq="MS")
20822082
y = np.arange(20)
20832083
ax = barplot(x=x, y=y, native_scale=True)
20842084
assert "Date" in ax.xaxis.get_major_locator().__class__.__name__

0 commit comments

Comments
 (0)