From 40c023cb150687a6ee67092e8461dcafdaa5b063 Mon Sep 17 00:00:00 2001
From: Nym Seddon <unseddd@shh.xyz>
Date: Thu, 21 Jan 2021 22:04:18 +0000
Subject: [PATCH] Add optimized binary extended gcd algorithm

Optimized to use cheap shifts and adds rather than multiplications and
divisions for finding the greated common divisor and Bezout coefficients.
---
 src/lib.rs | 178 +++++++++++++++++++++++++++++++++++++++++++----------
 1 file changed, 146 insertions(+), 32 deletions(-)

diff --git a/src/lib.rs b/src/lib.rs
index 0281954..1309a77 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -22,7 +22,7 @@ extern crate std;
 extern crate num_traits as traits;
 
 use core::mem;
-use core::ops::Add;
+use core::ops::{Add, Shl, Shr, Sub};
 
 use traits::{Num, Signed, Zero};
 
@@ -34,6 +34,140 @@ mod average;
 pub use average::Average;
 pub use average::{average_ceil, average_floor};
 
+/// Greatest common divisor and Bézout coefficients
+///
+/// # Examples
+///
+/// ~~~
+/// # use num_integer::{extended_binary_gcd, ExtendedGcd, Integer};
+/// # fn main() {
+/// let a = 693i16;
+/// let b = 609i16;
+/// let ExtendedGcd { gcd, x, y, .. } = extended_binary_gcd(&a, &b);
+/// assert_eq!(gcd, 21i16);
+/// assert_eq!(x, -181i16);
+/// assert_eq!(y, 206i16);
+/// # }
+/// ~~~
+///
+/// Based on "Binary extended gcd algorithm",
+/// Handbook of Applied Cryptography, Ch. 14, Ss. 14.61
+pub fn extended_binary_gcd<T>(this: &T, other: &T) -> ExtendedGcd<T>
+where
+    T: Clone
+        + Integer
+        + Shl<u32, Output = T>
+        + Shr<u32, Output = T>
+        + for<'a> Add<&'a T, Output = T>
+        + for<'a> Sub<&'a T, Output = T>,
+{
+    let zero = T::zero();
+
+    if this <= &zero || other <= &zero {
+        panic!("base and other must be positive, non-zero integers");
+    }
+
+    let mut echs = this.clone();
+    let mut why = other.clone();
+
+    let mut gg = T::one();
+
+    while echs.is_even() && why.is_even() {
+        echs = echs >> 1u32;
+        why = why >> 1u32;
+        gg = gg << 1u32;
+    }
+
+    let mut xx = echs.clone();
+    let mut yy = if why < zero { zero - &why } else { why.clone() };
+
+    let mut ba = T::one();
+    let mut bb = T::zero();
+    let mut bc = T::zero();
+    let mut bd = T::one();
+
+    while !xx.is_zero() {
+        while xx.is_even() {
+            xx = xx >> 1u32;
+
+            if ba.is_odd() || bb.is_odd() {
+                ba = ba + &why;
+                bb = bb - &echs;
+            }
+
+            ba = ba >> 1u32;
+            bb = bb >> 1u32;
+        }
+
+        while yy.is_even() {
+            yy = yy >> 1u32;
+
+            if bc.is_odd() || bd.is_odd() {
+                bc = bc + &why;
+                bd = bd - &echs;
+            }
+
+            bc = bc >> 1u32;
+            bd = bd >> 1u32;
+        }
+
+        if xx >= yy {
+            xx = xx - &yy;
+            ba = ba - &bc;
+            bb = bb - &bd;
+        } else {
+            yy = yy - &xx;
+            bc = bc - &ba;
+            bd = bd - &bb;
+        }
+    }
+
+    ExtendedGcd {
+        gcd: gg * yy,
+        x: bc,
+        y: bd,
+        _hidden: (),
+    }
+}
+
+fn extended_gcd<T>(this: &T, other: &T) -> ExtendedGcd<T>
+where
+    T: Clone + Integer,
+{
+    let mut s = (T::zero(), T::one());
+    let mut t = (T::one(), T::zero());
+    let mut r = (other.clone(), this.clone());
+
+    while !r.0.is_zero() {
+        let q = r.1.clone() / r.0.clone();
+        let f = |mut r: (T, T)| {
+            mem::swap(&mut r.0, &mut r.1);
+            r.0 = r.0 - q.clone() * r.1.clone();
+            r
+        };
+
+        r = f(r);
+        s = f(s);
+        t = f(t);
+    }
+
+    if r.1 >= T::zero() {
+        ExtendedGcd {
+            gcd: r.1,
+            x: s.1,
+            y: t.1,
+            _hidden: (),
+        }
+    } else {
+        ExtendedGcd {
+            gcd: T::zero() - r.1,
+            x: T::zero() - s.1,
+            y: T::zero() - t.1,
+            _hidden: (),
+        }
+    }
+}
+
 pub trait Integer: Sized + Num + PartialOrd + Ord + Eq {
     /// Floored integer division.
     ///
@@ -166,37 +300,7 @@ pub trait Integer: Sized + Num + PartialOrd + Ord + Eq {
     where
         Self: Clone,
     {
-        let mut s = (Self::zero(), Self::one());
-        let mut t = (Self::one(), Self::zero());
-        let mut r = (other.clone(), self.clone());
-
-        while !r.0.is_zero() {
-            let q = r.1.clone() / r.0.clone();
-            let f = |mut r: (Self, Self)| {
-                mem::swap(&mut r.0, &mut r.1);
-                r.0 = r.0 - q.clone() * r.1.clone();
-                r
-            };
-            r = f(r);
-            s = f(s);
-            t = f(t);
-        }
-
-        if r.1 >= Self::zero() {
-            ExtendedGcd {
-                gcd: r.1,
-                x: s.1,
-                y: t.1,
-                _hidden: (),
-            }
-        } else {
-            ExtendedGcd {
-                gcd: Self::zero() - r.1,
-                x: Self::zero() - s.1,
-                y: Self::zero() - t.1,
-                _hidden: (),
-            }
-        }
+        extended_gcd(&self, other)
     }
 
     /// Greatest common divisor, least common multiple, and Bézout coefficients.
@@ -502,6 +606,16 @@ macro_rules! impl_integer_for_isize {
                 m << shift
             }
 
+            #[inline]
+            fn extended_gcd(&self, other: &Self) -> ExtendedGcd<Self> {
+                let zero = Self::zero();
+                if self > &zero && other > &zero {
+                    extended_binary_gcd(&self, other)
+                } else {
+                    extended_gcd(&self, other)
+                }
+            }
+
             #[inline]
             fn extended_gcd_lcm(&self, other: &Self) -> (ExtendedGcd<Self>, Self) {
                 let egcd = self.extended_gcd(other);