Skip to content

Commit d0c6460

Browse files
authored
Make TensorCoreTiledLayout import more robust (#1912)
1 parent d3039da commit d0c6460

File tree

2 files changed

+5
-19
lines changed

2 files changed

+5
-19
lines changed

torchtune/training/quantization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
from typing import Callable, Optional
88

9-
from torchtune.utils._import_guard import _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API
10-
11-
if _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API:
9+
try:
10+
# torchao 0.7+
1211
from torchao.dtypes import TensorCoreTiledLayout
13-
else:
12+
except ImportError:
13+
# torchao 0.6 and before
1414
from torchao.dtypes import TensorCoreTiledLayoutType as TensorCoreTiledLayout
1515

1616
from torchao.quantization import (

torchtune/utils/_import_guard.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
import torchao
9-
from torchtune.utils._version import _is_fbcode, _nightly_version_ge, torch_version_ge
8+
from torchtune.utils._version import torch_version_ge
109

1110
# We can only use flex attention / BlockMask if torch version >= 2.5.0 and GPU is Turing / SM75 and above
1211
_SUPPORTS_FLEX_ATTENTION = (
1312
torch_version_ge("2.5.0")
1413
and torch.cuda.is_available()
1514
and torch.cuda.get_device_capability() >= (7, 5)
1615
)
17-
18-
torchao_version = torchao.__version__
19-
20-
_USE_NEW_TENSOR_CORE_TILED_LAYOUT_API = _is_fbcode() or (
21-
not _is_fbcode()
22-
and (
23-
("dev" not in torchao_version and torchao_version >= "0.7.0")
24-
or (
25-
"dev" in torchao_version
26-
and _nightly_version_ge(torchao_version, "2024-10-10")
27-
)
28-
)
29-
)

0 commit comments

Comments
 (0)