Skip to content

Commit fe7e77e

Browse files
committed
feat(RFC): Adds agg, field utility classes
`field` proposed in vega#3239 (comment) `agg` was developed during vega#3427 (comment) as a solution to part of vega#3476
1 parent 679a7ce commit fe7e77e

File tree

1 file changed

+344
-0
lines changed

1 file changed

+344
-0
lines changed

altair/vegalite/v5/_api_rfc.py

Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
"""
2+
Request for comment on additions to `api.py`.
3+
4+
Ideally these would be introduced *after* cleaning up the top-level namespace.
5+
6+
Actual runtime dependencies:
7+
- altair.utils.core
8+
- altair.utils.schemapi
9+
10+
The rest are to define aliases only.
11+
"""
12+
13+
from __future__ import annotations
14+
15+
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Union
16+
17+
from typing_extensions import TypeAlias
18+
19+
from altair.utils.core import TYPECODE_MAP as _TYPE_CODE
20+
from altair.utils.core import parse_shorthand as _parse
21+
from altair.utils.schemapi import Optional, SchemaBase, Undefined
22+
from altair.vegalite.v5.api import Parameter, SelectionPredicateComposition
23+
from altair.vegalite.v5.schema._typing import (
24+
BinnedTimeUnit_T,
25+
MultiTimeUnit_T,
26+
SingleTimeUnit_T,
27+
Type_T,
28+
)
29+
from altair.vegalite.v5.schema.core import (
30+
FieldEqualPredicate,
31+
FieldGTEPredicate,
32+
FieldGTPredicate,
33+
FieldLTEPredicate,
34+
FieldLTPredicate,
35+
FieldOneOfPredicate,
36+
FieldRangePredicate,
37+
FieldValidPredicate,
38+
)
39+
40+
if TYPE_CHECKING:
41+
from altair.utils.core import DataFrameLike
42+
from altair.vegalite.v5.schema._typing import AggregateOp_T
43+
from altair.vegalite.v5.schema.core import Predicate
44+
45+
__all__ = ["agg", "field"]
46+
47+
EncodeType: TypeAlias = Union[Type_T, Literal["O", "N", "Q", "T", "G"], None]
48+
AnyTimeUnit: TypeAlias = Union[MultiTimeUnit_T, BinnedTimeUnit_T, SingleTimeUnit_T]
49+
TimeUnitType: TypeAlias = Optional[Union[Dict[str, Any], SchemaBase, AnyTimeUnit]]
50+
RangeType: TypeAlias = Union[
51+
Dict[str, Any],
52+
Parameter,
53+
SchemaBase,
54+
Sequence[Union[Dict[str, Any], None, float, Parameter, SchemaBase]],
55+
]
56+
ValueType: TypeAlias = Union[str, bool, float, Dict[str, Any], Parameter, SchemaBase]
57+
58+
59+
_ENCODINGS = frozenset(
60+
(
61+
"ordinal",
62+
"O",
63+
"nominal",
64+
"N",
65+
"quantitative",
66+
"Q",
67+
"temporal",
68+
"T",
69+
"geojson",
70+
"G",
71+
None,
72+
)
73+
)
74+
75+
76+
def _parse_aggregate(
77+
aggregate: AggregateOp_T, name: str | None, encode_type: EncodeType, /
78+
) -> dict[str, Any]:
79+
if encode_type in _ENCODINGS:
80+
enc = f":{_TYPE_CODE.get(s, s)}" if (s := encode_type) else ""
81+
return _parse(f"{aggregate}({name or ''}){enc}")
82+
else:
83+
msg = (
84+
f"Expected a short/long-form encoding type, but got {encode_type!r}.\n\n"
85+
f"Try passing one of the following to `type`:\n"
86+
f"{', '.join(sorted(f'{e!r}' for e in _ENCODINGS))}."
87+
)
88+
raise TypeError(msg)
89+
90+
91+
def _wrap_composition(predicate: Predicate, /) -> SelectionPredicateComposition:
92+
return SelectionPredicateComposition(predicate.to_dict())
93+
94+
95+
class agg:
96+
"""Utility class providing autocomplete for shorthand.
97+
98+
Functional alternative to shorthand mini-language.
99+
"""
100+
101+
def __new__( # type: ignore[misc]
102+
cls, shorthand: dict[str, Any] | str, /, data: DataFrameLike | None = None
103+
) -> dict[str, Any]:
104+
return _parse(shorthand=shorthand, data=data)
105+
106+
@classmethod
107+
def argmin(
108+
cls, col_name: str | None = None, /, type: EncodeType = None
109+
) -> dict[str, Any]:
110+
return _parse_aggregate("argmin", col_name, type)
111+
112+
@classmethod
113+
def argmax(
114+
cls, col_name: str | None = None, /, type: EncodeType = None
115+
) -> dict[str, Any]:
116+
return _parse_aggregate("argmax", col_name, type)
117+
118+
@classmethod
119+
def average(
120+
cls, col_name: str | None = None, /, type: EncodeType = None
121+
) -> dict[str, Any]:
122+
return _parse_aggregate("average", col_name, type)
123+
124+
@classmethod
125+
def count(
126+
cls, col_name: str | None = None, /, type: EncodeType = "Q"
127+
) -> dict[str, Any]:
128+
return _parse_aggregate("count", col_name, type)
129+
130+
@classmethod
131+
def distinct(
132+
cls, col_name: str | None = None, /, type: EncodeType = None
133+
) -> dict[str, Any]:
134+
return _parse_aggregate("distinct", col_name, type)
135+
136+
@classmethod
137+
def max(
138+
cls, col_name: str | None = None, /, type: EncodeType = None
139+
) -> dict[str, Any]:
140+
return _parse_aggregate("max", col_name, type)
141+
142+
@classmethod
143+
def mean(
144+
cls, col_name: str | None = None, /, type: EncodeType = None
145+
) -> dict[str, Any]:
146+
return _parse_aggregate("mean", col_name, type)
147+
148+
@classmethod
149+
def median(
150+
cls, col_name: str | None = None, /, type: EncodeType = None
151+
) -> dict[str, Any]:
152+
return _parse_aggregate("median", col_name, type)
153+
154+
@classmethod
155+
def min(
156+
cls, col_name: str | None = None, /, type: EncodeType = None
157+
) -> dict[str, Any]:
158+
return _parse_aggregate("min", col_name, type)
159+
160+
@classmethod
161+
def missing(
162+
cls, col_name: str | None = None, /, type: EncodeType = None
163+
) -> dict[str, Any]:
164+
return _parse_aggregate("missing", col_name, type)
165+
166+
@classmethod
167+
def product(
168+
cls, col_name: str | None = None, /, type: EncodeType = None
169+
) -> dict[str, Any]:
170+
return _parse_aggregate("product", col_name, type)
171+
172+
@classmethod
173+
def q1(
174+
cls, col_name: str | None = None, /, type: EncodeType = None
175+
) -> dict[str, Any]:
176+
return _parse_aggregate("q1", col_name, type)
177+
178+
@classmethod
179+
def q3(
180+
cls, col_name: str | None = None, /, type: EncodeType = None
181+
) -> dict[str, Any]:
182+
return _parse_aggregate("q3", col_name, type)
183+
184+
@classmethod
185+
def ci0(
186+
cls, col_name: str | None = None, /, type: EncodeType = None
187+
) -> dict[str, Any]:
188+
return _parse_aggregate("ci0", col_name, type)
189+
190+
@classmethod
191+
def ci1(
192+
cls, col_name: str | None = None, /, type: EncodeType = None
193+
) -> dict[str, Any]:
194+
return _parse_aggregate("ci1", col_name, type)
195+
196+
@classmethod
197+
def stderr(
198+
cls, col_name: str | None = None, /, type: EncodeType = None
199+
) -> dict[str, Any]:
200+
return _parse_aggregate("stderr", col_name, type)
201+
202+
@classmethod
203+
def stdev(
204+
cls, col_name: str | None = None, /, type: EncodeType = None
205+
) -> dict[str, Any]:
206+
return _parse_aggregate("stdev", col_name, type)
207+
208+
@classmethod
209+
def stdevp(
210+
cls, col_name: str | None = None, /, type: EncodeType = None
211+
) -> dict[str, Any]:
212+
return _parse_aggregate("stdevp", col_name, type)
213+
214+
@classmethod
215+
def sum(
216+
cls, col_name: str | None = None, /, type: EncodeType = None
217+
) -> dict[str, Any]:
218+
return _parse_aggregate("sum", col_name, type)
219+
220+
@classmethod
221+
def valid(
222+
cls, col_name: str | None = None, /, type: EncodeType = None
223+
) -> dict[str, Any]:
224+
return _parse_aggregate("valid", col_name, type)
225+
226+
@classmethod
227+
def values(
228+
cls, col_name: str | None = None, /, type: EncodeType = None
229+
) -> dict[str, Any]:
230+
return _parse_aggregate("values", col_name, type)
231+
232+
@classmethod
233+
def variance(
234+
cls, col_name: str | None = None, /, type: EncodeType = None
235+
) -> dict[str, Any]:
236+
return _parse_aggregate("variance", col_name, type)
237+
238+
@classmethod
239+
def variancep(
240+
cls, col_name: str | None = None, /, type: EncodeType = None
241+
) -> dict[str, Any]:
242+
return _parse_aggregate("variancep", col_name, type)
243+
244+
@classmethod
245+
def exponential(
246+
cls, col_name: str | None = None, /, type: EncodeType = None
247+
) -> dict[str, Any]:
248+
return _parse_aggregate("exponential", col_name, type)
249+
250+
@classmethod
251+
def exponentialb(
252+
cls, col_name: str | None = None, /, type: EncodeType = None
253+
) -> dict[str, Any]:
254+
return _parse_aggregate("exponentialb", col_name, type)
255+
256+
257+
class field:
258+
"""Utility class for field predicates and shorthand parsing.
259+
260+
Examples
261+
--------
262+
>>> field("Origin")
263+
{'field': 'Origin'}
264+
265+
>>> field("Origin:N")
266+
{'field': 'Origin', 'type': 'nominal'}
267+
268+
>>> field.one_of("Origin", "Japan", "Europe")
269+
SelectionPredicateComposition({'field': 'Origin', 'oneOf': ['Japan', 'Europe']})
270+
"""
271+
272+
def __new__( # type: ignore[misc]
273+
cls, shorthand: dict[str, Any] | str, /, data: DataFrameLike | None = None
274+
) -> dict[str, Any]:
275+
return _parse(shorthand=shorthand, data=data)
276+
277+
@classmethod
278+
def one_of(
279+
cls,
280+
field: str,
281+
/,
282+
*values: bool | float | dict[str, Any] | SchemaBase,
283+
timeUnit: TimeUnitType = Undefined,
284+
) -> SelectionPredicateComposition:
285+
tp: type[Any] = type(values[0])
286+
if all(isinstance(v, tp) for v in values):
287+
vals: Sequence[Any] = values
288+
p = FieldOneOfPredicate(field=field, oneOf=vals, timeUnit=timeUnit)
289+
return _wrap_composition(p)
290+
else:
291+
msg = (
292+
f"Expected all `values` to be of the same type, but got:\n"
293+
f"{tuple(f"{type(v).__name__}" for v in values)!r}"
294+
)
295+
raise TypeError(msg)
296+
297+
@classmethod
298+
def eq(
299+
cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined
300+
) -> SelectionPredicateComposition:
301+
p = FieldEqualPredicate(field=field, equal=value, timeUnit=timeUnit)
302+
return _wrap_composition(p)
303+
304+
@classmethod
305+
def lt(
306+
cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined
307+
) -> SelectionPredicateComposition:
308+
p = FieldLTPredicate(field=field, lt=value, timeUnit=timeUnit)
309+
return _wrap_composition(p)
310+
311+
@classmethod
312+
def lte(
313+
cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined
314+
) -> SelectionPredicateComposition:
315+
p = FieldLTEPredicate(field=field, lte=value, timeUnit=timeUnit)
316+
return _wrap_composition(p)
317+
318+
@classmethod
319+
def gt(
320+
cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined
321+
) -> SelectionPredicateComposition:
322+
p = FieldGTPredicate(field=field, gt=value, timeUnit=timeUnit)
323+
return _wrap_composition(p)
324+
325+
@classmethod
326+
def gte(
327+
cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined
328+
) -> SelectionPredicateComposition:
329+
p = FieldGTEPredicate(field=field, gte=value, timeUnit=timeUnit)
330+
return _wrap_composition(p)
331+
332+
@classmethod
333+
def valid(
334+
cls, field: str, value: bool, /, *, timeUnit: TimeUnitType = Undefined
335+
) -> SelectionPredicateComposition:
336+
p = FieldValidPredicate(field=field, valid=value, timeUnit=timeUnit)
337+
return _wrap_composition(p)
338+
339+
@classmethod
340+
def range(
341+
cls, field: str, value: RangeType, /, *, timeUnit: TimeUnitType = Undefined
342+
) -> SelectionPredicateComposition:
343+
p = FieldRangePredicate(field=field, range=value, timeUnit=timeUnit)
344+
return _wrap_composition(p)

0 commit comments

Comments
 (0)