-
Notifications
You must be signed in to change notification settings - Fork 3.1k
add pargmin_p / pargmax_p primitives #5694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
4108482
to
a3a439e
Compare
9a72481
to
68b8f73
Compare
68b8f73
to
9fffa86
Compare
else: | ||
# TODO(mattjj): avals with names, check and eliminate axis name | ||
shape = np.shape(x) | ||
return ShapedArray(shape, np.dtype('int32')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that argmin
and argmax
always return int64
? Or they do at least when used from the NumPy API. The primitives themselves are parameterized by the integer type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed we'd want that for the jnp
API, though we don't necessarily need it on the lax API here.
lax.argmin
and lax.argmax
are parameterized by the dtype, but because our lax.axis_index
implementation only returns 32bit values (in turn because XLA's ReplicaId
produces a uint32
value) 64bit support here won't be complete. But it's easy to add!
new_axis = axis + (d <= axis) | ||
return prim.bind(x, axis=new_axis) | ||
else: | ||
return prim.bind(x, axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this also have to return a second component for the batch axis in the output? How come aren't the tests catching this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, I forgot to add a test for this case! I should add an xmap test that uses systematic schedules.
else: | ||
return dict(params, axis=subst(axis)) | ||
core.axis_substitution_rules[pargmax_p] = _pargminmax_axis_subst_rule | ||
core.axis_substitution_rules[pargmin_p] = _pargminmax_axis_subst_rule |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just use partial(_subst_all_names_in_param, 'axis')
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, your current implementation is ill-typed, because subst(axis)
might return a tuple of axes, which will break all other rules. We need those primitives to have multi-axis support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.argmin
/ np.argmax
don't support tuples of axes, so I don't think these primitives should either. Why do you think we need these primitives to support tuples of axes? The issue might be in the substitution code if it requires that. I was wrong!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, for the 'soft_pmap case'. I'm not sure what should be returned as an argmin over two axes though... it could be the flattened index, or it could be a tuple of indices.
|
||
def pargmin(x, axis_name): | ||
if isinstance(axis_name, (tuple, list)): | ||
raise TypeError(f"pargmin only accepts a single axis, got {axis_name}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While we can definitely get away with primitives not supporting multiple axes in the user-facing API, we definitely cannot do that internally. xmap
usually replaces each user-facing axis with multiple internal axes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, we could if we change the xmap
lowering strategy not to use the split-axis approach :P
But point taken, I forgot about this constraint.
To fix the jax2tf/primitives_test failures, you need to add the new primitives here |
Follow up from #5682, which added these functions without defining new primitives.
By defining new primitives, we make
vmap
-of-pargmin
more efficient, avoiding the "vectorized efficiency problem", which is the same reason why we're definingpsum
,pdot
, etc. as primitives that support both named and positional axis arguments.Because these primitives support both named and positional arguments, they can replace
argmin_p
andargmax_p
, though that's left as follow-up work.There's some weird failure with 2D float32 pmap test cases on CPU. It seems to be CPU-specific because the failures don't happen on TPU. The same issue does not arise on CPU in the int32 test cases with the same shape. (The failures are not due to the changes in
_axis_index_bind
or_axis_index_translation_rule
; I made those changes to simplify the XLA HLO computation I was looking at while trying to debug this issue.)An example of this failure is
PmapTest.testArgAllReduce_shape=float32[4,2]_axis=1_collective=pargmax
. The expected output is[0, 1, 1, 0]
but the actual output is[ 0, 2147483647, 2147483647, 0]
. The HLO looks correct though:I think I may just file an XLA:CPU bug about this...