@@ -5,7 +5,7 @@ import numpy as np
55import numpy .typing as npt
66import optype as op
77import optype .numpy as onp
8- from numpy ._typing import _DTypeLike
8+ from numpy ._typing import _ArrayLike , _DTypeLike , _NestedSequence
99from scipy ._typing import AnyShape
1010
1111__all__ = ["chirp" , "gausspulse" , "sawtooth" , "square" , "sweep_poly" , "unit_impulse" ]
@@ -18,6 +18,16 @@ _ArrayLikeFloat: TypeAlias = onp.ToFloat | onp.ToFloatND
1818_Array_f8 : TypeAlias = onp .ArrayND [np .float64 ]
1919_GaussPulseTime : TypeAlias = _ArrayLikeFloat | Literal ["cutoff" ]
2020
21+ # Type vars to annotate `chirp`
22+ _NBT1 = TypeVar ("_NBT1" , bound = npt .NBitBase )
23+ _NBT2 = TypeVar ("_NBT2" , bound = npt .NBitBase )
24+ _NBT3 = TypeVar ("_NBT3" , bound = npt .NBitBase )
25+ _NBT4 = TypeVar ("_NBT4" , bound = npt .NBitBase )
26+ _NBT5 = TypeVar ("_NBT5" , bound = npt .NBitBase )
27+ _ChirpTime : TypeAlias = _ArrayLike [np .floating [_NBT1 ] | np .integer [_NBT1 ]]
28+ _ChirpScalar : TypeAlias = float | np .floating [_NBT1 ] | np .integer [_NBT1 ]
29+ _ChirpMethod : TypeAlias = Literal ["linear" , "quadratic" , "logarithmic" , "hyperbolic" ]
30+
2131def sawtooth (t : _ArrayLikeFloat , width : _ArrayLikeFloat = 1 ) -> _Array_f8 : ...
2232def square (t : _ArrayLikeFloat , duty : _ArrayLikeFloat = 0.5 ) -> _Array_f8 : ...
2333
@@ -96,16 +106,29 @@ def gausspulse(
96106 retenv : _Truthy ,
97107) -> tuple [_Array_f8 , _Array_f8 , _Array_f8 ]: ...
98108
99- # float16 -> float16, float32 -> float32, ... -> float64
109+ #
110+ @overload # Static type checking for float values
111+ def chirp (
112+ t : _ChirpTime [_NBT1 ],
113+ f0 : _ChirpScalar [_NBT2 ],
114+ t1 : _ChirpScalar [_NBT3 ],
115+ f1 : _ChirpScalar [_NBT4 ],
116+ method : _ChirpMethod = "linear" ,
117+ phi : _ChirpScalar [_NBT5 ] = 0 ,
118+ vertex_zero : op .CanBool = True ,
119+ ) -> onp .ArrayND [np .floating [_NBT1 | _NBT2 | _NBT3 | _NBT4 | _NBT5 ]]: ...
120+ @overload # Other dtypes default to np.float64
100121def chirp (
101- t : onp .ToFloatND ,
122+ t : onp .ToFloatND | _NestedSequence [ float ] ,
102123 f0 : onp .ToFloat ,
103124 t1 : onp .ToFloat ,
104125 f1 : onp .ToFloat ,
105- method : Literal [ "linear" , "quadratic" , "logarithmic" , "hyperbolic" ] = "linear" ,
126+ method : _ChirpMethod = "linear" ,
106127 phi : onp .ToFloat = 0 ,
107128 vertex_zero : op .CanBool = True ,
108- ) -> npt .NDArray [np .float16 | np .float32 | np .float64 ]: ...
129+ ) -> _Array_f8 : ...
130+
131+ #
109132def sweep_poly (
110133 t : _ArrayLikeFloat ,
111134 poly : onp .ToFloatND | np .poly1d ,
0 commit comments