Skip to content
This repository was archived by the owner on Apr 24, 2025. It is now read-only.

Commit 13b9e3d

Browse files
SarahByrneIntelSarahByrneIntel
andauthored
Adding support and testing for chunk tensor operation (#90)
* Add support and test for chunk tensor op * Fix for chunk tensor op --------- Co-authored-by: SarahByrneIntel <[email protected]>
1 parent 66c1205 commit 13b9e3d

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

intel_npu_acceleration_library/backend/tensor.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,43 @@ def sum(
899899
sum = sum.to(dtype)
900900
return sum
901901

902+
def chunk(
903+
self,
904+
chunks: int,
905+
dim: int = 0,
906+
) -> Union["Tensor", list]:
907+
"""
908+
Return the list of tensor chunks.
909+
910+
Args:
911+
chunks (int): The number of chunks to return.
912+
dim (int): The dimension along which to split the tensor. Default is 0.
913+
914+
Returns:
915+
Union["Tensor", list]: The resulting list of split tensors or a single tensor.
916+
917+
Raises:
918+
ValueError: The input chunks value is not valid.
919+
"""
920+
if chunks <= 0:
921+
raise ValueError("The input chunks value is not valid.")
922+
if chunks == 1:
923+
return self
924+
tensors = []
925+
remainder = self.shape[dim] % chunks
926+
chunk_size = self.shape[dim] // chunks + (1 if remainder > 0 else 0)
927+
num_dims = self.dim()
928+
929+
start_idx = 0
930+
for _ in range(chunks):
931+
indexes = [slice(None)] * num_dims
932+
end_idx = start_idx + chunk_size
933+
end_idx = end_idx if end_idx < self.shape[dim] else self.shape[dim]
934+
indexes[dim] = slice(start_idx, end_idx)
935+
tensors.append(self.__getitem__(tuple(indexes)))
936+
start_idx = end_idx
937+
return tensors
938+
902939
def to(self, dtype: NPUDtype) -> "Tensor":
903940
"""
904941
Convert the tensor to the specified data type.

test/python/test_tensor.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ def test_reduce_operations(batch, hidden_dim, axis, op):
270270
reference = eval(f"X.{op}(dim=axis)")
271271
reference = reference.numpy()
272272

273-
print(X.sum())
274273
model = NNFactory()
275274
t1 = model.parameter(X.shape)
276275
_ = eval(f"t1.{op}()") if axis is None else eval(f"t1.{op}(dim=axis)")
@@ -326,3 +325,37 @@ def act(a, b):
326325
)
327326
< 0.001
328327
)
328+
329+
330+
@pytest.mark.parametrize("batch", [16, 128])
331+
@pytest.mark.parametrize("hidden_dim", [256, 512])
332+
@pytest.mark.parametrize("chunks", [1, 2, 3, 4])
333+
@pytest.mark.parametrize("axis", [0, 1, -1, -2])
334+
def test_chunk_operation(batch, hidden_dim, chunks, axis):
335+
336+
X = torch.rand((batch, hidden_dim)).to(torch.float16)
337+
338+
reference = X.chunk(chunks=chunks, dim=axis)
339+
340+
model = NNFactory()
341+
t1 = model.parameter(X.shape)
342+
_ = t1.chunk(chunks=chunks, dim=axis)
343+
model.compile()
344+
345+
result = model(X)
346+
347+
if chunks == 1:
348+
assert np.isfinite(
349+
reference[0].numpy()
350+
).all(), "Pytorch Reference contains NaN or Inf"
351+
assert np.isfinite(result.numpy()).all(), "NPU output contains NaN or Inf"
352+
assert 1 - r2_score(reference[0].numpy(), result.numpy()) < 0.01
353+
else:
354+
for i in range(len(reference)):
355+
assert np.isfinite(
356+
reference[i].numpy()
357+
).all(), "Pytorch Reference contains NaN or Inf"
358+
assert np.isfinite(
359+
result[i].numpy()
360+
).all(), "NPU output contains NaN or Inf"
361+
assert 1 - r2_score(reference[i].numpy(), result[i].numpy()) < 0.01

0 commit comments

Comments
 (0)