Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
@@ -271,6 +271,44 @@ def local_func_inv(fgraph, node):
return


@register_canonicalize
@register_specialize
@node_rewriter([Elemwise])
def local_func_inv_nan_switch(fgraph, node):
"""
Check for two consecutive switch operations that are functional inverses
and remove them from the function graph.
"""
inv_pairs = (
(aes_math.Log1mexp, aes_math.Log1mexp),
(aes.Log, aes.Exp),
(aes.Exp, aes.Log),
)
x = node.inputs[0]

if not isinstance(node.op, Elemwise):
return
if not x.owner or not isinstance(x.owner.op, Elemwise):
return

prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op

for inv_pair in inv_pairs:
if is_inverse_pair(node_op, prev_op, inv_pair):
# We don't need to copy stack trace, because the rewrite
# is trivial and maintains the earlier stack trace
ottype = node.out.dtype
inp = x.owner.inputs[0]
# Functions may have casted integer input to float
if inp.dtype != ottype:
inp = cast(inp, ottype)
return [inp]

return


@register_canonicalize
@register_specialize
@node_rewriter([Elemwise])
@@ -363,6 +401,15 @@ def local_exp_log_nan_switch(fgraph, node):
new_out = switch(le(x, 0), sub(1, exp(x)), np.asarray(np.nan, old_out.dtype))
return [new_out]

# Case for log1mexp(log1mexp(x)) -> x
if isinstance(prev_op, aes_math.Log1mexp) and isinstance(
Copy link
Member

@ricardoV94 ricardoV94 Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes more sense as a separate rewrite, analogous to the local_func_inv, something like local_func_inv_nan_switch. Otherwise looks nice. Would also need a test

node_op, aes_math.Log1mexp
):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(le(x, 0), x, np.asarray(np.nan, old_out.dtype))
return [new_out]

# Case for expm1(log1mexp(x)) -> -exp(x)
if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Expm1):
x = x.owner.inputs[0]
3 changes: 2 additions & 1 deletion tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
@@ -1773,7 +1773,7 @@ def test_local_pow_to_nested_squaring_fails_gracefully():
class TestFuncInverse:
def setup_method(self):
mode = get_default_mode()
self.mode = mode.including("local_func_inv")
self.mode = mode.including("local_func_inv", "local_func_inv_nan_switch")

def assert_func_pair_rewritten(
self, func1, func2, data, should_copy=True, is_complex=False
@@ -1817,6 +1817,7 @@ def test(self):
self.assert_func_pair_rewritten(arcsinh, sinh, dx)
self.assert_func_pair_rewritten(arctanh, tanh, dx)
self.assert_func_pair_rewritten(reciprocal, reciprocal, dx)
self.assert_func_pair_rewritten(log1mexp, log1mexp, dx)
self.assert_func_pair_rewritten(neg, neg, dx)
cx = dx + complex(0, 1) * (dx + 0.01)
self.assert_func_pair_rewritten(conj, conj, cx, is_complex=True)