Skip to content

Use default in TypeVar so Series defaults to Series[Any], and Index to Index[Any] #1232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ mypy round.py
we get the following error message:

```text
round.py:6: error: Argument "decimals" to "round" of "DataFrame" has incompatible type "DataFrame"; expected "Union[int, Dict[Any, Any], Series[Any]]" [arg-type]
round.py:6: error: Argument "decimals" to "round" of "DataFrame" has incompatible type "DataFrame"; expected "Union[int, Dict[Any, Any], Series]" [arg-type]
Found 1 error in 1 file (checked 1 source file)
```

Expand Down
4 changes: 2 additions & 2 deletions docs/philosophy.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ lt = s < 3

In the pandas source, `lt` is a `Series` with a `dtype` of `bool`. In the pandas-stubs,
the type of `lt` is `Series[bool]`. This allows further type checking to occur in other
pandas methods. Note that in the above example, `s` is typed as `Series[Any]` because
its type cannot be statically inferred.
pandas methods. Note that in the above example, `s` is just typed as `Series` (which
defaults to `Series[Any]`) because its type cannot be statically inferred.

This also allows type checking for operations on series that contain date/time data. Consider
the following example that creates two series of datetimes with corresponding arithmetic.
Expand Down
6 changes: 3 additions & 3 deletions pandas-stubs/_libs/tslibs/timestamps.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ from typing import (
import numpy as np
from pandas import (
DatetimeIndex,
Index,
TimedeltaIndex,
)
from pandas.core.indexes.base import UnknownIndex
from pandas.core.series import (
Series,
TimedeltaSeries,
Expand Down Expand Up @@ -236,15 +236,15 @@ class Timestamp(datetime, SupportsIndex):
@overload
def __eq__(self, other: TimestampSeries) -> Series[bool]: ... # type: ignore[overload-overlap]
@overload
def __eq__(self, other: npt.NDArray[np.datetime64] | UnknownIndex) -> np_ndarray_bool: ... # type: ignore[overload-overlap]
def __eq__(self, other: npt.NDArray[np.datetime64] | Index) -> np_ndarray_bool: ... # type: ignore[overload-overlap]
@overload
def __eq__(self, other: object) -> Literal[False]: ...
@overload
def __ne__(self, other: Timestamp | datetime | np.datetime64) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
@overload
def __ne__(self, other: TimestampSeries) -> Series[bool]: ... # type: ignore[overload-overlap]
@overload
def __ne__(self, other: npt.NDArray[np.datetime64] | UnknownIndex) -> np_ndarray_bool: ... # type: ignore[overload-overlap]
def __ne__(self, other: npt.NDArray[np.datetime64] | Index) -> np_ndarray_bool: ... # type: ignore[overload-overlap]
@overload
def __ne__(self, other: object) -> Literal[True]: ...
def __hash__(self) -> int: ...
Expand Down
38 changes: 10 additions & 28 deletions pandas-stubs/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ from typing import (
Protocol,
SupportsIndex,
TypedDict,
TypeVar,
Union,
overload,
)
Expand All @@ -36,6 +35,7 @@ from pandas.core.tools.datetimes import FulldatetimeDict
from typing_extensions import (
ParamSpec,
TypeAlias,
TypeVar,
)

from pandas._libs.interval import Interval
Expand Down Expand Up @@ -66,7 +66,7 @@ HashableT5 = TypeVar("HashableT5", bound=Hashable)
# array-like

ArrayLike: TypeAlias = ExtensionArray | np.ndarray
AnyArrayLike: TypeAlias = ArrayLike | Index[Any] | Series[Any]
AnyArrayLike: TypeAlias = ArrayLike | Index | Series

# list-like

Expand Down Expand Up @@ -803,7 +803,7 @@ DtypeNp = TypeVar("DtypeNp", bound=np.dtype[np.generic])
KeysArgType: TypeAlias = Any
ListLikeT = TypeVar("ListLikeT", bound=ListLike)
ListLikeExceptSeriesAndStr: TypeAlias = (
MutableSequence[Any] | np.ndarray | tuple[Any, ...] | Index[Any]
MutableSequence[Any] | np.ndarray | tuple[Any, ...] | Index
)
ListLikeU: TypeAlias = Sequence | np.ndarray | Series | Index
ListLikeHashable: TypeAlias = (
Expand All @@ -826,29 +826,8 @@ MaskType: TypeAlias = Series[bool] | np_ndarray_bool | list[bool]

# Scratch types for generics

S1 = TypeVar(
"S1",
bound=str
| bytes
| datetime.date
| datetime.time
| bool
| int
| float
| complex
| Dtype
| datetime.datetime # includes pd.Timestamp
| datetime.timedelta # includes pd.Timedelta
| Period
| Interval
| CategoricalDtype
| BaseOffset
| list[str],
)

S2 = TypeVar(
"S2",
bound=str
SeriesDType: TypeAlias = (
str
| bytes
| datetime.date
| datetime.time
Expand All @@ -863,8 +842,11 @@ S2 = TypeVar(
| Interval
| CategoricalDtype
| BaseOffset
| list[str],
| list[str]
)
S1 = TypeVar("S1", bound=SeriesDType, default=Any)
# Like S1, but without `default=Any`.
S2 = TypeVar("S2", bound=SeriesDType)

IndexingInt: TypeAlias = (
int | np.int_ | np.integer | np.unsignedinteger | np.signedinteger | np.int8
Expand Down Expand Up @@ -951,7 +933,7 @@ ReplaceValue: TypeAlias = (
| NAType
| Sequence[Scalar | Pattern]
| Mapping[HashableT, ScalarT]
| Series[Any]
| Series
| None
)

Expand Down
8 changes: 4 additions & 4 deletions pandas-stubs/core/dtypes/missing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ isneginf_scalar = ...
@overload
def isna(obj: DataFrame) -> DataFrame: ...
@overload
def isna(obj: Series[Any]) -> Series[bool]: ...
def isna(obj: Series) -> Series[bool]: ...
@overload
def isna(obj: Index[Any] | list[Any] | ArrayLike) -> npt.NDArray[np.bool_]: ...
def isna(obj: Index | list[Any] | ArrayLike) -> npt.NDArray[np.bool_]: ...
@overload
def isna(
obj: Scalar | NaTType | NAType | None,
Expand All @@ -39,9 +39,9 @@ isnull = isna
@overload
def notna(obj: DataFrame) -> DataFrame: ...
@overload
def notna(obj: Series[Any]) -> Series[bool]: ...
def notna(obj: Series) -> Series[bool]: ...
@overload
def notna(obj: Index[Any] | list[Any] | ArrayLike) -> npt.NDArray[np.bool_]: ...
def notna(obj: Index | list[Any] | ArrayLike) -> npt.NDArray[np.bool_]: ...
@overload
def notna(obj: ScalarT | NaTType | NAType | None) -> TypeIs[ScalarT]: ...

Expand Down
50 changes: 27 additions & 23 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ from pandas.core.reshape.pivot import (
)
from pandas.core.series import (
Series,
UnknownSeries,
)
from pandas.core.window import (
Expanding,
Expand All @@ -79,7 +78,7 @@ from pandas._libs.tslibs import BaseOffset
from pandas._libs.tslibs.nattype import NaTType
from pandas._libs.tslibs.offsets import DateOffset
from pandas._typing import (
S1,
S2,
AggFuncTypeBase,
AggFuncTypeDictFrame,
AggFuncTypeDictSeries,
Expand Down Expand Up @@ -1319,11 +1318,11 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
@overload
def stack(
self, level: Level | list[Level] = ..., dropna: _bool = ..., sort: _bool = ...
) -> Self | Series[Any]: ...
) -> Self | Series: ...
@overload
def stack(
self, level: Level | list[Level] = ..., future_stack: _bool = ...
) -> Self | Series[Any]: ...
) -> Self | Series: ...
def explode(
self, column: Sequence[Hashable], ignore_index: _bool = ...
) -> Self: ...
Expand Down Expand Up @@ -1383,7 +1382,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
@overload
def apply(
self,
f: Callable[..., ListLikeExceptSeriesAndStr | Series[Any]],
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
axis: AxisIndex = ...,
raw: _bool = ...,
result_type: None = ...,
Expand All @@ -1393,13 +1392,14 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
@overload
def apply(
self,
f: Callable[..., S1 | NAType],
# Use S2 (TypeVar without `default=Any`) instead of S1 due to https://github.com/python/mypy/issues/19182.
f: Callable[..., S2 | NAType],
axis: AxisIndex = ...,
raw: _bool = ...,
result_type: None = ...,
args: Any = ...,
**kwargs: Any,
) -> Series[S1]: ...
) -> Series[S2]: ...
# Since non-scalar type T is not supported in Series[T],
# we separate this overload from the above one
@overload
Expand All @@ -1411,24 +1411,25 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
result_type: None = ...,
args: Any = ...,
**kwargs: Any,
) -> Series[Any]: ...
) -> Series: ...

# apply() overloads with keyword result_type, and axis does not matter
@overload
def apply(
self,
f: Callable[..., S1 | NAType],
# Use S2 (TypeVar without `default=Any`) instead of S1 due to https://github.com/python/mypy/issues/19182.
f: Callable[..., S2 | NAType],
axis: Axis = ...,
raw: _bool = ...,
args: Any = ...,
*,
result_type: Literal["expand", "reduce"],
**kwargs: Any,
) -> Series[S1]: ...
) -> Series[S2]: ...
@overload
def apply(
self,
f: Callable[..., ListLikeExceptSeriesAndStr | Series[Any] | Mapping[Any, Any]],
f: Callable[..., ListLikeExceptSeriesAndStr | Series | Mapping[Any, Any]],
axis: Axis = ...,
raw: _bool = ...,
args: Any = ...,
Expand All @@ -1446,12 +1447,12 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
*,
result_type: Literal["reduce"],
**kwargs: Any,
) -> Series[Any]: ...
) -> Series: ...
@overload
def apply(
self,
f: Callable[
..., ListLikeExceptSeriesAndStr | Series[Any] | Scalar | Mapping[Any, Any]
..., ListLikeExceptSeriesAndStr | Series | Scalar | Mapping[Any, Any]
],
axis: Axis = ...,
raw: _bool = ...,
Expand All @@ -1465,27 +1466,28 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
@overload
def apply(
self,
f: Callable[..., Series[Any]],
f: Callable[..., Series],
axis: AxisIndex = ...,
raw: _bool = ...,
args: Any = ...,
*,
result_type: Literal["reduce"],
**kwargs: Any,
) -> Series[Any]: ...
) -> Series: ...

# apply() overloads with default result_type of None, and keyword axis=1 matters
@overload
def apply(
self,
f: Callable[..., S1 | NAType],
# Use S2 (TypeVar without `default=Any`) instead of S1 due to https://github.com/python/mypy/issues/19182.
f: Callable[..., S2 | NAType],
raw: _bool = ...,
result_type: None = ...,
args: Any = ...,
*,
axis: AxisColumn,
**kwargs: Any,
) -> Series[S1]: ...
) -> Series[S2]: ...
@overload
def apply(
self,
Expand All @@ -1496,11 +1498,11 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
*,
axis: AxisColumn,
**kwargs: Any,
) -> Series[Any]: ...
) -> Series: ...
@overload
def apply(
self,
f: Callable[..., Series[Any]],
f: Callable[..., Series],
raw: _bool = ...,
result_type: None = ...,
args: Any = ...,
Expand All @@ -1513,7 +1515,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
@overload
def apply(
self,
f: Callable[..., Series[Any]],
f: Callable[..., Series],
raw: _bool = ...,
args: Any = ...,
*,
Expand All @@ -1538,7 +1540,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
) -> Self: ...
def merge(
self,
right: DataFrame | Series[Any],
right: DataFrame | Series,
how: MergeHow = ...,
on: IndexLabel | AnyArrayLike | None = ...,
left_on: IndexLabel | AnyArrayLike | None = ...,
Expand Down Expand Up @@ -1684,6 +1686,8 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
@property
def iloc(self) -> _iLocIndexerFrame[Self]: ...
@property
# mypy complains if we use Index[Any] instead of UnknownIndex here, even though
# the latter is aliased to the former ¯\_(ツ)_/¯.
def index(self) -> UnknownIndex: ...
Comment on lines 1688 to 1691
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a complete mystery to me - any ideas?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if you just used Index ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

tests/test_scalars.py:1380: error: Expression is of type "Any", not "ndarray[tuple[int, ...], dtype[numpy.bool[builtins.bool]]]"  [assert-type]
tests/test_scalars.py:1383: error: Expression is of type "Any", not "ndarray[tuple[int, ...], dtype[numpy.bool[builtins.bool]]]"  [assert-type]

the only difference i applied was

--- a/pandas-stubs/core/frame.pyi
+++ b/pandas-stubs/core/frame.pyi
@@ -1688,7 +1688,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
     @property
     # mypy complains if we use Index[Any] instead of UnknownIndex here, even though
     # the latter is aliased to the former ¯\_(ツ)_/¯.
-    def index(self) -> UnknownIndex: ...
+    def index(self) -> Index: ...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe leave it as Index in frame.pyi there, but change the __eq__() and __ne__() methods in timestamps.pyi to use Index[Any] instead of UnknownIndex

@index.setter
def index(self, idx: Index) -> None: ...
Expand Down Expand Up @@ -2012,7 +2016,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
| Callable[[DataFrame], DataFrame]
| Callable[[Any], _bool]
),
other: Scalar | UnknownSeries | DataFrame | Callable | NAType | None = ...,
other: Scalar | Series | DataFrame | Callable | NAType | None = ...,
*,
inplace: Literal[True],
axis: Axis | None = ...,
Expand All @@ -2028,7 +2032,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
| Callable[[DataFrame], DataFrame]
| Callable[[Any], _bool]
),
other: Scalar | UnknownSeries | DataFrame | Callable | NAType | None = ...,
other: Scalar | Series | DataFrame | Callable | NAType | None = ...,
*,
inplace: Literal[False] = ...,
axis: Axis | None = ...,
Expand Down
Loading