Skip to content

Commit 30193d8

Browse files
gmagogsfmfacebook-github-bot
authored andcommitted
deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/test/quantize (pytorch#596)
Summary: Pull Request resolved: facebookresearch/FBGEMM#596 X-link: pytorch#3512 Reviewed By: avikchaudhuri Differential Revision: D67381311 fbshipit-source-id: 345264f99d6f4b77508b4ea95fe20b3482ad1f04
1 parent 7f68e29 commit 30193d8

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
134134

135135
model = TestModule().cuda()
136136
# bf16 required here
137-
_ = torch.export.export(model, (torch.randn(32, 32).to(torch.bfloat16).cuda(),))
137+
_ = torch.export.export(
138+
model, (torch.randn(32, 32).to(torch.bfloat16).cuda(),), strict=True
139+
)
138140

139141
def test_f8f8bf16_export(self) -> None:
140142
class TestModule(torch.nn.Module):
@@ -161,7 +163,7 @@ def forward(self, xq: torch.Tensor, wq: torch.Tensor) -> torch.Tensor:
161163
fp8_dtype = torch.float8_e4m3fnuz
162164
xq = torch.randn(M, K).to(fp8_dtype).cuda()
163165
wq = torch.randn(N, K).to(fp8_dtype).cuda()
164-
_ = torch.export.export(model, (xq, wq))
166+
_ = torch.export.export(model, (xq, wq), strict=True)
165167

166168

167169
@unittest.skipIf(

0 commit comments

Comments
 (0)