5
5
"""
6
6
7
7
import os
8
+ from functools import reduce
8
9
from textwrap import dedent
10
+ from typing import Tuple
9
11
10
12
import numpy as np
11
13
import scipy .special
@@ -683,23 +685,28 @@ def __hash__(self):
683
685
gammaincc = GammaIncC (upgrade_to_float , name = "gammaincc" )
684
686
685
687
686
- def _make_scalar_loop (n_steps , init , constant , inner_loop_fn , name ):
687
- init = [as_scalar (x ) for x in init ]
688
+ def _make_scalar_loop (n_steps , init , constant , inner_loop_fn , name , loop_op = ScalarLoop ):
689
+ init = [as_scalar (x ) if x is not None else None for x in init ]
688
690
constant = [as_scalar (x ) for x in constant ]
691
+
689
692
# Create dummy types, in case some variables have the same initial form
690
- init_ = [x .type () for x in init ]
693
+ init_ = [x .type () if x is not None else None for x in init ]
691
694
constant_ = [x .type () for x in constant ]
692
695
update_ , until_ = inner_loop_fn (* init_ , * constant_ )
693
- op = ScalarLoop (
696
+
697
+ # Filter Nones
698
+ init = [i for i in init if i is not None ]
699
+ init_ = [i for i in init_ if i is not None ]
700
+ update_ = [u for u in update_ if u is not None ]
701
+ op = loop_op (
694
702
init = init_ ,
695
703
constant = constant_ ,
696
704
update = update_ ,
697
705
until = until_ ,
698
706
until_condition_failed = "warn" ,
699
707
name = name ,
700
708
)
701
- S , * _ = op (n_steps , * init , * constant )
702
- return S
709
+ return op (n_steps , * init , * constant )
703
710
704
711
705
712
def gammainc_grad (k , x ):
@@ -740,7 +747,7 @@ def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x):
740
747
741
748
init = [sum_a0 , log_gamma_k_plus_n_plus_1 , k_plus_n ]
742
749
constant = [log_x ]
743
- sum_a = _make_scalar_loop (
750
+ sum_a , * _ = _make_scalar_loop (
744
751
max_iters , init , constant , inner_loop_a , name = "gammainc_grad_a"
745
752
)
746
753
@@ -827,7 +834,7 @@ def inner_loop_a(sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac, x):
827
834
828
835
init = [sum_a0 , delta , xpow , k_minus_one_minus_n , fac , dfac ]
829
836
constant = [x ]
830
- sum_a = _make_scalar_loop (
837
+ sum_a , * _ = _make_scalar_loop (
831
838
n_steps , init , constant , inner_loop_a , name = "gammaincc_grad_a"
832
839
)
833
840
grad_approx_a = (
@@ -870,7 +877,7 @@ def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x):
870
877
871
878
init = [sum_b0 , log_s , s_sign , log_delta , n ]
872
879
constant = [k , log_x ]
873
- sum_b = _make_scalar_loop (
880
+ sum_b , * _ = _make_scalar_loop (
874
881
max_iters , init , constant , inner_loop_b , name = "gammaincc_grad_b"
875
882
)
876
883
grad_approx_b = (
@@ -1540,7 +1547,7 @@ def inner_loop(
1540
1547
1541
1548
init = [derivative , Am2 , Am1 , Bm2 , Bm1 , dAm2 , dAm1 , dBm2 , dBm1 , n ]
1542
1549
constant = [f , p , q , K , dK ]
1543
- grad = _make_scalar_loop (
1550
+ grad , * _ = _make_scalar_loop (
1544
1551
max_iters , init , constant , inner_loop , name = "betainc_grad"
1545
1552
)
1546
1553
return grad
@@ -1579,10 +1586,11 @@ def impl(self, a, b, c, z):
1579
1586
def grad (self , inputs , grads ):
1580
1587
a , b , c , z = inputs
1581
1588
(gz ,) = grads
1589
+ grad_a , grad_b , grad_c = hyp2f1_grad (a , b , c , z , wrt = [0 , 1 , 2 ])
1582
1590
return [
1583
- gz * hyp2f1_grad ( a , b , c , z , wrt = 0 ) ,
1584
- gz * hyp2f1_grad ( a , b , c , z , wrt = 1 ) ,
1585
- gz * hyp2f1_grad ( a , b , c , z , wrt = 2 ) ,
1591
+ gz * grad_a ,
1592
+ gz * grad_b ,
1593
+ gz * grad_c ,
1586
1594
gz * ((a * b ) / c ) * hyp2f1 (a + 1 , b + 1 , c + 1 , z ),
1587
1595
]
1588
1596
@@ -1598,92 +1606,55 @@ def _unsafe_sign(x):
1598
1606
return switch (x > 0 , 1 , - 1 )
1599
1607
1600
1608
1601
- def hyp2f1_grad (a , b , c , z , wrt : int ):
1602
- dtype = upcast (a .type .dtype , b .type .dtype , c .type .dtype , z .type .dtype , "float32" )
1603
-
1604
- def check_2f1_converges (a , b , c , z ):
1605
- def is_nonpositive_integer (x ):
1606
- if x .type .dtype not in integer_types :
1607
- return eq (floor (x ), x ) & (x <= 0 )
1608
- else :
1609
- return x <= 0
1609
+ class Grad2F1Loop (ScalarLoop ):
1610
+ """Subclass of ScalarLoop for easier targetting in rewrites"""
1610
1611
1611
- a_is_polynomial = is_nonpositive_integer (a ) & (scalar_abs (a ) >= 0 )
1612
- num_terms = switch (
1613
- a_is_polynomial ,
1614
- floor (scalar_abs (a )).astype ("int64" ),
1615
- 0 ,
1616
- )
1617
1612
1618
- b_is_polynomial = is_nonpositive_integer (b ) & (scalar_abs (b ) >= num_terms )
1619
- num_terms = switch (
1620
- b_is_polynomial ,
1621
- floor (scalar_abs (b )).astype ("int64" ),
1622
- num_terms ,
1623
- )
1613
+ def _grad_2f1_loop (a , b , c , z , * , skip_loop , wrt , dtype ):
1614
+ """
1615
+ Notes
1616
+ -----
1617
+ The algorithm can be derived by looking at the ratio of two successive terms in the series
1618
+ β_{k+1}/β_{k} = A(k)/B(k)
1619
+ β_{k+1} = A(k)/B(k) * β_{k}
1620
+ d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1624
1621
1625
- is_undefined = is_nonpositive_integer (c ) & (scalar_abs (c ) <= num_terms )
1626
- is_polynomial = a_is_polynomial | b_is_polynomial
1622
+ In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1627
1623
1628
- return (~ is_undefined ) & (
1629
- is_polynomial | (scalar_abs (z ) < 1 ) | (eq (scalar_abs (z ), 1 ) & (c > (a + b )))
1630
- )
1624
+ The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1625
+ by dropping the respective term
1626
+ d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1627
+ d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1628
+ d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1631
1629
1632
- def compute_grad_2f1 (a , b , c , z , wrt , skip_loop ):
1633
- """
1634
- Notes
1635
- -----
1636
- The algorithm can be derived by looking at the ratio of two successive terms in the series
1637
- β_{k+1}/β_{k} = A(k)/B(k)
1638
- β_{k+1} = A(k)/B(k) * β_{k}
1639
- d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1640
-
1641
- In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1642
-
1643
- The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1644
- by dropping the respective term
1645
- d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1646
- d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1647
- d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1648
-
1649
- The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1650
- tracking their signs.
1651
- """
1652
-
1653
- wrt_a = wrt_b = False
1654
- if wrt == 0 :
1655
- wrt_a = True
1656
- elif wrt == 1 :
1657
- wrt_b = True
1658
- elif wrt != 2 :
1659
- raise ValueError (f"wrt must be 0, 1, or 2, got { wrt } " )
1660
-
1661
- min_steps = np .array (
1662
- 10 , dtype = "int32"
1663
- ) # https://github.com/stan-dev/math/issues/2857
1664
- max_steps = switch (
1665
- skip_loop , np .array (0 , dtype = "int32" ), np .array (int (1e6 ), dtype = "int32" )
1666
- )
1667
- precision = np .array (1e-14 , dtype = config .floatX )
1630
+ The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1631
+ tracking their signs.
1632
+ """
1668
1633
1669
- grad = np .array (0 , dtype = dtype )
1634
+ min_steps = np .array (
1635
+ 10 , dtype = "int32"
1636
+ ) # https://github.com/stan-dev/math/issues/2857
1637
+ max_steps = switch (
1638
+ skip_loop , np .array (0 , dtype = "int32" ), np .array (int (1e6 ), dtype = "int32" )
1639
+ )
1640
+ precision = np .array (1e-14 , dtype = config .floatX )
1670
1641
1671
- log_g = np .array (- np .inf , dtype = dtype )
1672
- log_g_sign = np .array (1 , dtype = "int8" )
1642
+ grads = [np .array (0 , dtype = dtype ) if i in wrt else None for i in range (3 )]
1643
+ log_gs = [np .array (- np .inf , dtype = dtype ) if i in wrt else None for i in range (3 )]
1644
+ log_gs_signs = [np .array (1 , dtype = "int8" ) if i in wrt else None for i in range (3 )]
1673
1645
1674
- log_t = np .array (0.0 , dtype = dtype )
1675
- log_t_sign = np .array (1 , dtype = "int8" )
1646
+ log_t = np .array (0.0 , dtype = dtype )
1647
+ log_t_sign = np .array (1 , dtype = "int8" )
1676
1648
1677
- log_z = log (scalar_abs (z ))
1678
- sign_z = _unsafe_sign (z )
1649
+ log_z = log (scalar_abs (z ))
1650
+ sign_z = _unsafe_sign (z )
1679
1651
1680
- sign_zk = sign_z
1681
- k = np .array (0 , dtype = "int32" )
1652
+ sign_zk = sign_z
1653
+ k = np .array (0 , dtype = "int32" )
1682
1654
1683
- def inner_loop (
1684
- grad ,
1685
- log_g ,
1686
- log_g_sign ,
1655
+ def inner_loop (* args ):
1656
+ (
1657
+ * grads_vars ,
1687
1658
log_t ,
1688
1659
log_t_sign ,
1689
1660
sign_zk ,
@@ -1693,65 +1664,147 @@ def inner_loop(
1693
1664
c ,
1694
1665
log_z ,
1695
1666
sign_z ,
1696
- ):
1697
- p = (a + k ) * (b + k ) / ((c + k ) * (k + 1 ))
1698
- if p .type .dtype != dtype :
1699
- p = p .astype (dtype )
1700
-
1701
- term = log_g_sign * log_t_sign * exp (log_g - log_t )
1702
- if wrt_a :
1703
- term += reciprocal (a + k )
1704
- elif wrt_b :
1705
- term += reciprocal (b + k )
1706
- else :
1707
- term -= reciprocal (c + k )
1667
+ ) = args
1668
+
1669
+ (
1670
+ grad_a ,
1671
+ grad_b ,
1672
+ grad_c ,
1673
+ log_g_a ,
1674
+ log_g_b ,
1675
+ log_g_c ,
1676
+ log_g_sign_a ,
1677
+ log_g_sign_b ,
1678
+ log_g_sign_c ,
1679
+ ) = grads_vars
1680
+
1681
+ p = (a + k ) * (b + k ) / ((c + k ) * (k + 1 ))
1682
+ if p .type .dtype != dtype :
1683
+ p = p .astype (dtype )
1684
+
1685
+ # If p==0, don't update grad and get out of while loop next
1686
+ p_zero = eq (p , 0 )
1687
+
1688
+ if 0 in wrt :
1689
+ term_a = log_g_sign_a * log_t_sign * exp (log_g_a - log_t )
1690
+ term_a += reciprocal (a + k )
1691
+ if term_a .type .dtype != dtype :
1692
+ term_a = term_a .astype (dtype )
1693
+ if 1 in wrt :
1694
+ term_b = log_g_sign_b * log_t_sign * exp (log_g_b - log_t )
1695
+ term_b += reciprocal (b + k )
1696
+ if term_b .type .dtype != dtype :
1697
+ term_b = term_b .astype (dtype )
1698
+ if 2 in wrt :
1699
+ term_c = log_g_sign_c * log_t_sign * exp (log_g_c - log_t )
1700
+ term_c -= reciprocal (c + k )
1701
+ if term_c .type .dtype != dtype :
1702
+ term_c = term_c .astype (dtype )
1703
+
1704
+ log_t = log_t + log (scalar_abs (p )) + log_z
1705
+ log_t_sign = (_unsafe_sign (p ) * log_t_sign ).astype ("int8" )
1706
+
1707
+ grads = [None ] * 3
1708
+ log_gs = [None ] * 3
1709
+ log_gs_signs = [None ] * 3
1710
+ grad_incs = [None ] * 3
1711
+
1712
+ if 0 in wrt :
1713
+ log_g_a = log_t + log (scalar_abs (term_a ))
1714
+ log_g_sign_a = (_unsafe_sign (term_a ) * log_t_sign ).astype ("int8" )
1715
+ grad_inc_a = log_g_sign_a * exp (log_g_a ) * sign_zk
1716
+ grads [0 ] = switch (p_zero , grad_a , grad_a + grad_inc_a )
1717
+ log_gs [0 ] = log_g_a
1718
+ log_gs_signs [0 ] = log_g_sign_a
1719
+ grad_incs [0 ] = grad_inc_a
1720
+ if 1 in wrt :
1721
+ log_g_b = log_t + log (scalar_abs (term_b ))
1722
+ log_g_sign_b = (_unsafe_sign (term_b ) * log_t_sign ).astype ("int8" )
1723
+ grad_inc_b = log_g_sign_b * exp (log_g_b ) * sign_zk
1724
+ grads [1 ] = switch (p_zero , grad_b , grad_b + grad_inc_b )
1725
+ log_gs [1 ] = log_g_b
1726
+ log_gs_signs [1 ] = log_g_sign_b
1727
+ grad_incs [1 ] = grad_inc_b
1728
+ if 2 in wrt :
1729
+ log_g_c = log_t + log (scalar_abs (term_c ))
1730
+ log_g_sign_c = (_unsafe_sign (term_c ) * log_t_sign ).astype ("int8" )
1731
+ grad_inc_c = log_g_sign_c * exp (log_g_c ) * sign_zk
1732
+ grads [2 ] = switch (p_zero , grad_c , grad_c + grad_inc_c )
1733
+ log_gs [2 ] = log_g_c
1734
+ log_gs_signs [2 ] = log_g_sign_c
1735
+ grad_incs [2 ] = grad_inc_c
1736
+
1737
+ sign_zk *= sign_z
1738
+ k += 1
1739
+
1740
+ abs_grad_incs = [
1741
+ scalar_abs (grad_inc ) for grad_inc in grad_incs if grad_inc is not None
1742
+ ]
1743
+ if len (grad_incs ) == 1 :
1744
+ [max_abs_grad_inc ] = grad_incs
1745
+ else :
1746
+ max_abs_grad_inc = reduce (scalar_maximum , abs_grad_incs )
1708
1747
1709
- if term .type .dtype != dtype :
1710
- term = term .astype (dtype )
1748
+ return (
1749
+ (* grads , * log_gs , * log_gs_signs , log_t , log_t_sign , sign_zk , k ),
1750
+ (eq (p , 0 ) | ((k > min_steps ) & (max_abs_grad_inc <= precision ))),
1751
+ )
1711
1752
1712
- log_t = log_t + log (scalar_abs (p )) + log_z
1713
- log_t_sign = (_unsafe_sign (p ) * log_t_sign ).astype ("int8" )
1714
- log_g = log_t + log (scalar_abs (term ))
1715
- log_g_sign = (_unsafe_sign (term ) * log_t_sign ).astype ("int8" )
1753
+ init = [* grads , * log_gs , * log_gs_signs , log_t , log_t_sign , sign_zk , k ]
1754
+ constant = [a , b , c , log_z , sign_z ]
1755
+ loop_outs = _make_scalar_loop (
1756
+ max_steps , init , constant , inner_loop , name = "hyp2f1_grad" , loop_op = Grad2F1Loop
1757
+ )
1758
+ return loop_outs [: len (wrt )]
1716
1759
1717
- g_current = log_g_sign * exp (log_g ) * sign_zk
1718
1760
1719
- # If p==0, don't update grad and get out of while loop next
1720
- grad = switch (
1721
- eq (p , 0 ),
1722
- grad ,
1723
- grad + g_current ,
1724
- )
1761
+ def hyp2f1_grad (a , b , c , z , wrt : Tuple [int , ...]):
1762
+ dtype = upcast (a .type .dtype , b .type .dtype , c .type .dtype , z .type .dtype , "float32" )
1725
1763
1726
- sign_zk *= sign_z
1727
- k += 1
1764
+ def check_2f1_converges (a , b , c , z ):
1765
+ def is_nonpositive_integer (x ):
1766
+ if x .type .dtype not in integer_types :
1767
+ return eq (floor (x ), x ) & (x <= 0 )
1768
+ else :
1769
+ return x <= 0
1728
1770
1729
- return (
1730
- (grad , log_g , log_g_sign , log_t , log_t_sign , sign_zk , k ),
1731
- (eq (p , 0 ) | ((k > min_steps ) & (scalar_abs (g_current ) <= precision ))),
1732
- )
1771
+ a_is_polynomial = is_nonpositive_integer (a ) & (scalar_abs (a ) >= 0 )
1772
+ num_terms = switch (
1773
+ a_is_polynomial ,
1774
+ floor (scalar_abs (a )).astype ("int64" ),
1775
+ 0 ,
1776
+ )
1733
1777
1734
- init = [grad , log_g , log_g_sign , log_t , log_t_sign , sign_zk , k ]
1735
- constant = [a , b , c , log_z , sign_z ]
1736
- grad = _make_scalar_loop (
1737
- max_steps , init , constant , inner_loop , name = "hyp2f1_grad"
1778
+ b_is_polynomial = is_nonpositive_integer (b ) & (scalar_abs (b ) >= num_terms )
1779
+ num_terms = switch (
1780
+ b_is_polynomial ,
1781
+ floor (scalar_abs (b )).astype ("int64" ),
1782
+ num_terms ,
1738
1783
)
1739
1784
1740
- return switch (
1741
- eq (z , 0 ),
1742
- 0 ,
1743
- grad ,
1785
+ is_undefined = is_nonpositive_integer (c ) & (scalar_abs (c ) <= num_terms )
1786
+ is_polynomial = a_is_polynomial | b_is_polynomial
1787
+
1788
+ return (~ is_undefined ) & (
1789
+ is_polynomial | (scalar_abs (z ) < 1 ) | (eq (scalar_abs (z ), 1 ) & (c > (a + b )))
1744
1790
)
1745
1791
1746
1792
# We have to pass the converges flag to interrupt the loop, as the switch is not lazy
1747
1793
z_is_zero = eq (z , 0 )
1748
1794
converges = check_2f1_converges (a , b , c , z )
1749
- return switch (
1750
- z_is_zero ,
1751
- 0 ,
1752
- switch (
1753
- converges ,
1754
- compute_grad_2f1 (a , b , c , z , wrt , skip_loop = z_is_zero | (~ converges )),
1755
- np .nan ,
1756
- ),
1795
+ grads = _grad_2f1_loop (
1796
+ a , b , c , z , skip_loop = z_is_zero | (~ converges ), wrt = wrt , dtype = dtype
1757
1797
)
1798
+
1799
+ return [
1800
+ switch (
1801
+ z_is_zero ,
1802
+ 0 ,
1803
+ switch (
1804
+ converges ,
1805
+ grad ,
1806
+ np .nan ,
1807
+ ),
1808
+ )
1809
+ for grad in grads
1810
+ ]
0 commit comments