Skip to content

Commit e0988aa

Browse files
authored
Add TorchExpm1Visitor (#78)
Follow-up for #77
1 parent 7bf6f67 commit e0988aa

File tree

6 files changed

+62
-2
lines changed

6 files changed

+62
-2
lines changed

tests/fixtures/misc/checker/expm1.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
a = torch.randn(5)
3+
b = torch.exp(a) - 1
4+
c = torch.exp(a) - 1.0
5+
6+
ret = (torch.exp(a) - 1) * torch.exp(2 * b)
7+
8+
# False negative: can not detect currently
9+
x = a.exp() - 1
10+
11+
# False negative: should be rare and would complicate implementation
12+
x = -1 + torch.exp(a)

tests/fixtures/misc/checker/expm1.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
3:5 TOR107 Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. It is more accurate for small values of `x`.
2+
4:5 TOR107 Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. It is more accurate for small values of `x`.
3+
6:7 TOR107 Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. It is more accurate for small values of `x`.

tests/test_torchfix.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,15 @@ def pytest_generate_tests(metafunc):
3737
("TOR102,TOR101", {"TOR102", "TOR101"}),
3838
(
3939
"TOR1,TOR102",
40-
{"TOR102", "TOR101", "TOR103", "TOR104", "TOR105", "TOR106"},
40+
{
41+
"TOR101",
42+
"TOR102",
43+
"TOR103",
44+
"TOR104",
45+
"TOR105",
46+
"TOR106",
47+
"TOR107",
48+
},
4149
),
4250
(None, set(GET_ALL_ERROR_CODES()) - exclude_set),
4351
]

torchfix/torchfix.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from .visitors import (
1111
TorchDeprecatedSymbolsVisitor,
12+
TorchExpm1Visitor,
1213
TorchLog1pVisitor,
1314
TorchNonPublicAliasVisitor,
1415
TorchReentrantCheckpointVisitor,
@@ -29,6 +30,7 @@
2930

3031
ALL_VISITOR_CLS = [
3132
TorchDeprecatedSymbolsVisitor,
33+
TorchExpm1Visitor,
3234
TorchLog1pVisitor,
3335
TorchNonPublicAliasVisitor,
3436
TorchRequireGradVisitor,

torchfix/visitors/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from .deprecated_symbols import TorchDeprecatedSymbolsVisitor
22
from .internal import TorchScopedLibraryVisitor
33
from .misc import (
4+
TorchExpm1Visitor,
5+
TorchLog1pVisitor,
46
TorchReentrantCheckpointVisitor,
57
TorchRequireGradVisitor,
6-
TorchLog1pVisitor,
78
)
89
from .nonpublic import TorchNonPublicAliasVisitor
910
from .performance import TorchSynchronizedDataLoaderVisitor
@@ -16,6 +17,7 @@
1617

1718
__all__ = [
1819
"TorchDeprecatedSymbolsVisitor",
20+
"TorchExpm1Visitor",
1921
"TorchLog1pVisitor",
2022
"TorchNonPublicAliasVisitor",
2123
"TorchReentrantCheckpointVisitor",

torchfix/visitors/misc/__init__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,36 @@ def visit_Call(self, node):
121121
message=self.ERRORS[0].message(),
122122
replacement=None,
123123
)
124+
125+
126+
class TorchExpm1Visitor(TorchVisitor):
127+
"""
128+
Suggest using `torch.special.expm1(x)` instead of `torch.exp(x) - 1`.
129+
"""
130+
131+
ERRORS = [
132+
TorchError(
133+
"TOR107",
134+
(
135+
"Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. "
136+
"It is more accurate for small values of `x`."
137+
),
138+
)
139+
]
140+
141+
def visit_BinaryOperation(self, node):
142+
if m.matches(
143+
node,
144+
m.BinaryOperation(
145+
left=m.Call(),
146+
operator=m.Subtract(),
147+
right=m.Integer(value="1") | m.Float(value="1.0"),
148+
),
149+
):
150+
if self.get_qualified_name_for_call(node.left) == "torch.exp":
151+
self.add_violation(
152+
node,
153+
error_code=self.ERRORS[0].error_code,
154+
message=self.ERRORS[0].message(),
155+
replacement=None,
156+
)

0 commit comments

Comments
 (0)