Skip to content

Commit 86f15ca

Browse files
committed
Make TensorCoreTiledLayout import more robust
Summary: Fixes pytorch#1908. Previous attempts (pytorch#1886) to fix this issue still break in local settings, so it's more robust and simpler to just try catch the import error. Test Plan: ``` from torchtune.training.quantization import * ```
1 parent d3039da commit 86f15ca

File tree

2 files changed

+5
-16
lines changed

2 files changed

+5
-16
lines changed

torchtune/training/quantization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
from typing import Callable, Optional
88

9-
from torchtune.utils._import_guard import _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API
10-
11-
if _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API:
9+
try:
10+
# torchao 0.7+
1211
from torchao.dtypes import TensorCoreTiledLayout
13-
else:
12+
except ImportError:
13+
# torchao 0.6 and before
1414
from torchao.dtypes import TensorCoreTiledLayoutType as TensorCoreTiledLayout
1515

1616
from torchao.quantization import (

torchtune/utils/_import_guard.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88
import torchao
9-
from torchtune.utils._version import _is_fbcode, _nightly_version_ge, torch_version_ge
9+
from torchtune.utils._version import torch_version_ge
1010

1111
# We can only use flex attention / BlockMask if torch version >= 2.5.0 and GPU is Turing / SM75 and above
1212
_SUPPORTS_FLEX_ATTENTION = (
@@ -16,14 +16,3 @@
1616
)
1717

1818
torchao_version = torchao.__version__
19-
20-
_USE_NEW_TENSOR_CORE_TILED_LAYOUT_API = _is_fbcode() or (
21-
not _is_fbcode()
22-
and (
23-
("dev" not in torchao_version and torchao_version >= "0.7.0")
24-
or (
25-
"dev" in torchao_version
26-
and _nightly_version_ge(torchao_version, "2024-10-10")
27-
)
28-
)
29-
)

0 commit comments

Comments
 (0)