@@ -166,6 +166,10 @@ def __add__(self, other) -> "Tensor":
166
166
Returns:
167
167
Tensor: The result of the addition.
168
168
"""
169
+ if isinstance (other , (int , float )):
170
+ other = self .factory .constant (
171
+ torch .tensor ([other ], dtype = self .dtype .torch_dtype )
172
+ )
169
173
return generate_op ([self , other ], "eltwise_add" )
170
174
171
175
def __sub__ (self , other ) -> "Tensor" :
@@ -178,6 +182,10 @@ def __sub__(self, other) -> "Tensor":
178
182
Returns:
179
183
Tensor: The result of the subtraction.
180
184
"""
185
+ if isinstance (other , (int , float )):
186
+ other = self .factory .constant (
187
+ torch .tensor ([other ], dtype = self .dtype .torch_dtype )
188
+ )
181
189
return generate_op ([self , - other ], "eltwise_add" )
182
190
183
191
def __mul__ (self , other ) -> "Tensor" :
@@ -190,6 +198,10 @@ def __mul__(self, other) -> "Tensor":
190
198
Returns:
191
199
Tensor: The result of the multiplication.
192
200
"""
201
+ if isinstance (other , (int , float )):
202
+ other = self .factory .constant (
203
+ torch .tensor ([other ], dtype = self .dtype .torch_dtype )
204
+ )
193
205
return generate_op ([self , other ], "eltwise_mul" )
194
206
195
207
def __truediv__ (self , other ) -> "Tensor" :
@@ -202,8 +214,76 @@ def __truediv__(self, other) -> "Tensor":
202
214
Returns:
203
215
Tensor: The result of the division.
204
216
"""
217
+ if isinstance (other , (int , float )):
218
+ other = self .factory .constant (
219
+ torch .tensor ([other ], dtype = self .dtype .torch_dtype )
220
+ )
205
221
return generate_op ([self , other ], "eltwise_div" )
206
222
223
+ def __radd__ (self , other ) -> "Tensor" :
224
+ """
225
+ Add two tensors element-wise.
226
+
227
+ Args:
228
+ other (Tensor): The tensor to be added.
229
+
230
+ Returns:
231
+ Tensor: The result of the addition.
232
+ """
233
+ if isinstance (other , (int , float )):
234
+ other = self .factory .constant (
235
+ torch .tensor ([other ], dtype = self .dtype .torch_dtype )
236
+ )
237
+ return generate_op ([other , self ], "eltwise_add" )
238
+
239
+ def __rsub__ (self , other ) -> "Tensor" :
240
+ """
241
+ Subtract two tensors element-wise.
242
+
243
+ Args:
244
+ other (Tensor): The tensor to be subtracted.
245
+
246
+ Returns:
247
+ Tensor: The result of the subtraction.
248
+ """
249
+ if isinstance (other , (int , float )):
250
+ other = self .factory .constant (
251
+ torch .tensor ([other ], dtype = self .dtype .torch_dtype )
252
+ )
253
+ return generate_op ([other , - self ], "eltwise_add" )
254
+
255
+ def __rmul__ (self , other ) -> "Tensor" :
256
+ """
257
+ Multiply two tensors element-wise.
258
+
259
+ Args:
260
+ other (Tensor): The tensor to be multiplied.
261
+
262
+ Returns:
263
+ Tensor: The result of the multiplication.
264
+ """
265
+ if isinstance (other , (int , float )):
266
+ other = self .factory .constant (
267
+ torch .tensor ([other ], dtype = self .dtype .torch_dtype )
268
+ )
269
+ return generate_op ([other , self ], "eltwise_mul" )
270
+
271
+ def __rtruediv__ (self , other ) -> "Tensor" :
272
+ """
273
+ Divide two tensors element-wise.
274
+
275
+ Args:
276
+ other (Tensor): The tensor to be divided.
277
+
278
+ Returns:
279
+ Tensor: The result of the division.
280
+ """
281
+ if isinstance (other , (int , float )):
282
+ other = self .factory .constant (
283
+ torch .tensor ([other ], dtype = self .dtype .torch_dtype )
284
+ )
285
+ return generate_op ([other , self ], "eltwise_div" )
286
+
207
287
def __neg__ (self ) -> "Tensor" :
208
288
"""
209
289
Negate the tensor.
0 commit comments