From 2bce6142051699da9dffd2b7ab5a3c948a1c37d9 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 29 Jan 2025 15:23:16 +0000 Subject: [PATCH 1/2] BUG: `isclose` integer overflow --- src/array_api_extra/_lib/_funcs.py | 7 ++++++- tests/test_funcs.py | 7 +++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 0af265e2..7a43fbf0 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -335,7 +335,12 @@ def isclose( atol = int(atol) if rtol == 0: return xp.abs(a - b) <= atol - nrtol = int(1.0 / rtol) + + try: + nrtol = xp.asarray(int(1.0 / rtol), dtype=b.dtype) + except OverflowError: + return xp.abs(a - b) <= atol + return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index ef1a1fc2..eee65145 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -354,6 +354,13 @@ def test_tolerance(self, dtype: str, xp: ModuleType): xp_assert_equal(isclose(a, b, rtol=0), xp.asarray([False, False])) xp_assert_equal(isclose(a, b, atol=1, rtol=0), xp.asarray([True, False])) + @pytest.mark.parametrize("dtype", ["int8", "uint8"]) + def test_tolerance_integer_overflow(self, dtype: str, xp: ModuleType): + """1/rtol is too large for dtype""" + a = xp.asarray([100, 100], dtype=getattr(xp, dtype)) + b = xp.asarray([100, 101], dtype=getattr(xp, dtype)) + xp_assert_equal(isclose(a, b), xp.asarray([True, False])) + def test_very_small_numbers(self, xp: ModuleType): a = xp.asarray([1e-9, 1e-9]) b = xp.asarray([1.0001e-9, 1.00001e-9]) From 2b8aa429704a26e8e569c31a2c9b0b3f42d627a2 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 29 Jan 2025 15:43:12 +0000 Subject: [PATCH 2/2] Clarification --- src/array_api_extra/_lib/_funcs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 7a43fbf0..384a6db1 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -339,6 +339,7 @@ def isclose( try: nrtol = xp.asarray(int(1.0 / rtol), dtype=b.dtype) except OverflowError: + # rtol * max_int(dtype) < 1, so it's inconsequential return xp.abs(a - b) <= atol return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol)