File tree Expand file tree Collapse file tree 2 files changed +5
-19
lines changed Expand file tree Collapse file tree 2 files changed +5
-19
lines changed Original file line number Diff line number Diff line change 6
6
7
7
from typing import Callable , Optional
8
8
9
- from torchtune .utils ._import_guard import _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API
10
-
11
- if _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API :
9
+ try :
10
+ # torchao 0.7+
12
11
from torchao .dtypes import TensorCoreTiledLayout
13
- else :
12
+ except ImportError :
13
+ # torchao 0.6 and before
14
14
from torchao .dtypes import TensorCoreTiledLayoutType as TensorCoreTiledLayout
15
15
16
16
from torchao .quantization import (
Original file line number Diff line number Diff line change 5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import torch
8
- import torchao
9
- from torchtune .utils ._version import _is_fbcode , _nightly_version_ge , torch_version_ge
8
+ from torchtune .utils ._version import torch_version_ge
10
9
11
10
# We can only use flex attention / BlockMask if torch version >= 2.5.0 and GPU is Turing / SM75 and above
12
11
_SUPPORTS_FLEX_ATTENTION = (
13
12
torch_version_ge ("2.5.0" )
14
13
and torch .cuda .is_available ()
15
14
and torch .cuda .get_device_capability () >= (7 , 5 )
16
15
)
17
-
18
- torchao_version = torchao .__version__
19
-
20
- _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API = _is_fbcode () or (
21
- not _is_fbcode ()
22
- and (
23
- ("dev" not in torchao_version and torchao_version >= "0.7.0" )
24
- or (
25
- "dev" in torchao_version
26
- and _nightly_version_ge (torchao_version , "2024-10-10" )
27
- )
28
- )
29
- )
You can’t perform that action at this time.
0 commit comments