Skip to content

Improve error message when passing a tuple of non-concrete values as … #7952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

hawkinsp
Copy link
Collaborator

…an axis argument.

Previously, we would see the less helpful error:

jax._src.errors.TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=5/0)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

Whereas now we see:

jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The axis argument must be known statically.
While tracing the function h at /Users/phawkins/p/jax/tests/lax_numpy_test.py:3663 for jit, this concrete value was not available in Python because it depends on the value of the argument 'y'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

…an axis argument.

Previously, we would see the less helpful error:

```
jax._src.errors.TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=5/0)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError
```

Whereas now we see:

```
jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The axis argument must be known statically.
While tracing the function h at /Users/phawkins/p/jax/tests/lax_numpy_test.py:3663 for jit, this concrete value was not available in Python because it depends on the value of the argument 'y'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
```
@hawkinsp hawkinsp requested a review from jakevdp September 17, 2021 17:51
@google-cla google-cla bot added the cla: yes label Sep 17, 2021
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Sep 17, 2021
@jakevdp jakevdp added the better_errors Improve the error reporting label Sep 17, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better_errors Improve the error reporting cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants