Skip to content

Commit 88349ff

Browse files
Maggie Mossfacebook-github-bot
authored andcommitted
Flip deep learning (pytorch#2831)
Summary: Pull Request resolved: pytorch#2831 X-link: facebookresearch/FBGEMM#26 Reviewed By: connernilsen Differential Revision: D59653154 fbshipit-source-id: 2be7d6c6d540db34391bbd6e0dafe62eec185e82
1 parent 304b8a2 commit 88349ff

File tree

10 files changed

+47
-1
lines changed

10 files changed

+47
-1
lines changed

fbgemm_gpu/codegen/genscript/generate_backward_split.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
from typing import List
1414

1515
try:
16+
# pyre-fixme[21]: Could not find name `ArgType` in
17+
# `deeplearning.fbgemm.fbgemm_gpu.codegen.genscript.optimizers`.
18+
# pyre-fixme[21]: Could not find name `OptimItem` in
19+
# `deeplearning.fbgemm.fbgemm_gpu.codegen.genscript.optimizers`.
20+
# pyre-fixme[21]: Could not find name `OptimizerArgsSet` in
21+
# `deeplearning.fbgemm.fbgemm_gpu.codegen.genscript.optimizers`.
22+
# pyre-fixme[21]: Could not find name `generate_optimized_grad_sum_loop_access`
23+
# in `deeplearning.fbgemm.fbgemm_gpu.codegen.genscript.optimizers`.
1624
from .optimizers import *
1725
from .common import CodeTemplate
1826
from .optimizer_args import OptimizerArgsSet

fbgemm_gpu/codegen/genscript/optimizer_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def make_function_schema_arg(ty: ArgType, name: str, default: Union[int, float])
265265
ArgType.PLACEHOLDER_TENSOR: tensor_arg,
266266
ArgType.INT: lambda x: int_arg(x, default=int(default)),
267267
ArgType.FLOAT: lambda x: float_arg(x, default=default),
268+
# pyre-fixme[6]: For 2nd argument expected `int` but got `Union[float, int]`.
268269
ArgType.SYM_INT: lambda x: schema_sym_int_arg(x, default=default),
269270
}[ty](name)
270271

fbgemm_gpu/test/quantize/fp8_rowwise_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from hypothesis import given, settings, Verbosity
1818

1919
from . import common # noqa E402
20+
21+
# pyre-fixme[21]: Could not find name `open_source` in
22+
# `deeplearning.fbgemm.fbgemm_gpu.test.quantize.common`.
2023
from .common import open_source
2124

2225
if open_source:

fbgemm_gpu/test/quantize/fused_8bit_rowwise_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from hypothesis import assume, given, HealthCheck, settings
1616

