diff --git a/pytensor/tensor/rewriting/uncanonicalize.py b/pytensor/tensor/rewriting/uncanonicalize.py index 09b6819737..ae631665ec 100644 --- a/pytensor/tensor/rewriting/uncanonicalize.py +++ b/pytensor/tensor/rewriting/uncanonicalize.py @@ -174,11 +174,15 @@ def local_dimshuffle_alloc(fgraph, node): def local_dimshuffle_subtensor(fgraph, node): """If a subtensor is inside a dimshuffle which only drop broadcastable dimensions, scrap the dimshuffle and index the - subtensor with 0 + subtensor in a way that avoids the degenerate dimension x[i:j, :, k:l].dimshuffle(0, 2) => x[i:j, 0, k:l] if x.broadcastable == (False, True, False) + x[i:j, k:l, :].dimshuffle(0, 2) => x[i:j, k, :] + x[i:j, k:, :].dimshuffle(0, 2) => x[i:j, k, :] + x[i:j, :l, :].dimshuffle(0, 2) => x[i:j, 0, :] + """ if isinstance(node.op, DimShuffle) and node.inputs[0].owner: # the dimshuffle can only drop dimensions (cannot reshape nor add 'x') @@ -217,24 +221,40 @@ def local_dimshuffle_subtensor(fgraph, node): new_idx_list = list(input_.owner.op.idx_list) new_inputs = [input_.owner.inputs[0]] zero = constant(0) - slice_attr_list = ["start", "stop", "step"] j = 0 slice_i = -1 subtensor_removed_dims = 0 for i, idx in enumerate(input_.owner.op.idx_list): if isinstance(idx, slice): - past_j = j slice_i += 1 - for slice_attr in slice_attr_list: - if getattr(idx, slice_attr) is not None: - new_inputs += [input_.owner.inputs[1 + j]] - j += 1 - # if past_j == j indicates a slice(None, None, None), - # that's where we want to index with 0 if it is also at - # the same spot of a missing dim - if past_j == j and slice_i in missing_dims: - new_idx_list[i] = zero - new_inputs += [zero] + if slice_i in missing_dims: + # Missing dim is a slice(None), remove by indexing by 0 + if idx == slice(None): + new_idx_list[i] = zero + new_inputs += [zero] + # Missing dim is an ordinary slice with known output dim length of 1 + # Remove by indexing by start + else: + if idx.start is None: + start = zero + else: + start = input_.owner.inputs[1 + j] + j += 1 + new_idx_list[i] = start + new_inputs += [start] + + # Ignore useless stop and step input if there is one + for slice_attr in ("stop", "step"): + if getattr(idx, slice_attr) is not None: + j += 1 + + # Keep non-dropped slice inputs + else: + for slice_attr in ("start", "stop", "step"): + if getattr(idx, slice_attr) is not None: + new_inputs += [input_.owner.inputs[1 + j]] + j += 1 + # Keep non-dropped non-slice inputs else: new_inputs += [input_.owner.inputs[1 + j]] j += 1 diff --git a/tests/tensor/rewriting/test_uncanonicalize.py b/tests/tensor/rewriting/test_uncanonicalize.py index 995352374e..865da83137 100644 --- a/tests/tensor/rewriting/test_uncanonicalize.py +++ b/tests/tensor/rewriting/test_uncanonicalize.py @@ -214,3 +214,11 @@ def test_local_dimshuffle_subtensor(): assert x[:, :, 0:3, ::-1].dimshuffle(0, 2, 3).eval( {x: np.ones((5, 1, 6, 7))} ).shape == (5, 3, 7) + + # Test dropped sliced dimensions + x = matrix("x", shape=(5, 4), dtype="float64") + + assert x[2:3, :-2].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (2,) + assert x[:1, 0:3].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (3,) + assert x[-1:, :].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (4,) + assert x[4:3:-1, 1:].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (3,)