Skip to content

add pargmin_p / pargmax_p primitives #5694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Feb 10, 2021

Follow up from #5682, which added these functions without defining new primitives.

By defining new primitives, we make vmap-of-pargmin more efficient, avoiding the "vectorized efficiency problem", which is the same reason why we're defining psum, pdot, etc. as primitives that support both named and positional axis arguments.

Because these primitives support both named and positional arguments, they can replace argmin_p and argmax_p, though that's left as follow-up work.


There's some weird failure with 2D float32 pmap test cases on CPU. It seems to be CPU-specific because the failures don't happen on TPU. The same issue does not arise on CPU in the int32 test cases with the same shape. (The failures are not due to the changes in _axis_index_bind or _axis_index_translation_rule; I made those changes to simplify the XLA HLO computation I was looking at while trying to debug this issue.)

An example of this failure is PmapTest.testArgAllReduce_shape=float32[4,2]_axis=1_collective=pargmax. The expected output is [0, 1, 1, 0] but the actual output is [ 0, 2147483647, 2147483647, 0]. The HLO looks correct though:

HloModule pmap__lambda_.26

primitive_computation_max.9 {
  parameter.10 = f32[] parameter(0)
  parameter.11 = f32[] parameter(1)
  ROOT maximum.12 = f32[] maximum(parameter.10, parameter.11)
}

primitive_computation_min.18 {
  parameter.19 = s32[] parameter(0)
  parameter.20 = s32[] parameter(1)
  ROOT minimum.21 = s32[] minimum(parameter.19, parameter.20)
}

ENTRY pmap__lambda_.26 {
  constant.2 = pred[] constant(false)
  constant.3 = pred[] constant(false)
  parameter.1 = f32[4]{0} parameter(0), parameter_replication={false}
  all-reduce.13 = f32[4]{0} all-reduce(parameter.1), replica_groups={{0,1}}, to_apply=primitive_computation_max.9
  tuple.14 = (f32[4]{0}) tuple(all-reduce.13)
  get-tuple-element.15 = f32[4]{0} get-tuple-element(tuple.14), index=0
  compare.16 = pred[4]{0} compare(get-tuple-element.15, parameter.1), direction=EQ
  replica-id.4 = u32[] replica-id()
  convert.5 = s32[] convert(replica-id.4)
  broadcast.6 = s32[4]{0} broadcast(convert.5), dimensions={}
  constant.7 = s32[] constant(2147483647)
  broadcast.8 = s32[4]{0} broadcast(constant.7), dimensions={}
  select.17 = s32[4]{0} select(compare.16, broadcast.6, broadcast.8)
  all-reduce.22 = s32[4]{0} all-reduce(select.17), replica_groups={{0,1}}, to_apply=primitive_computation_min.18
  tuple.23 = (s32[4]{0}) tuple(all-reduce.22)
  get-tuple-element.24 = s32[4]{0} get-tuple-element(tuple.23), index=0
  ROOT tuple.25 = (s32[4]{0}) tuple(get-tuple-element.24)
}

I think I may just file an XLA:CPU bug about this...

@google-cla google-cla bot added the cla: yes label Feb 10, 2021
@mattjj mattjj force-pushed the pargmax-primitive branch 2 times, most recently from 4108482 to a3a439e Compare February 10, 2021 03:13
@mattjj mattjj added the pull ready Ready for copybara import and testing label Feb 10, 2021
@mattjj mattjj force-pushed the pargmax-primitive branch 4 times, most recently from 9a72481 to 68b8f73 Compare February 10, 2021 03:51
@mattjj mattjj requested a review from apaszke February 10, 2021 04:16
@mattjj mattjj marked this pull request as ready for review February 10, 2021 04:16
else:
# TODO(mattjj): avals with names, check and eliminate axis name
shape = np.shape(x)
return ShapedArray(shape, np.dtype('int32'))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that argmin and argmax always return int64? Or they do at least when used from the NumPy API. The primitives themselves are parameterized by the integer type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed we'd want that for the jnp API, though we don't necessarily need it on the lax API here.

lax.argmin and lax.argmax are parameterized by the dtype, but because our lax.axis_index implementation only returns 32bit values (in turn because XLA's ReplicaId produces a uint32 value) 64bit support here won't be complete. But it's easy to add!

new_axis = axis + (d <= axis)
return prim.bind(x, axis=new_axis)
else:
return prim.bind(x, axis)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this also have to return a second component for the batch axis in the output? How come aren't the tests catching this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I forgot to add a test for this case! I should add an xmap test that uses systematic schedules.

else:
return dict(params, axis=subst(axis))
core.axis_substitution_rules[pargmax_p] = _pargminmax_axis_subst_rule
core.axis_substitution_rules[pargmin_p] = _pargminmax_axis_subst_rule
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just use partial(_subst_all_names_in_param, 'axis') here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, your current implementation is ill-typed, because subst(axis) might return a tuple of axes, which will break all other rules. We need those primitives to have multi-axis support.

Copy link
Collaborator Author

@mattjj mattjj Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.argmin / np.argmax don't support tuples of axes, so I don't think these primitives should either. Why do you think we need these primitives to support tuples of axes? The issue might be in the substitution code if it requires that. I was wrong!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, for the 'soft_pmap case'. I'm not sure what should be returned as an argmin over two axes though... it could be the flattened index, or it could be a tuple of indices.


def pargmin(x, axis_name):
if isinstance(axis_name, (tuple, list)):
raise TypeError(f"pargmin only accepts a single axis, got {axis_name}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we can definitely get away with primitives not supporting multiple axes in the user-facing API, we definitely cannot do that internally. xmap usually replaces each user-facing axis with multiple internal axes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we could if we change the xmap lowering strategy not to use the split-axis approach :P

But point taken, I forgot about this constraint.

@gnecula
Copy link
Collaborator

gnecula commented Feb 10, 2021

To fix the jax2tf/primitives_test failures, you need to add the new primitives here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants