We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b7ee1e8 commit 0be998bCopy full SHA for 0be998b
src/transformers/utils/generic.py
@@ -762,7 +762,7 @@ def torch_int(x):
762
763
import torch
764
765
- return x.to(torch.int64) if torch.jit.is_tracing() else int(x)
+ return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
766
767
768
def torch_float(x):
@@ -774,7 +774,7 @@ def torch_float(x):
774
775
776
777
- return x.to(torch.float32) if torch.jit.is_tracing() else int(x)
+ return x.to(torch.float32) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
778
779
780
def filter_out_non_signature_kwargs(extra: Optional[list] = None):
0 commit comments