Skip to content

BUG: tag_solve_triangular doesn't use a triangular solver #382

Closed
@jessegrabowski

Description

@jessegrabowski

Describe the issue:

After #303 I was thinking it would be nice to use the new lower_triangular and upper_triangular tags to re-write inverse and solves involving triangular matrices to use solve_triangular. I saw this optimization already exists as tag_solve_triangular. For example, this graph:

import pytensor
import pytensor.tensor as pt
import numpy as np

A = pt.dmatrix('A')
b = pt.dmatrix('b')
L = pt.linalg.cholesky(A)
X = pt.linalg.solve(L, b)
f = pytensor.function([A, b], [X])
pytensor.dprint(X)

Solve{assume_a='gen', lower=False, check_finite=True} [id A]
 ├─ Cholesky{lower=True, destructive=False, on_error='raise'} [id B]
 │  └─ A [id C]
 └─ b [id D]

Gets rewritten to:

pytensor.dprint(f)
Solve{assume_a='sym', lower=True, check_finite=True} [id A] 1
 ├─ Cholesky{lower=True, destructive=False, on_error='raise'} [id B] 0
 │  └─ A [id C]
 └─ b [id D]

But as I point out in #291, Solve(assume_a='sym', lower=True)(A, b) is not the same as solve_triangular(A, b, lower=True). Indeed, a lot of speed is being left on the table:

Z = np.random.normal(size=(5000, 5000))
P = Z @ Z.T
P_chol = np.linalg.cholesky(P)
eye = np.eye(5000)
%timeit f(P, eye)
>>> 3.69 s ± 48.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

from pytensor.tensor.slinalg import solve_triangular
X2 = solve_triangular(L, b, lower=True)
f2 = pytensor.function([A, b], [X2])
%timeit f2(P, eye)
>>> 1.36 s ± 15.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

And it seems like something is going wrong with the Solve() approach, because the following test fails:

from numpy.testing import assert_allclose
x1 = f(P, eye)[0]
x2 = f2(P, eye)[0]
assert_allclose(x1 @ P_chol, eye, atol=1e-8) # fails
assert_allclose(x2 @ P_chol, eye, atol=1e-8) # passes

Am I missing something with all this?

Reproducable code example:

See above

Error message:

No response

PyTensor version information:

Pytensor version: 2.12.3

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions