Skip to content

Commit daf142e

Browse files
committed
refactor: Rename and move OneOrSeq
Planned in vega#3427 (comment) Will allow for more reuse
1 parent fd78fd6 commit daf142e

File tree

4 files changed

+54
-31
lines changed

4 files changed

+54
-31
lines changed

altair/vegalite/v5/api.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Union,
1515
TYPE_CHECKING,
1616
TypeVar,
17-
Sequence,
1817
Protocol,
1918
)
2019
from typing_extensions import TypeAlias
@@ -45,10 +44,6 @@
4544
from typing import TypedDict
4645
else:
4746
from typing_extensions import TypedDict
48-
if sys.version_info >= (3, 12):
49-
from typing import TypeAliasType
50-
else:
51-
from typing_extensions import TypeAliasType
5247

5348
if TYPE_CHECKING:
5449
from ...utils.core import DataFrameLike
@@ -125,26 +120,12 @@
125120
AggregateOp_T,
126121
MultiTimeUnit_T,
127122
SingleTimeUnit_T,
123+
OneOrSeq,
128124
)
129125

130126

131127
ChartDataType: TypeAlias = Optional[Union[DataType, core.Data, str, core.Generator]]
132128
_TSchemaBase = TypeVar("_TSchemaBase", bound=core.SchemaBase)
133-
_T = TypeVar("_T")
134-
_OneOrSeq = TypeAliasType("_OneOrSeq", Union[_T, Sequence[_T]], type_params=(_T,))
135-
"""One of ``_T`` specified type(s), or a `Sequence` of such.
136-
137-
Examples
138-
--------
139-
The parameters ``short``, ``long`` accept the same range of types::
140-
141-
# ruff: noqa: UP006, UP007
142-
143-
def func(
144-
short: _OneOrSeq[str | bool | float],
145-
long: Union[str, bool, float, Sequence[Union[str, bool, float]],
146-
): ...
147-
"""
148129

149130

150131
# ------------------------------------------------------------------------
@@ -571,7 +552,7 @@ class _ConditionExtra(TypedDict, closed=True, total=False): # type: ignore[call
571552
param: Parameter | str
572553
test: _TestPredicateType
573554
value: Any
574-
__extra_items__: _StatementType | _OneOrSeq[_LiteralValue]
555+
__extra_items__: _StatementType | OneOrSeq[_LiteralValue]
575556

576557

577558
_Condition: TypeAlias = _ConditionExtra

altair/vegalite/v5/schema/_typing.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
from __future__ import annotations
66

7-
from typing import Any, Literal, Mapping
7+
from typing import Any, Literal, Mapping, Sequence, TypeVar, Union
88

9-
from typing_extensions import TypeAlias
9+
from typing_extensions import TypeAlias, TypeAliasType
1010

1111
__all__ = [
1212
"AggregateOp_T",
@@ -32,6 +32,7 @@
3232
"Mark_T",
3333
"MultiTimeUnit_T",
3434
"NonArgAggregateOp_T",
35+
"OneOrSeq",
3536
"Orient_T",
3637
"Orientation_T",
3738
"ProjectionType_T",
@@ -60,6 +61,22 @@
6061
]
6162

6263