1717
from . import common # noqa E402
18+
19+
# pyre-fixme[21]: Could not find name `open_source` in
20+
# `deeplearning.fbgemm.fbgemm_gpu.test.quantize.common`.
1821
from .common import (
1922
fused_rowwise_8bit_dequantize_2bytes_padding_scale_bias_first_reference,
2023
fused_rowwise_8bit_dequantize_reference,

fbgemm_gpu/test/quantize/fused_nbit_rowwise_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from hypothesis import assume, given, HealthCheck, settings
1616

1717
from . import common # noqa E402
18+
19+
# pyre-fixme[21]: Could not find name `open_source` in
20+
# `deeplearning.fbgemm.fbgemm_gpu.test.quantize.common`.
1821
from .common import (
1922
bytes_to_half_floats,
2023
fused_rowwise_nbit_quantize_dequantize_reference,

fbgemm_gpu/test/quantize/mixed_dim_int8_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from hypothesis import given, HealthCheck, settings
1616

1717
from . import common # noqa E402
18+
19+
# pyre-fixme[21]: Could not find name `open_source` in
20+
# `deeplearning.fbgemm.fbgemm_gpu.test.quantize.common`.
1821
from .common import open_source
1922

2023
if open_source:

fbgemm_gpu/test/quantize/msfp_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from hypothesis import given, HealthCheck, settings
1414

1515
from . import common # noqa E402
16+
17+
# pyre-fixme[21]: Could not find name `open_source` in
18+
# `deeplearning.fbgemm.fbgemm_gpu.test.quantize.common`.
1619
from .common import open_source
1720

1821
if open_source:

fbgemm_gpu/test/quantize/mx/common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ def _get_format_params( # noqa
142142

143143
min_norm = _get_min_norm(ebits)
144144

145+
# pyre-fixme[6]: For 1st argument expected `ElemFormat` but got `Union[None,
146+
# ElemFormat, int]`.
145147
_FORMAT_CACHE[fmt] = (ebits, mbits, emax, max_norm, min_norm)
146148

147149
return ebits, mbits, emax, max_norm, min_norm
@@ -308,6 +310,7 @@ def check_diff_quantize(
308310
raise IndexError
309311

310312
# Convert to numpy
313+
# pyre-fixme[9]: x has type `Tensor`; used as `Union[ndarray, Tensor]`.
311314
x = np.array(x) if type(x) is list else x
312315
x = x.cpu().numpy() if type(x) is torch.Tensor else x
313316
y1 = y1.detach().cpu().numpy()
@@ -510,11 +513,16 @@ def _quantize_elemwise_core(
510513
private_exp = None
511514

512515
# Scale up so appropriate number of bits are in the integer portion of the number
516+
# pyre-fixme[6]: For 3rd argument expected `Optional[int]` but got
517+
# `Optional[Tensor]`.
513518
out = _safe_lshift(out, bits - 2, private_exp)
514519

520+
# pyre-fixme[6]: For 3rd argument expected `RoundingMode` but got `str`.
515521
out = _round_mantissa(out, bits, round, clamp=False)
516522

517523
# Undo scaling
524+
# pyre-fixme[6]: For 3rd argument expected `Optional[int]` but got
525+
# `Optional[Tensor]`.
518526
out = _safe_rshift(out, bits - 2, private_exp)
519527

520528
# Set values > max_norm to Inf if desired, else clamp them

fbgemm_gpu/test/quantize/mx4_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from hypothesis import given, settings, Verbosity
2222

2323
from . import common # noqa E402
24+
25+
# pyre-fixme[21]: Could not find name `open_source` in
26+
# `deeplearning.fbgemm.fbgemm_gpu.test.quantize.common`.
2427
from .common import open_source
2528
from .mx.common import (
2629
_get_format_params,
@@ -133,6 +136,8 @@ def fake_quantize_mx(
133136

134137
# Undo tile reshaping
135138
if group_size:
139+
# pyre-fixme[61]: `padded_shape` is undefined, or not always defined.
140+
# pyre-fixme[61]: `orig_shape` is undefined, or not always defined.
136141
A = _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes)
137142

138143
return A

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def generate_ssd_tbes(
146146
ssd_shards: int = 1, # from SSDTableBatchedEmbeddingBags
147147
optimizer: OptimType = OptimType.EXACT_ROWWISE_ADAGRAD,
148148
cache_set_scale: float = 1.0,
149+
# pyre-fixme[9]: pooling_mode has type `bool`; used as `PoolingMode`.
149150
pooling_mode: bool = PoolingMode.SUM,
150151
weights_precision: SparseType = SparseType.FP32,
151152
output_dtype: SparseType = SparseType.FP32,
@@ -239,6 +240,10 @@ def generate_ssd_tbes(
239240
if weights_precision == SparseType.FP16:
240241
emb_ref = [emb.float() for emb in emb_ref]
241242

243+
# pyre-fixme[7]: Expected `Tuple[SSDTableBatchedEmbeddingBags,
244+
# List[EmbeddingBag]]` but got `Tuple[SSDTableBatchedEmbeddingBags,
245+
# Union[List[Union[Embedding, EmbeddingBag]], List[Embedding],
246+
# List[EmbeddingBag]]]`.
242247
return emb, emb_ref
243248

244249
def concat_ref_tensors(
@@ -724,7 +729,11 @@ def test_ssd_cache(
724729

725730
# pyre-fixme[16]: Optional type has no attribute `float`.
726731
optim_state_r.add_(
727-
emb_r.weight.grad.float().to_dense().pow(2).mean(dim=1)
732+
# pyre-fixme[16]: `Optional` has no attribute `float`.
733+
emb_r.weight.grad.float()
734+
.to_dense()
735+
.pow(2)
736+
.mean(dim=1)
728737
)
729738
torch.testing.assert_close(
730739
optim_state_t.float(),

0 commit comments

Comments
 (0)