Skip to content

Commit 1203558

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Fix zero_start_index_M argument for triton rowwise quantize (pytorch#3639)
Summary: X-link: facebookresearch/FBGEMM#714 D68797978 implemented a new feature that allowed partial rowwise quantization for jagged tensors in the hopes of improving MOE performance. However, it operated on the wrong dimension (oops). This update shifts the dimension to the proper per-group non zero row. Reviewed By: jasonjk-park, jiawenliu64 Differential Revision: D68872138
1 parent 1c79ba7 commit 1203558

File tree

2 files changed

+26
-25
lines changed

2 files changed

+26
-25
lines changed

fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,23 @@ def _test_quantize_fp8_row(
6060
# Apply sparsification if specified.
6161
zero_start_index_M = None
6262
if use_jagged:
63+
# View input as [G, M, K] where G is the number of groups.
64+
grouped_input = input_a.view(
65+
-1, input_a.shape[-2], input_a.shape[-1]
66+
)
6367
m_vals = torch.randint(
64-
0, input_a.shape[-1] + 1, (input_a.shape[:-1])
68+
0, grouped_input.shape[1] + 1, (grouped_input.shape[0],)
6569
)
66-
mask = torch.arange(input_a.shape[-1]).expand(
67-
input_a.shape[:-1] + (input_a.shape[-1],)
70+
mask = torch.arange(grouped_input.shape[-2]).expand(
71+
(grouped_input.shape[0], grouped_input.shape[1])
6872
) >= m_vals.unsqueeze(-1)
6973
# Set corresponding values to 0.
70-
input_a[mask] = 0.0
74+
grouped_input[mask] = 0.0
7175
# Generate nonzero tensor in same layout as input.
72-
zero_start_index_M = torch.count_nonzero(input_a, dim=-1)
76+
zero_start_index_M = torch.count_nonzero(
77+
torch.sum(grouped_input, dim=-1), dim=-1
78+
)
79+
7380
a_fp8, a_scale = quantize_fp8_row(
7481
input_a,
7582
scale_ub=scale_ub,

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2316,7 +2316,6 @@ def _kernel_quantize_fp8_row(
23162316
stride_ok,
23172317
stride_zb,
23182318
stride_zm,
2319-
stride_zn,
23202319
TL_FP8_DTYPE: tl.constexpr,
23212320
MAX_FP8: tl.constexpr,
23222321
EPS: tl.constexpr,
@@ -2354,7 +2353,6 @@ def _kernel_quantize_fp8_row(
23542353
stride_ok (int): Stride of k dimension of output.
23552354
stride_zb (int): Stride of b dimension of jagged index.
23562355
stride_zm (int): Stride of m dimension of jagged index.
2357-
stride_zn (int): Stride of n dimension of jagged index.
23582356
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
23592357
MAX_FP8 (float): Maxmimum expressible value for FP8.
23602358
EPS (float): Epsilon value for numerical stability.
@@ -2380,24 +2378,22 @@ def _kernel_quantize_fp8_row(
23802378
+ (pid % (M * N)) % N * stride_on
23812379
)
23822380

2383-
if JAGGED:
2384-
z_offset_base = (
2385-
pid // (M * N) * stride_zb
2386-
+ (pid % (M * N)) // N * stride_zm
2387-
+ (pid % (M * N)) % N * stride_zn
2388-
)
2389-
row_size = tl.load(zero_start_index_M + z_offset_base)
2390-
else:
2391-
row_size = K
2381+
K_in = K
23922382

2393-
blocks = tl.cdiv(row_size, BLOCK_SIZE)
2383+
if JAGGED:
2384+
z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
2385+
group_rows = tl.load(zero_start_index_M + z_offset_base)
2386+
current_row = pid % N
2387+
# If this row is empty, dont process any of it.
2388+
if current_row >= group_rows:
2389+
K_in = 0
23942390

23952391
# Calculate max.
23962392
cur_max = 0.0
2397-
for _k in range(0, blocks):
2393+
for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
23982394
a = tl.load(
23992395
A + a_offset_base + n_offset * stride_ak,
2400-
mask=n_offset < row_size,
2396+
mask=n_offset < K_in,
24012397
other=0.0,
24022398
)
24032399
tile_max = tl.max(tl.abs(a))
@@ -2418,15 +2414,14 @@ def _kernel_quantize_fp8_row(
24182414
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
24192415
a = tl.load(
24202416
A + a_offset_base + n_offset * stride_ak,
2421-
mask=n_offset < row_size,
2417+
mask=n_offset < K_in,
24222418
other=0.0,
24232419
)
24242420
a_fp8 = a * a_scale
24252421
# Clamp A to fp8 range to make sure there's no overflow.
24262422
# This is required for AMD. Nvidia's default saturation
24272423
# handles it, but it's nice to have anyway.
2428-
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8)
2429-
a_fp8.to(TL_FP8_DTYPE)
2424+
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
24302425
tl.store(
24312426
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
24322427
a_fp8,
@@ -2481,7 +2476,6 @@ def triton_quantize_fp8_row(
24812476
a_fp8.stride(3),
24822477
zero_start_index_M.stride(0) if zero_start_index_M is not None else None,
24832478
zero_start_index_M.stride(1) if zero_start_index_M is not None else None,
2484-
zero_start_index_M.stride(2) if zero_start_index_M is not None else None,
24852479
TL_FP8_DTYPE=tl_dtype,
24862480
MAX_FP8=max_fp8,
24872481
EPS=eps,
@@ -2527,8 +2521,8 @@ def quantize_fp8_row(
25272521
while a.dim() < 4:
25282522
a = a.unsqueeze(0)
25292523
if zero_start_index_M is not None:
2530-
while zero_start_index_M.dim() < 3:
2531-
zero_start_index_M = zero_start_index_M.unsqueeze(0)
2524+
# There should be one value of zero_start_index_M per NxK matrix.
2525+
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
25322526
a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
25332527
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
25342528
# else use pytorch implementation.

0 commit comments

Comments
 (0)