64+
T = TypeVar("T")
65+
OneOrSeq = TypeAliasType("OneOrSeq", Union[T, Sequence[T]], type_params=(T,))
66+
"""One of ``T`` specified type(s), or a `Sequence` of such.
67+
68+
Examples
69+
--------
70+
The parameters ``short``, ``long`` accept the same range of types::
71+
72+
# ruff: noqa: UP006, UP007
73+
74+
def func(
75+
short: OneOrSeq[str | bool | float],
76+
long: Union[str, bool, float, Sequence[Union[str, bool, float]],
77+
): ...
78+
"""
79+
6380
Map: TypeAlias = Mapping[str, Any]
6481
AggregateOp_T: TypeAlias = Literal[
6582
"argmax",

tools/generate_schema_wrapper.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,26 @@ def encode({encode_method_args}) -> Self:
232232
return copy
233233
'''
234234

235+
# NOTE: Not yet reasonable to generalize `TypeAliasType`, `TypeVar`
236+
# Revisit if this starts to become more common
237+
TYPING_EXTRA: Final = '''
238+
T = TypeVar("T")
239+
OneOrSeq = TypeAliasType("OneOrSeq", Union[T, Sequence[T]], type_params=(T,))
240+
"""One of ``T`` specified type(s), or a `Sequence` of such.
241+
242+
Examples
243+
--------
244+
The parameters ``short``, ``long`` accept the same range of types::
245+
246+
# ruff: noqa: UP006, UP007
247+
248+
def func(
249+
short: OneOrSeq[str | bool | float],
250+
long: Union[str, bool, float, Sequence[Union[str, bool, float]],
251+
): ...
252+
"""
253+
'''
254+
235255

236256
class SchemaGenerator(codegen.SchemaGenerator):
237257
schema_class_template = textwrap.dedent(
@@ -815,7 +835,9 @@ def vegalite_main(skip_download: bool = False) -> None:
815835
)
816836
print(msg)
817837
TypeAliasTracer.update_aliases(("Map", "Mapping[str, Any]"))
818-
TypeAliasTracer.write_module(fp_typing, header=HEADER)
838+
TypeAliasTracer.write_module(
839+
fp_typing, "OneOrSeq", header=HEADER, extra=TYPING_EXTRA
840+
)
819841
# Write the pre-generated modules
820842
for fp, contents in files.items():
821843
print(f"Writing\n {schemafile!s}\n ->{fp!s}")

tools/schemapi/utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def __init__(
7171
self._aliases: dict[str, str] = {}
7272
self._imports: Sequence[str] = (
7373
"from __future__ import annotations\n",
74-
"from typing import Literal, Mapping, Any",
75-
"from typing_extensions import TypeAlias",
74+
"from typing import Any, Literal, Mapping, TypeVar, Sequence, Union",
75+
"from typing_extensions import TypeAlias, TypeAliasType",
7676
)
7777
self._cmd_check: list[str] = ["--fix"]
7878
self._cmd_format: Sequence[str] = ruff_format or ()
@@ -141,28 +141,31 @@ def is_cached(self, tp: str, /) -> bool:
141141
return tp in self._literals_invert or tp in self._literals
142142

143143
def write_module(
144-
self, fp: Path, *extra_imports: str, header: LiteralString
144+
self, fp: Path, *extra_all: str, header: LiteralString, extra: LiteralString
145145
) -> None:
146146
"""Write all collected `TypeAlias`'s to `fp`.
147147
148148
Parameters
149149
----------
150150
fp
151151
Path to new module.
152-
*extra_imports
153-
Follows `self._imports` block.
152+
*extra_all
153+
Any manually spelled types to be exported.
154154
header
155155
`tools.generate_schema_wrapper.HEADER`.
156+
extra
157+
`tools.generate_schema_wrapper.TYPING_EXTRA`.
156158
"""
157159
ruff_format = ["ruff", "format", fp]
158160
if self._cmd_format:
159161
ruff_format.extend(self._cmd_format)
160162
commands = (["ruff", "check", fp, *self._cmd_check], ruff_format)
161-
static = (header, "\n", *self._imports, *extra_imports, "\n\n")
163+
static = (header, "\n", *self._imports, "\n\n")
162164
self.update_aliases(*sorted(self._literals.items(), key=itemgetter(0)))
165+
all_ = [*iter(self._aliases), *extra_all]
163166
it = chain(
164167
static,
165-
[f"__all__ = {list(self._aliases)}", "\n\n"],
168+
[f"__all__ = {all_}", "\n\n", extra],
166169
self.generate_aliases(),
167170
)
168171
fp.write_text("\n".join(it), encoding="utf-8")

0 commit comments

Comments
 (0)