Skip to content

Commit d1495a7

Browse files
committed
[471] adding case for log1mexp(log1mexp(x)) -> x
Signed-off-by: Nathaniel <[email protected]>
1 parent 081a0b4 commit d1495a7

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,13 @@ def local_exp_log_nan_switch(fgraph, node):
362362
old_out = node.outputs[0]
363363
new_out = switch(le(x, 0), sub(1, exp(x)), np.asarray(np.nan, old_out.dtype))
364364
return [new_out]
365+
366+
# Case for log1mexp(log1mexp(x)) -> x
367+
if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes_math.Log1mexp):
368+
x = x.owner.inputs[0]
369+
old_out = node.outputs[0]
370+
new_out = switch(le(x, 0), x, np.asarray(np.nan, old_out.dtype))
371+
return [new_out]
365372

366373
# Case for expm1(log1mexp(x)) -> -exp(x)
367374
if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Expm1):

0 commit comments

Comments
 (0)