Skip to content

Commit 43c11f8

Browse files
authored
Add PyTensor backend (#362)
1 parent 253545a commit 43c11f8

File tree

7 files changed

+68
-4
lines changed

7 files changed

+68
-4
lines changed

.github/workflows/run_tests.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ jobs:
1313
# cupy is not tested because it demands gpu
1414
# oneflow testing is dropped, see details at https://github.com/Oneflow-Inc/oneflow/issues/10340
1515
# paddle was switched off because of divergence with numpy in py3.10, paddle==2.6.1
16-
frameworks: ['numpy pytorch tensorflow jax']
16+
# The last pytensor release that supports python 3.8 doesn't include einsum, so we skip that combination.
17+
frameworks: ['numpy pytorch tensorflow jax', 'pytensor']
18+
exclude:
19+
- python-version: '3.8'
20+
frameworks: 'pytensor'
1721

1822
steps:
1923
- uses: actions/checkout@v4

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ Einops works with ...
307307
- [paddle](https://github.com/PaddlePaddle/Paddle) (community)
308308
- [oneflow](https://github.com/Oneflow-Inc/oneflow) (community)
309309
- [tinygrad](https://github.com/tinygrad/tinygrad) (community)
310+
- [pytensor](https://github.com/pymc-devs/pytensor) (community)
310311

311312
Additionally, einops can be used with any framework that supports
312313
[Python array API standard](https://data-apis.org/array-api/latest/API_specification/index.html),

einops/_backends.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,3 +662,58 @@ def is_float_type(self, x):
662662

663663
def einsum(self, pattern, *x):
664664
return self.tinygrad.Tensor.einsum(pattern, *x)
665+
666+
667+
class PyTensorBackend(AbstractBackend):
668+
framework_name = "pytensor"
669+
670+
def __init__(self):
671+
from pytensor import tensor
672+
673+
self.pt = tensor
674+
675+
def is_appropriate_type(self, tensor):
676+
return isinstance(tensor, self.pt.TensorVariable)
677+
678+
def is_float_type(self, x):
679+
return x.dtype in self.pt.type.float_dtypes
680+
681+
def from_numpy(self, x):
682+
return self.pt.as_tensor(x)
683+
684+
def to_numpy(self, x):
685+
return x.eval() # Will only work if there are no symbolic inputs
686+
687+
def create_symbol(self, shape):
688+
if not isinstance(shape, tuple | list):
689+
shape = (shape,)
690+
return self.pt.tensor(shape=shape)
691+
692+
def eval_symbol(self, symbol, input_dict):
693+
# input_dict is actually a list of tuple?
694+
return symbol.eval(dict(input_dict))
695+
696+
def arange(self, start, stop):
697+
return self.pt.arange(start, stop)
698+
699+
def shape(self, x):
700+
# use the static shape dimensions where known
701+
return tuple(
702+
static_dim if static_dim is not None else symbolic_dim
703+
for static_dim, symbolic_dim in zip(x.type.shape, x.shape)
704+
)
705+
706+
def stack_on_zeroth_dimension(self, tensors: list):
707+
return self.pt.stack(tensors)
708+
709+
def tile(self, x, repeats):
710+
return self.pt.tile(x, repeats)
711+
712+
def concat(self, tensors, axis: int):
713+
return self.pt.concatenate(tensors, axis=axis)
714+
715+
def add_axis(self, x, new_position):
716+
return self.pt.expand_dims(x, new_position)
717+
718+
def einsum(self, pattern, *x):
719+
return self.pt.einsum(pattern, *x)

einops/tests/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ def collect_test_backends(symbolic=False, layers=False) -> List[_backends.Abstra
8787
]
8888
else:
8989
if not layers:
90-
backend_types = []
90+
backend_types = [
91+
_backends.PyTensorBackend,
92+
]
9193
else:
9294
backend_types = [
9395
_backends.TFKerasBackend,

einops/tests/run_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def main():
3333
# "paddle": ["paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html"],
3434
"paddle": ["paddlepaddle"],
3535
"oneflow": ["oneflow==0.9.0"],
36+
"pytensor": ["pytensor"],
3637
}
3738

3839
usage = f"""

einops/tests/test_einsum.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def test_layer():
187187
"cupy",
188188
"tensorflow.keras",
189189
"paddle",
190+
"pytensor",
190191
]
191192

192193

@@ -254,7 +255,7 @@ def test_functional_symbolic():
254255
)
255256
if predicted_out_data.shape != out_shape:
256257
raise ValueError(f"Expected output shape {out_shape} but got {predicted_out_data.shape}")
257-
assert np.testing.assert_array_almost_equal(predicted_out_data, expected_out_data, decimal=5)
258+
np.testing.assert_array_almost_equal(predicted_out_data, expected_out_data, decimal=5)
258259

259260

260261
def test_functional_errors():

einops/tests/test_other.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def test_parse_shape_symbolic_ellipsis(backend):
210210
for static_shape, shape, pattern, expected in [
211211
([10, 20], [None, None], "...", dict()),
212212
([10], [None], "... a", dict(a=10)),
213-
([10, 20], [None], "... a", dict(a=20)),
213+
([10, 20], [None, None], "... a", dict(a=20)),
214214
([10, 20, 30], [None, None, None], "... a", dict(a=30)),
215215
([10, 20, 30, 40], [None, None, None, None], "... a", dict(a=40)),
216216
([10], [None], "a ...", dict(a=10)),

0 commit comments

Comments
 (0)