diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 114f42d..515f52e 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -11,8 +11,8 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.10', '3.11', '3.12', '3.13'] - numpy-version: ['1.26', 'dev'] + python-version: ['3.12', '3.13'] + numpy-version: ['1.26', '2.2', 'dev'] exclude: - python-version: '3.13' numpy-version: '1.26' @@ -38,7 +38,7 @@ jobs: if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy; else - python -m pip install 'numpy>=1.26,<2.0'; + python -m pip install 'numpy=='${{ matrix.numpy-version }}; fi python -m pip install ${GITHUB_WORKSPACE}/array-api-strict python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 7242055..9906f97 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -67,9 +67,7 @@ def __hash__(self) -> int: CPU_DEVICE = Device() ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2")) -# See https://github.com/data-apis/array-api-strict/issues/67 and the comment -# on __array__ below. -_allow_array = True +_default = object() class Array: @@ -151,40 +149,28 @@ def __repr__(self) -> str: __str__ = __repr__ - # In the future, _allow_array will be set to False, which will disallow - # __array__. This means calling `np.func()` on an array_api_strict array - # will give an error. If we don't explicitly disallow it, NumPy defaults - # to creating an object dtype array, which would lead to confusing error - # messages at best and surprising bugs at worst. The reason for doing this - # is that __array__ is not actually supported by the standard, so it can - # lead to code assuming np.asarray(other_array) would always work in the - # standard. - # - # This was implemented historically for compatibility, and removing it has + # `__array__` was implemented historically for compatibility, and removing it has # caused issues for some libraries (see # https://github.com/data-apis/array-api-strict/issues/67). - def __array__( - self, dtype: None | np.dtype[Any] = None, copy: None | bool = None - ) -> npt.NDArray[Any]: - # We have to allow this to be internally enabled as there's no other - # easy way to parse a list of Array objects in asarray(). - if _allow_array: - if self._device != CPU_DEVICE: - raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") - # copy keyword is new in 2.0.0; for older versions don't use it - # retry without that keyword. - if np.__version__[0] < '2': - return np.asarray(self._array, dtype=dtype) - elif np.__version__.startswith('2.0.0-dev0'): - # Handle dev version for which we can't know based on version - # number whether or not the copy keyword is supported. - try: - return np.asarray(self._array, dtype=dtype, copy=copy) - except TypeError: - return np.asarray(self._array, dtype=dtype) - else: - return np.asarray(self._array, dtype=dtype, copy=copy) - raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported") + + # Instead of `__array__` we now implement the buffer protocol. + # Note that it makes array-apis-strict requiring python>=3.12 + def __buffer__(self, flags): + if self._device != CPU_DEVICE: + raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") + return self._array.__buffer__(flags) + + # We do not define __release_buffer__, per the discussion at + # https://github.com/data-apis/array-api-strict/pull/115#pullrequestreview-2917178729 + + def __array__(self, *args, **kwds): + # a stub for python < 3.12; otherwise numpy silently produces object arrays + import sys + minor, major = sys.version_info.minor, sys.version_info.major + if major < 3 or minor < 12: + raise TypeError( + "Interoperation with NumPy requires python >= 3.12. Please upgrade." + ) # These are various helper functions to make the array behavior match the # spec in places where it either deviates from or is more strict than diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 69d37aa..64c51ce 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -1,7 +1,5 @@ from __future__ import annotations -from collections.abc import Generator -from contextlib import contextmanager from enum import Enum from typing import TYPE_CHECKING, Literal @@ -26,21 +24,6 @@ class Undef(Enum): _undef = Undef.UNDEF -@contextmanager -def allow_array() -> Generator[None]: - """ - Temporarily enable Array.__array__. This is needed for np.array to parse - list of lists of Array objects. - """ - from . import _array_object - original_value = _array_object._allow_array - try: - _array_object._allow_array = True - yield - finally: - _array_object._allow_array = original_value - - def _check_valid_dtype(dtype: DType | None) -> None: # Note: Only spelling dtypes as the dtype objects is supported. if dtype not in (None,) + _all_dtypes: @@ -123,8 +106,8 @@ def asarray( # Give a better error message in this case. NumPy would convert this # to an object array. TODO: This won't handle large integers in lists. raise OverflowError("Integer out of bounds for array dtypes") - with allow_array(): - res = np.array(obj, dtype=_np_dtype, copy=copy) + + res = np.array(obj, dtype=_np_dtype, copy=copy) return Array._new(res, device=device) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index ae6627a..bb4263c 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -1,3 +1,4 @@ +import sys import operator from builtins import all as all_ @@ -526,6 +527,10 @@ def test_array_properties(): assert b.mT.shape == (3, 2) +@pytest.mark.xfail(sys.version_info.major*100 + sys.version_info.minor < 312, + reason="array conversion relies on buffer protocol, and " + "requires python >= 3.12" +) def test_array_conversion(): # Check that arrays on the CPU device can be converted to NumPy # but arrays on other devices can't. Note this is testing the logic in @@ -536,25 +541,23 @@ def test_array_conversion(): for device in ("device1", "device2"): a = ones((2, 3), device=array_api_strict.Device(device)) - with pytest.raises(RuntimeError, match="Can not convert array"): + with pytest.raises((RuntimeError, ValueError)): np.asarray(a) -def test__array__(): - # __array__ should work for now + # __buffer__ should work for now for conversion to numpy a = ones((2, 3)) - np.array(a) - - # Test the _allow_array private global flag for disabling it in the - # future. - from .. import _array_object - original_value = _array_object._allow_array - try: - _array_object._allow_array = False - a = ones((2, 3)) - with pytest.raises(ValueError, match="Conversion from an array_api_strict array to a NumPy ndarray is not supported"): - np.array(a) - finally: - _array_object._allow_array = original_value + na = np.array(a) + assert na.shape == (2, 3) + assert na.dtype == np.float64 + +@pytest.mark.skipif(not sys.version_info.major*100 + sys.version_info.minor < 312, + reason="conversion to numpy errors out unless python >= 3.12" +) +def test_array_conversion_2(): + a = ones((2, 3)) + with pytest.raises(TypeError): + np.array(a) + def test_allow_newaxis(): a = ones(5)