Skip to content

RFC: require that dtypes obey Python hashing rules #582

Open
@NeilGirdhar

Description

@NeilGirdhar

Python's documentation promises that: "The only required property is that objects which compare equal have the same hash value…" However, NumPy dtypes do not follow this requirement. As discussed in numpy/numpy#7242, dtype objects, their types, and their names all compare equal despite hashing unequal. Could the Array API promise that this will no longer be the case?

Activity

rgommers

rgommers commented on Dec 23, 2022

@rgommers
Member

That seems fine to me to explicitly specify. float32 == 'float32' should clearly return False. In NumPy it's a bit messy:

>>> import numpy as np
>>> np.float32 == 'float32'
False
>>> np.dtype(np.float32) == 'float32'
True

Only the first example is relevant for the array API standard, so I think this will be fine to specify since NumPy already complies.

This one there is a problem in NumPy however:

>>> np.dtype(np.float64) == float
True

That can be considered a clear bug though, should be fixed in NumPy.

NeilGirdhar

NeilGirdhar commented on Dec 23, 2022

@NeilGirdhar
Author

Only the first example is relevant for the array API standard, so I think this will be fine to specify since NumPy already complies.

So you're saying that np.dtype(np.float32) == 'float32' will be true or false?

That can be considered a clear bug though, should be fixed in NumPy.

Agreed.

What about np.float32 == np.dtype(np.float32)?

This also violates Python's hashing invariant.

rgommers

rgommers commented on Dec 23, 2022

@rgommers
Member

What about np.float32 == np.dtype(np.float32)?

I agree that it's a bug technically. Not 100% sure that the NumPy team will want that changed, but I hope so (and a proposal for a major release is in the works, so that could go into it). For the array API standard it's not an issue, because there is no dtype constructor/function in the main namespace.

So you're saying that np.dtype(np.float32) == 'float32' will be blocked or not?

That's more for the NumPy issue tracker, but if it were up to me then yes.

For this issue tracker, I'm +1 on adopting language in the standard like: "All objects in this standard must adhere to the following requirement (as required by Python itself): objects which compare equal have the same hash value".

NeilGirdhar

NeilGirdhar commented on Dec 23, 2022

@NeilGirdhar
Author

For this issue tracker, I'm +1 on adopting language in the standard like: "All objects in this standard must adhere to the following requirement (as required by Python itself): objects which compare equal have the same hash value".

That would be amazing. That's exactly what I was hoping for.

I agree that it's a bug technically. Not 100% sure that the NumPy team will want that changed, but I hope so (and a proposal for a major release is in the works, so that could go into it). For the array API standard it's not an issue, because there is no dtype constructor/function in the main namespace.

Okay, thanks for explaining. If the above language were adopted, NumPy could implement that by making xp.float32 not simply equal to np.dtype(np.float32), but rather a special dtype object that doesn't have the pernicious behavior.

rgommers

rgommers commented on Dec 23, 2022

@rgommers
Member

Let's give it a bit of time to see if anyone sees a reason not to add such a requirement. I can open a PR after the holidays.

NeilGirdhar

NeilGirdhar commented on Dec 25, 2022

@NeilGirdhar
Author

For the array API standard it's not an issue, because there is no dtype constructor/function in the main namespace.

Just noticed this comment. It is currently an issue in NumPy's implementation of the Array API:

import numpy.array_api as xp
xp.float32 == xp.float32.type  # True!

This is because xp.float32 points to an object np.dtype(np.float32). For this to be fixed, NumPy would just need a new dtype class for use in its Array API xp.

With the language you suggested above, NumPy would be forced to do this to become compliant 😄 .

So you're saying that np.dtype(np.float32) == 'float32' will be blocked or not?

That's more for the NumPy issue tracker, but if it were up to me then yes.

Same thing here, I think. NumPy will probably reject this for their own namespace (np), but if you adopt that language, they would have to fix it in the array API (xp).

Incidentally, I assume you want numpy.array_api.float32 to compare equal to jax.array_api.float32? Since there is no root project to provide a base implementation of dtypes, you may need to standardize how dtype.__hash__ and comparison work.

rgommers

rgommers commented on Dec 26, 2022

@rgommers
Member

xp.float32 == xp.float32.type # True!

There is no float32.type in the standard. That it shows up with numpy.array_api.float32 is because the dtype objects there are aliases to the regular numpy ones, rather than new objects. That was a shortcut I think, because adding new dtypes is a lot of work. So that's one place where currently numpy.array_api doesn't 100% meet its goal of being completely minimal.

Incidentally, I assume you want numpy.array_api.float32 to compare equal to jax.array_api.float32?

No, definitely not. No objects from two different libraries should ever compare equal, unless they're indeed the same object.

NeilGirdhar

NeilGirdhar commented on Dec 26, 2022

@NeilGirdhar
Author

So that's one place where currently numpy.array_api doesn't 100% meet its goal of being completely minimal.

Ok! Thanks for explaining.

No, definitely not. No objects from two different libraries should ever compare equal, unless they're indeed the same object.

So to do things like checking that two arrays have the same dtype, or creating a NumPy array that has the same type as a Jax array, we'll need mappings like:

m = {jax.array_api.float32: np.array_api.float32, ...}

And code like

np.array_api.ones_like(some_jax_array)  # works today, in either direction.

is impossible, yes? You need:

np.array_api.ones(some_jax_array.shape, dtype=m[some_jax_array.dtype])
rgommers

rgommers commented on Dec 26, 2022

@rgommers
Member

So to do things like checking that two arrays have the same dtype ...

Having to use library-specific constructs should not be needed - if so, we're missing an API I'd say. More importantly: mixing arrays from different libraries like this is a bit of an anti-pattern. You can't do much with that, neither library has kernels for functions that use both array types, so you're probably relying on implicit conversion of one to the other.

