Skip to content

Commit f915fb3

Browse files
committed
fix: Add decomposition for aten.addmm
- Decompose addmm operator into mul, matmul, and add ops - Add test case for addmm decomposition
1 parent 0f35954 commit f915fb3

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,14 @@ def alias_replacement(x: torch.Tensor) -> torch.Tensor:
5656
return x
5757

5858

59+
@register_decomposition(torch.ops.aten.addmm, registry=DECOMPOSITIONS)
60+
def addmm_replacement(
61+
input_: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, *, beta=1, alpha=1
62+
) -> torch.Tensor:
63+
return torch.add(
64+
torch.mul(input_, beta), torch.mul(torch.matmul(mat1, mat2), alpha)
65+
)
66+
67+
5968
def get_decompositions():
6069
return DECOMPOSITIONS

py/torch_tensorrt/dynamo/backend/test/test_lowering.py renamed to py/torch_tensorrt/dynamo/backend/test/test_decompositions.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from utils import lower_graph_testing
33
from torch.testing._internal.common_utils import run_tests, TestCase
44
import torch
5+
from torch_tensorrt.dynamo import compile
56

67

78
class TestLowering(TestCase):
@@ -109,6 +110,69 @@ def forward(self, x):
109110
f"The following expected ops were not encountered: {expected_ops_unseen}",
110111
)
111112

113+
def test_lowering_addmm(self):
114+
class AddMM(torch.nn.Module):
115+
def forward(self, x, y, z):
116+
return torch.addmm(x, y, z, beta=16, alpha=5)
117+
118+
# Operations expected to be included in the traced graph after decompositions
119+
expected_ops = {
120+
torch.ops.aten.add.Tensor,
121+
torch.ops.aten.mul.Tensor,
122+
torch.ops.aten.mm.default,
123+
}
124+
unexpected_ops = {torch.ops.aten.addmm.default}
125+
126+
inputs = [
127+
torch.rand(
128+
1,
129+
1,
130+
).cuda(),
131+
torch.rand(
132+
7,
133+
8,
134+
).cuda(),
135+
torch.rand(
136+
8,
137+
9,
138+
).cuda(),
139+
]
140+
141+
fx_graph = torch.fx.symbolic_trace(AddMM())
142+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
143+
fx_graph,
144+
inputs,
145+
expected_ops=expected_ops,
146+
unexpected_ops=unexpected_ops,
147+
min_block_size=1,
148+
)
149+
150+
self.assertEquals(
151+
len(unexpected_ops_seen),
152+
0,
153+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
154+
)
155+
156+
self.assertEquals(
157+
len(expected_ops_unseen),
158+
0,
159+
f"The following expected ops were not encountered: {expected_ops_unseen}",
160+
)
161+
162+
torch._dynamo.reset()
163+
164+
# Validate that the results between Torch and Torch-TRT are similar
165+
optimized_model = compile(
166+
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
167+
)
168+
optimized_model_results = optimized_model(*inputs).detach().cpu()
169+
torch_model_results = fx_graph(*inputs).detach().cpu()
170+
171+
max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results))
172+
self.assertAlmostEqual(
173+
max_diff, 0, f"AddMM TRT outputs don't match with the original model."
174+
)
175+
112176

113177
if __name__ == "__main__":
114178
run_tests()

0 commit comments

Comments
 (0)