Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 34230a2

Browse files
committedMay 14, 2023
Fuse hyp2f1 grads
1 parent 9d31e8f commit 34230a2

File tree

3 files changed

+285
-135
lines changed

3 files changed

+285
-135
lines changed
 

‎pytensor/scalar/math.py‎

Lines changed: 188 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
"""
66

77
import os
8+
from functools import reduce
89
from textwrap import dedent
10+
from typing import Tuple
911

1012
import numpy as np
1113
import scipy.special
@@ -683,23 +685,28 @@ def __hash__(self):
683685
gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")
684686

685687

686-
def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name):
687-
init = [as_scalar(x) for x in init]
688+
def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name, loop_op=ScalarLoop):
689+
init = [as_scalar(x) if x is not None else None for x in init]
688690
constant = [as_scalar(x) for x in constant]
691+
689692
# Create dummy types, in case some variables have the same initial form
690-
init_ = [x.type() for x in init]
693+
init_ = [x.type() if x is not None else None for x in init]
691694
constant_ = [x.type() for x in constant]
692695
update_, until_ = inner_loop_fn(*init_, *constant_)
693-
op = ScalarLoop(
696+
697+
# Filter Nones
698+
init = [i for i in init if i is not None]
699+
init_ = [i for i in init_ if i is not None]
700+
update_ = [u for u in update_ if u is not None]
701+
op = loop_op(
694702
init=init_,
695703
constant=constant_,
696704
update=update_,
697705
until=until_,
698706
until_condition_failed="warn",
699707
name=name,
700708
)
701-
S, *_ = op(n_steps, *init, *constant)
702-
return S
709+
return op(n_steps, *init, *constant)
703710

704711

705712
def gammainc_grad(k, x):
@@ -740,7 +747,7 @@ def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x):
740747

741748
init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n]
742749
constant = [log_x]
743-
sum_a = _make_scalar_loop(
750+
sum_a, *_ = _make_scalar_loop(
744751
max_iters, init, constant, inner_loop_a, name="gammainc_grad_a"
745752
)
746753

@@ -827,7 +834,7 @@ def inner_loop_a(sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac, x):
827834

828835
init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac]
829836
constant = [x]
830-
sum_a = _make_scalar_loop(
837+
sum_a, *_ = _make_scalar_loop(
831838
n_steps, init, constant, inner_loop_a, name="gammaincc_grad_a"
832839
)
833840
grad_approx_a = (
@@ -870,7 +877,7 @@ def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x):
870877

871878
init = [sum_b0, log_s, s_sign, log_delta, n]
872879
constant = [k, log_x]
873-
sum_b = _make_scalar_loop(
880+
sum_b, *_ = _make_scalar_loop(
874881
max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b"
875882
)
876883
grad_approx_b = (
@@ -1540,7 +1547,7 @@ def inner_loop(
15401547

15411548
init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n]
15421549
constant = [f, p, q, K, dK]
1543-
grad = _make_scalar_loop(
1550+
grad, *_ = _make_scalar_loop(
15441551
max_iters, init, constant, inner_loop, name="betainc_grad"
15451552
)
15461553
return grad
@@ -1579,10 +1586,11 @@ def impl(self, a, b, c, z):
15791586
def grad(self, inputs, grads):
15801587
a, b, c, z = inputs
15811588
(gz,) = grads
1589+
grad_a, grad_b, grad_c = hyp2f1_grad(a, b, c, z, wrt=[0, 1, 2])
15821590
return [
1583-
gz * hyp2f1_grad(a, b, c, z, wrt=0),
1584-
gz * hyp2f1_grad(a, b, c, z, wrt=1),
1585-
gz * hyp2f1_grad(a, b, c, z, wrt=2),
1591+
gz * grad_a,
1592+
gz * grad_b,
1593+
gz * grad_c,
15861594
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
15871595
]
15881596

@@ -1598,92 +1606,55 @@ def _unsafe_sign(x):
15981606
return switch(x > 0, 1, -1)
15991607

16001608

1601-
def hyp2f1_grad(a, b, c, z, wrt: int):
1602-
dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32")
1603-
1604-
def check_2f1_converges(a, b, c, z):
1605-
def is_nonpositive_integer(x):
1606-
if x.type.dtype not in integer_types:
1607-
return eq(floor(x), x) & (x <= 0)
1608-
else:
1609-
return x <= 0
1609+
class Grad2F1Loop(ScalarLoop):
1610+
"""Subclass of ScalarLoop for easier targetting in rewrites"""
16101611

1611-
a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0)
1612-
num_terms = switch(
1613-
a_is_polynomial,
1614-
floor(scalar_abs(a)).astype("int64"),
1615-
0,
1616-
)
16171612

1618-
b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b) >= num_terms)
1619-
num_terms = switch(
1620-
b_is_polynomial,
1621-
floor(scalar_abs(b)).astype("int64"),
1622-
num_terms,
1623-
)
1613+
def _grad_2f1_loop(a, b, c, z, *, skip_loop, wrt, dtype):
1614+
"""
1615+
Notes
1616+
-----
1617+
The algorithm can be derived by looking at the ratio of two successive terms in the series
1618+
β_{k+1}/β_{k} = A(k)/B(k)
1619+
β_{k+1} = A(k)/B(k) * β_{k}
1620+
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
16241621
1625-
is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms)
1626-
is_polynomial = a_is_polynomial | b_is_polynomial
1622+
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
16271623
1628-
return (~is_undefined) & (
1629-
is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b)))
1630-
)
1624+
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1625+
by dropping the respective term
1626+
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1627+
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1628+
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
16311629
1632-
def compute_grad_2f1(a, b, c, z, wrt, skip_loop):
1633-
"""
1634-
Notes
1635-
-----
1636-
The algorithm can be derived by looking at the ratio of two successive terms in the series
1637-
β_{k+1}/β_{k} = A(k)/B(k)
1638-
β_{k+1} = A(k)/B(k) * β_{k}
1639-
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1640-
1641-
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1642-
1643-
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1644-
by dropping the respective term
1645-
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1646-
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1647-
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1648-
1649-
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1650-
tracking their signs.
1651-
"""
1652-
1653-
wrt_a = wrt_b = False
1654-
if wrt == 0:
1655-
wrt_a = True
1656-
elif wrt == 1:
1657-
wrt_b = True
1658-
elif wrt != 2:
1659-
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
1660-
1661-
min_steps = np.array(
1662-
10, dtype="int32"
1663-
) # https://github.com/stan-dev/math/issues/2857
1664-
max_steps = switch(
1665-
skip_loop, np.array(0, dtype="int32"), np.array(int(1e6), dtype="int32")
1666-
)
1667-
precision = np.array(1e-14, dtype=config.floatX)
1630+
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1631+
tracking their signs.
1632+
"""
16681633

1669-
grad = np.array(0, dtype=dtype)
1634+
min_steps = np.array(
1635+
10, dtype="int32"
1636+
) # https://github.com/stan-dev/math/issues/2857
1637+
max_steps = switch(
1638+
skip_loop, np.array(0, dtype="int32"), np.array(int(1e6), dtype="int32")
1639+
)
1640+
precision = np.array(1e-14, dtype=config.floatX)
16701641

1671-
log_g = np.array(-np.inf, dtype=dtype)
1672-
log_g_sign = np.array(1, dtype="int8")
1642+
grads = [np.array(0, dtype=dtype) if i in wrt else None for i in range(3)]
1643+
log_gs = [np.array(-np.inf, dtype=dtype) if i in wrt else None for i in range(3)]
1644+
log_gs_signs = [np.array(1, dtype="int8") if i in wrt else None for i in range(3)]
16731645

1674-
log_t = np.array(0.0, dtype=dtype)
1675-
log_t_sign = np.array(1, dtype="int8")
1646+
log_t = np.array(0.0, dtype=dtype)
1647+
log_t_sign = np.array(1, dtype="int8")
16761648

1677-
log_z = log(scalar_abs(z))
1678-
sign_z = _unsafe_sign(z)
1649+
log_z = log(scalar_abs(z))
1650+
sign_z = _unsafe_sign(z)
16791651

1680-
sign_zk = sign_z
1681-
k = np.array(0, dtype="int32")
1652+
sign_zk = sign_z
1653+
k = np.array(0, dtype="int32")
16821654

1683-
def inner_loop(
1684-
grad,
1685-
log_g,
1686-
log_g_sign,
1655+
def inner_loop(*args):
1656+
(
1657+
*grads_vars,
16871658
log_t,
16881659
log_t_sign,
16891660
sign_zk,
@@ -1693,65 +1664,147 @@ def inner_loop(
16931664
c,
16941665
log_z,
16951666
sign_z,
1696-
):
1697-
p = (a + k) * (b + k) / ((c + k) * (k + 1))
1698-
if p.type.dtype != dtype:
1699-
p = p.astype(dtype)
1700-
1701-
term = log_g_sign * log_t_sign * exp(log_g - log_t)
1702-
if wrt_a:
1703-
term += reciprocal(a + k)
1704-
elif wrt_b:
1705-
term += reciprocal(b + k)
1706-
else:
1707-
term -= reciprocal(c + k)
1667+
) = args
1668+
1669+
(
1670+
grad_a,
1671+
grad_b,
1672+
grad_c,
1673+
log_g_a,
1674+
log_g_b,
1675+
log_g_c,
1676+
log_g_sign_a,
1677+
log_g_sign_b,
1678+
log_g_sign_c,
1679+
) = grads_vars
1680+
1681+
p = (a + k) * (b + k) / ((c + k) * (k + 1))
1682+
if p.type.dtype != dtype:
1683+
p = p.astype(dtype)
1684+
1685+
# If p==0, don't update grad and get out of while loop next
1686+
p_zero = eq(p, 0)
1687+
1688+
if 0 in wrt:
1689+
term_a = log_g_sign_a * log_t_sign * exp(log_g_a - log_t)
1690+
term_a += reciprocal(a + k)
1691+
if term_a.type.dtype != dtype:
1692+
term_a = term_a.astype(dtype)
1693+
if 1 in wrt:
1694+
term_b = log_g_sign_b * log_t_sign * exp(log_g_b - log_t)
1695+
term_b += reciprocal(b + k)
1696+
if term_b.type.dtype != dtype:
1697+
term_b = term_b.astype(dtype)
1698+
if 2 in wrt:
1699+
term_c = log_g_sign_c * log_t_sign * exp(log_g_c - log_t)
1700+
term_c -= reciprocal(c + k)
1701+
if term_c.type.dtype != dtype:
1702+
term_c = term_c.astype(dtype)
1703+
1704+
log_t = log_t + log(scalar_abs(p)) + log_z
1705+
log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8")
1706+
1707+
grads = [None] * 3
1708+
log_gs = [None] * 3
1709+
log_gs_signs = [None] * 3
1710+
grad_incs = [None] * 3
1711+
1712+
if 0 in wrt:
1713+
log_g_a = log_t + log(scalar_abs(term_a))
1714+
log_g_sign_a = (_unsafe_sign(term_a) * log_t_sign).astype("int8")
1715+
grad_inc_a = log_g_sign_a * exp(log_g_a) * sign_zk
1716+
grads[0] = switch(p_zero, grad_a, grad_a + grad_inc_a)
1717+
log_gs[0] = log_g_a
1718+
log_gs_signs[0] = log_g_sign_a
1719+
grad_incs[0] = grad_inc_a
1720+
if 1 in wrt:
1721+
log_g_b = log_t + log(scalar_abs(term_b))
1722+
log_g_sign_b = (_unsafe_sign(term_b) * log_t_sign).astype("int8")
1723+
grad_inc_b = log_g_sign_b * exp(log_g_b) * sign_zk
1724+
grads[1] = switch(p_zero, grad_b, grad_b + grad_inc_b)
1725+
log_gs[1] = log_g_b
1726+
log_gs_signs[1] = log_g_sign_b
1727+
grad_incs[1] = grad_inc_b
1728+
if 2 in wrt:
1729+
log_g_c = log_t + log(scalar_abs(term_c))
1730+
log_g_sign_c = (_unsafe_sign(term_c) * log_t_sign).astype("int8")
1731+
grad_inc_c = log_g_sign_c * exp(log_g_c) * sign_zk
1732+
grads[2] = switch(p_zero, grad_c, grad_c + grad_inc_c)
1733+
log_gs[2] = log_g_c
1734+
log_gs_signs[2] = log_g_sign_c
1735+
grad_incs[2] = grad_inc_c
1736+
1737+
sign_zk *= sign_z
1738+
k += 1
1739+
1740+
abs_grad_incs = [
1741+
scalar_abs(grad_inc) for grad_inc in grad_incs if grad_inc is not None
1742+
]
1743+
if len(grad_incs) == 1:
1744+
[max_abs_grad_inc] = grad_incs
1745+
else:
1746+
max_abs_grad_inc = reduce(scalar_maximum, abs_grad_incs)
17081747

1709-
if term.type.dtype != dtype:
1710-
term = term.astype(dtype)
1748+
return (
1749+
(*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k),
1750+
(eq(p, 0) | ((k > min_steps) & (max_abs_grad_inc <= precision))),
1751+
)
17111752

1712-
log_t = log_t + log(scalar_abs(p)) + log_z
1713-
log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8")
1714-
log_g = log_t + log(scalar_abs(term))
1715-
log_g_sign = (_unsafe_sign(term) * log_t_sign).astype("int8")
1753+
init = [*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k]
1754+
constant = [a, b, c, log_z, sign_z]
1755+
loop_outs = _make_scalar_loop(
1756+
max_steps, init, constant, inner_loop, name="hyp2f1_grad", loop_op=Grad2F1Loop
1757+
)
1758+
return loop_outs[: len(wrt)]
17161759

1717-
g_current = log_g_sign * exp(log_g) * sign_zk
17181760

1719-
# If p==0, don't update grad and get out of while loop next
1720-
grad = switch(
1721-
eq(p, 0),
1722-
grad,
1723-
grad + g_current,
1724-
)
1761+
def hyp2f1_grad(a, b, c, z, wrt: Tuple[int, ...]):
1762+
dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32")
17251763

1726-
sign_zk *= sign_z
1727-
k += 1
1764+
def check_2f1_converges(a, b, c, z):
1765+
def is_nonpositive_integer(x):
1766+
if x.type.dtype not in integer_types:
1767+
return eq(floor(x), x) & (x <= 0)
1768+
else:
1769+
return x <= 0
17281770

1729-
return (
1730-
(grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k),
1731-
(eq(p, 0) | ((k > min_steps) & (scalar_abs(g_current) <= precision))),
1732-
)
1771+
a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0)
1772+
num_terms = switch(
1773+
a_is_polynomial,
1774+
floor(scalar_abs(a)).astype("int64"),
1775+
0,
1776+
)
17331777

1734-
init = [grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k]
1735-
constant = [a, b, c, log_z, sign_z]
1736-
grad = _make_scalar_loop(
1737-
max_steps, init, constant, inner_loop, name="hyp2f1_grad"
1778+
b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b) >= num_terms)
1779+
num_terms = switch(
1780+
b_is_polynomial,
1781+
floor(scalar_abs(b)).astype("int64"),
1782+
num_terms,
17381783
)
17391784

1740-
return switch(
1741-
eq(z, 0),
1742-
0,
1743-
grad,
1785+
is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms)
1786+
is_polynomial = a_is_polynomial | b_is_polynomial
1787+
1788+
return (~is_undefined) & (
1789+
is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b)))
17441790
)
17451791

17461792
# We have to pass the converges flag to interrupt the loop, as the switch is not lazy
17471793
z_is_zero = eq(z, 0)
17481794
converges = check_2f1_converges(a, b, c, z)
1749-
return switch(
1750-
z_is_zero,
1751-
0,
1752-
switch(
1753-
converges,
1754-
compute_grad_2f1(a, b, c, z, wrt, skip_loop=z_is_zero | (~converges)),
1755-
np.nan,
1756-
),
1795+
grads = _grad_2f1_loop(
1796+
a, b, c, z, skip_loop=z_is_zero | (~converges), wrt=wrt, dtype=dtype
17571797
)
1798+
1799+
return [
1800+
switch(
1801+
z_is_zero,
1802+
0,
1803+
switch(
1804+
converges,
1805+
grad,
1806+
np.nan,
1807+
),
1808+
)
1809+
for grad in grads
1810+
]

‎pytensor/tensor/rewriting/elemwise.py‎

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytensor.graph.rewriting.db import SequenceDB
2424
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
2525
from pytensor.scalar.loop import ScalarLoop
26+
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
2627
from pytensor.tensor.basic import (
2728
MakeVector,
2829
alloc,
@@ -31,6 +32,7 @@
3132
)
3233
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
3334
from pytensor.tensor.exceptions import NotScalarConstantError
35+
from pytensor.tensor.math import exp
3436
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
3537
from pytensor.tensor.shape import shape_padleft
3638
from pytensor.tensor.var import TensorConstant
@@ -1215,3 +1217,61 @@ def local_careduce_fusion(fgraph, node):
12151217
"fusion",
12161218
position=49,
12171219
)
1220+
1221+
1222+
@register_specialize
1223+
@node_rewriter([Elemwise])
1224+
def local_useless_2f1grad_loop(fgraph, node):
1225+
# Remove unused terms from the hyp2f1 grad loop
1226+
1227+
loop_op = node.op.scalar_op
1228+
if not isinstance(loop_op, Grad2F1Loop):
1229+
return
1230+
1231+
grad_related_vars = node.outputs[:-4]
1232+
# Rewrite was already applied
1233+
if len(grad_related_vars) // 3 != 3:
1234+
return None
1235+
1236+
grad_vars = grad_related_vars[:3]
1237+
grad_var_is_used = [bool(fgraph.clients.get(v)) for v in grad_vars]
1238+
1239+
# Nothing to do here
1240+
if sum(grad_var_is_used) == 3:
1241+
return None
1242+
1243+
# Check that None of the remaining vars is used anywhere
1244+
if any(bool(fgraph.clients.get(v)) for v in node.outputs[3:]):
1245+
return None
1246+
1247+
a, b, c, log_z, sign_z = node.inputs[-5:]
1248+
z = exp(log_z) * sign_z
1249+
1250+
# Reconstruct scalar loop with relevant outputs
1251+
a_, b_, c_, z_ = (x.type.to_scalar_type()() for x in (a, b, c, z))
1252+
wrt = [i for i, used in enumerate(grad_var_is_used) if used]
1253+
new_loop_op = _grad_2f1_loop(
1254+
a_, b_, c_, z_, skip_loop=False, wrt=wrt, dtype=a_.type.dtype
1255+
)[0].owner.op
1256+
1257+
# Reconstruct elemwise loop
1258+
new_elemwise_op = Elemwise(scalar_op=new_loop_op)
1259+
n_steps = node.inputs[0]
1260+
init_grad_vars = node.inputs[1:10]
1261+
other_inputs = node.inputs[10:]
1262+
1263+
init_grads = init_grad_vars[: len(wrt)]
1264+
init_gs = init_grad_vars[3 : 3 + len(wrt)]
1265+
init_gs_signs = init_grad_vars[6 : 6 + len(wrt)]
1266+
subset_init_grad_vars = init_grads + init_gs + init_gs_signs
1267+
1268+
new_outs = new_elemwise_op(n_steps, *subset_init_grad_vars, *other_inputs)
1269+
1270+
replacements = {}
1271+
i = 0
1272+
for grad_var, is_used in zip(grad_vars, grad_var_is_used):
1273+
if not is_used:
1274+
continue
1275+
replacements[grad_var] = new_outs[i]
1276+
i += 1
1277+
return replacements

‎tests/tensor/test_math_scipy.py‎

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import pytest
55

66
from pytensor.gradient import verify_grad
7+
from pytensor.scalar import ScalarLoop
8+
from pytensor.tensor.elemwise import Elemwise
79

810

911
scipy = pytest.importorskip("scipy")
@@ -1052,3 +1054,38 @@ def test_benchmark(self, case, wrt, benchmark):
10521054
expected_result,
10531055
rtol=rtol,
10541056
)
1057+
1058+
@pytest.mark.parametrize("wrt", ([0], [1], [2], [0, 1], [1, 2], [0, 2], [0, 1, 2]))
1059+
def test_unused_grad_loop_opt(self, wrt):
1060+
"""Test that we don't compute unnecessary outputs in the grad scalar loop"""
1061+
(
1062+
test_a1,
1063+
test_a2,
1064+
test_b1,
1065+
test_z,
1066+
*expected_dds,
1067+
expected_ddz,
1068+
) = self.few_iters_case
1069+
1070+
a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z")
1071+
hyp2f1_out = at.hyp2f1(a1, a2, b1, z)
1072+
wrt_vars = [v for i, v in enumerate((a1, a2, b1, z)) if i in wrt]
1073+
hyp2f1_grad = at.grad(hyp2f1_out, wrt=wrt_vars)
1074+
1075+
mode = get_default_mode().including("local_useless_2f1grad_loop")
1076+
f_grad = function([a1, a2, b1, z], hyp2f1_grad, mode=mode)
1077+
1078+
[scalar_loop_op] = [
1079+
node.op.scalar_op
1080+
for node in f_grad.maker.fgraph.apply_nodes
1081+
if isinstance(node.op, Elemwise)
1082+
and isinstance(node.op.scalar_op, ScalarLoop)
1083+
]
1084+
assert scalar_loop_op.nin == 10 + 3 * len(wrt)
1085+
1086+
rtol = 1e-9 if config.floatX == "float64" else 2e-3
1087+
np.testing.assert_allclose(
1088+
f_grad(test_a1, test_a2, test_b1, test_z),
1089+
[dd for i, dd in enumerate(expected_dds) if i in wrt],
1090+
rtol=rtol,
1091+
)

0 commit comments

Comments
 (0)
Please sign in to comment.