So in this case, let me assume that x is a numpy array, y a JAX array and you're wanting to use functions from x (numpy):

# First retrieve the namespace you want to work with
xp = x.__array_namespace__()
# Use DLPack or the buffer protocol to convert a CPU JAX array to a NumPy array
y = xp.asarray(y)
# Now we can compare dtypes:
if x.dtype == y.dtype == xp.float32:
    # If the same dtypes, do stuff

# Or, similarly:
if xp.isdtype(x, xp.float32) and xp.isdtype(y, xp.float32):

is impossible, yes? You need:

yes indeed

I'm actually a little surprised JAX accepts numpy arrays. It seems to go against its philosophy; TensorFlow, PyTorch and CuPy will all raise. When you call jnp.xxx(a_numpy_array), JAX will also make a copy always I believe, since it doesn't want to share memory. An explicit copy made by the user is clearer and more portable.

JAX is also annotating its array inputs as array_like, but it doesn't mean the same as for NumPy:

>>> jnp.sin([1, 2, 3])
...
TypeError: sin requires ndarray or scalar arguments, got <class 'list'> at position 0

All this stuff is bug-prone:

>>> jnp.sin(np.array([1, 2, 3]))
Array([0.84147096, 0.9092974 , 0.14112   ], dtype=float32)

>>> jnp.sin(np.ma.array([1, 2, 3], mask=[True, False, True]))  # bug in user code here, because JAX silently discards mask
Array([0.84147096, 0.9092974 , 0.14112   ], dtype=float32)

>>> np.sin(np.ma.array([1, 2, 3], mask=[True, False, True]))
masked_array(data=[--, 0.9092974268256816, --],
             mask=[ True, False,  True],
       fill_value=1e+20)
NeilGirdhar

NeilGirdhar commented on Dec 26, 2022

@NeilGirdhar
Author

More importantly: mixing arrays from different libraries like this is a bit of an anti-pattern. You can't do much with that, neither library has kernels for functions that use both array types, so you're probably relying on implicit conversion of one to the other.

Okay, makes sense. I haven't been very conscious about this because (as you pointed out) Jax implicitly converts. I will be more careful.

y = xp.asarray(y)

I think this is where I'm confused. Somehow numpy has to know what its equivalent dtypes are for Jax's dtypes even though they don't compare equal? Or will it produce a numpy array with a Jax dtype? As this seems to work:

In [12]: x = jnp.ones(10, jnp.bfloat16)

In [14]: np.asarray(x)
Out[14]: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=bfloat16)

When you call jnp.xxx(a_numpy_array), JAX will also make a copy always I believe, since it doesn't want to share memory. An explicit copy made by the user is clearer and more portable.

Very interesting. I wonder what the Jax team would say.

rgommers

rgommers commented on Dec 26, 2022

@rgommers
Member

I think this is where I'm confused. Somehow numpy has to know what its equivalent dtypes are for Jax's dtypes even though they don't compare equal? Or will it produce a numpy array with a Jax dtype?

NumPy knows the dtype, as does JAX. This conversion uses the Python buffer protocol or DLPack, both of which are protocols explicitly meant for exchanging data in a reliable way (that includes dtype, shape, endianness, etc.). So the asarray call will produce a numpy array with a numpy dtype, and to do so numpy does not need to know anything specifically about JAX.

When you call jnp.xxx(a_numpy_array), JAX will also make a copy always I believe, since it doesn't want to share memory. An explicit copy made by the user is clearer and more portable.

Very interesting. I wonder what the Jax team would say.

Let's try to find out:) This section of the JAX docs only explains why JAX doesn't accept list/tuple/etc., but I cannot find an explanation of why it does accept numpy arrays and scalars. @shoyer or @jakevdp, would you be able to comment on why JAX implements a limited form of "array-like"?

Also, in addition to bug with masked arrays above, here is another bug:

>>> jnp.sin(np.float64(1.5))  # silent precision loss here, downcasting to float32
Array(0.997495, dtype=float32)
>>> jax.__version__
'0.4.1'
NeilGirdhar

NeilGirdhar commented on Dec 26, 2022

@NeilGirdhar
Author

So the asarray call will produce a numpy array with a numpy dtype, and to do so numpy does not need to know anything specifically about JAX.

In that case, there should be a way to convert dtypes using both the buffer protocol or DLPack? Something more efficient than:

def x_to_y_dtype(some_xp_dtype: DType, yp: ArrayInterfac) -> DType:
  xp = some_xp_dtype.__array_interface__ # doesn't exist
  x = xp.ones((), dtype=some_xp_dtype)
  yp.asarray(x)
  return yp.dtype

Should dtypes have a __array_namespace__ attribute? Currently, they don't. So, the above function can't be written unless you know xp.

rgommers

rgommers commented on Dec 26, 2022

@rgommers
Member

No, those protocols are specifically for exchanging data (strided arrays/buffers). A dtype without data isn't very meaningful. You could exchange a size-1 array if needed, or a 'float32' string representation, or whatever works.

NeilGirdhar

NeilGirdhar commented on Dec 26, 2022

@NeilGirdhar
Author

No, those protocols are specifically for exchanging data (strided arrays/buffers).

I understand, but in order to exchange data, they have to be able to convert dtypes. So, that dtype conversion is happening somehow, and I was just wondering if that conversion can be accessed by the user.

29 remaining items

Loading
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    RFCRequest for comments. Feature requests and proposed changes.

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @seberg@asmeurer@rgommers@NeilGirdhar@jakevdp

        Issue actions

          RFC: require that dtypes obey Python hashing rules · Issue #582 · data-apis/array-api