Closed
Description
Describe the issue:
After #303 I was thinking it would be nice to use the new lower_triangular
and upper_triangular
tags to re-write inverse and solves involving triangular matrices to use solve_triangular
. I saw this optimization already exists as tag_solve_triangular
. For example, this graph:
import pytensor
import pytensor.tensor as pt
import numpy as np
A = pt.dmatrix('A')
b = pt.dmatrix('b')
L = pt.linalg.cholesky(A)
X = pt.linalg.solve(L, b)
f = pytensor.function([A, b], [X])
pytensor.dprint(X)
Solve{assume_a='gen', lower=False, check_finite=True} [id A]
├─ Cholesky{lower=True, destructive=False, on_error='raise'} [id B]
│ └─ A [id C]
└─ b [id D]
Gets rewritten to:
pytensor.dprint(f)
Solve{assume_a='sym', lower=True, check_finite=True} [id A] 1
├─ Cholesky{lower=True, destructive=False, on_error='raise'} [id B] 0
│ └─ A [id C]
└─ b [id D]
But as I point out in #291, Solve(assume_a='sym', lower=True)(A, b)
is not the same as solve_triangular(A, b, lower=True)
. Indeed, a lot of speed is being left on the table:
Z = np.random.normal(size=(5000, 5000))
P = Z @ Z.T
P_chol = np.linalg.cholesky(P)
eye = np.eye(5000)
%timeit f(P, eye)
>>> 3.69 s ± 48.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
from pytensor.tensor.slinalg import solve_triangular
X2 = solve_triangular(L, b, lower=True)
f2 = pytensor.function([A, b], [X2])
%timeit f2(P, eye)
>>> 1.36 s ± 15.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
And it seems like something is going wrong with the Solve()
approach, because the following test fails:
from numpy.testing import assert_allclose
x1 = f(P, eye)[0]
x2 = f2(P, eye)[0]
assert_allclose(x1 @ P_chol, eye, atol=1e-8) # fails
assert_allclose(x2 @ P_chol, eye, atol=1e-8) # passes
Am I missing something with all this?
Reproducable code example:
See above
Error message:
No response
PyTensor version information:
Pytensor version: 2.12.3
Context for the issue:
No response