Skip to content
This repository was archived by the owner on Apr 24, 2025. It is now read-only.

Commit 66c1205

Browse files
Fix ops and r_ops in case of float and int (#88)
* Fix ops and r_ops in case of float and int * Random input
1 parent a35dea0 commit 66c1205

File tree

2 files changed

+120
-0
lines changed

2 files changed

+120
-0
lines changed

intel_npu_acceleration_library/backend/tensor.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ def __add__(self, other) -> "Tensor":
166166
Returns:
167167
Tensor: The result of the addition.
168168
"""
169+
if isinstance(other, (int, float)):
170+
other = self.factory.constant(
171+
torch.tensor([other], dtype=self.dtype.torch_dtype)
172+
)
169173
return generate_op([self, other], "eltwise_add")
170174

171175
def __sub__(self, other) -> "Tensor":
@@ -178,6 +182,10 @@ def __sub__(self, other) -> "Tensor":
178182
Returns:
179183
Tensor: The result of the subtraction.
180184
"""
185+
if isinstance(other, (int, float)):
186+
other = self.factory.constant(
187+
torch.tensor([other], dtype=self.dtype.torch_dtype)
188+
)
181189
return generate_op([self, -other], "eltwise_add")
182190

183191
def __mul__(self, other) -> "Tensor":
@@ -190,6 +198,10 @@ def __mul__(self, other) -> "Tensor":
190198
Returns:
191199
Tensor: The result of the multiplication.
192200
"""
201+
if isinstance(other, (int, float)):
202+
other = self.factory.constant(
203+
torch.tensor([other], dtype=self.dtype.torch_dtype)
204+
)
193205
return generate_op([self, other], "eltwise_mul")
194206

195207
def __truediv__(self, other) -> "Tensor":
@@ -202,8 +214,76 @@ def __truediv__(self, other) -> "Tensor":
202214
Returns:
203215
Tensor: The result of the division.
204216
"""
217+
if isinstance(other, (int, float)):
218+
other = self.factory.constant(
219+
torch.tensor([other], dtype=self.dtype.torch_dtype)
220+
)
205221
return generate_op([self, other], "eltwise_div")
206222

223+
def __radd__(self, other) -> "Tensor":
224+
"""
225+
Add two tensors element-wise.
226+
227+
Args:
228+
other (Tensor): The tensor to be added.
229+
230+
Returns:
231+
Tensor: The result of the addition.
232+
"""
233+
if isinstance(other, (int, float)):
234+
other = self.factory.constant(
235+
torch.tensor([other], dtype=self.dtype.torch_dtype)
236+
)
237+
return generate_op([other, self], "eltwise_add")
238+
239+
def __rsub__(self, other) -> "Tensor":
240+
"""
241+
Subtract two tensors element-wise.
242+
243+
Args:
244+
other (Tensor): The tensor to be subtracted.
245+
246+
Returns:
247+
Tensor: The result of the subtraction.
248+
"""
249+
if isinstance(other, (int, float)):
250+
other = self.factory.constant(
251+
torch.tensor([other], dtype=self.dtype.torch_dtype)
252+
)
253+
return generate_op([other, -self], "eltwise_add")
254+
255+
def __rmul__(self, other) -> "Tensor":
256+
"""
257+
Multiply two tensors element-wise.
258+
259+
Args:
260+
other (Tensor): The tensor to be multiplied.
261+
262+
Returns:
263+
Tensor: The result of the multiplication.
264+
"""
265+
if isinstance(other, (int, float)):
266+
other = self.factory.constant(
267+
torch.tensor([other], dtype=self.dtype.torch_dtype)
268+
)
269+
return generate_op([other, self], "eltwise_mul")
270+
271+
def __rtruediv__(self, other) -> "Tensor":
272+
"""
273+
Divide two tensors element-wise.
274+
275+
Args:
276+
other (Tensor): The tensor to be divided.
277+
278+
Returns:
279+
Tensor: The result of the division.
280+
"""
281+
if isinstance(other, (int, float)):
282+
other = self.factory.constant(
283+
torch.tensor([other], dtype=self.dtype.torch_dtype)
284+
)
285+
return generate_op([other, self], "eltwise_div")
286+
207287
def __neg__(self) -> "Tensor":
208288
"""
209289
Negate the tensor.

test/python/test_tensor.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,43 @@ def test_reduce_operations(batch, hidden_dim, axis, op):
286286
assert 1 - r2_score([reference, 1], [result, 1]) < 0.01
287287
else:
288288
assert 1 - r2_score(reference, result) < 0.01
289+
290+
291+
@pytest.mark.parametrize("shape", [[1, 128, 32, 64], [12, 231]])
292+
@pytest.mark.parametrize("op", ["+", "-", "*", "/"])
293+
@pytest.mark.parametrize("side", ["left", "right"])
294+
@pytest.mark.parametrize("value", [3, -10])
295+
def test_float_op(shape, op, side, value):
296+
def op_func(a, b):
297+
if op == "+":
298+
return a + b
299+
elif op == "-":
300+
return a - b
301+
elif op == "*":
302+
return a * b
303+
elif op == "/":
304+
return a / b
305+
306+
def act(a, b):
307+
if side == "left":
308+
return op_func(b, a)
309+
else:
310+
return op_func(a, b)
311+
312+
x = torch.rand(shape).to(torch.float16) + 2
313+
reference = act(x, value)
314+
315+
model = NNFactory()
316+
t1 = model.parameter(shape, float16)
317+
out = act(t1, value)
318+
model.compile()
319+
320+
result = model(x)
321+
322+
assert (
323+
1
324+
- r2_score(
325+
reference.flatten().detach().numpy(), result.flatten().detach().numpy()
326+
)
327+
< 0.001
328+
)

0 commit comments

Comments
 (0)