Skip to content

Commit 0be998b

Browse files
echarlaixArthurZucker
authored andcommitted
Requires for torch.tensor before casting (#31755)
1 parent b7ee1e8 commit 0be998b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/utils/generic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,7 @@ def torch_int(x):
762762

763763
import torch
764764

765-
return x.to(torch.int64) if torch.jit.is_tracing() else int(x)
765+
return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
766766

767767

768768
def torch_float(x):
@@ -774,7 +774,7 @@ def torch_float(x):
774774

775775
import torch
776776

777-
return x.to(torch.float32) if torch.jit.is_tracing() else int(x)
777+
return x.to(torch.float32) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
778778

779779

780780
def filter_out_non_signature_kwargs(extra: Optional[list] = None):

0 commit comments

Comments
 (0)