@@ -2316,7 +2316,6 @@ def _kernel_quantize_fp8_row(
2316
2316
stride_ok ,
2317
2317
stride_zb ,
2318
2318
stride_zm ,
2319
- stride_zn ,
2320
2319
TL_FP8_DTYPE : tl .constexpr ,
2321
2320
MAX_FP8 : tl .constexpr ,
2322
2321
EPS : tl .constexpr ,
@@ -2354,7 +2353,6 @@ def _kernel_quantize_fp8_row(
2354
2353
stride_ok (int): Stride of k dimension of output.
2355
2354
stride_zb (int): Stride of b dimension of jagged index.
2356
2355
stride_zm (int): Stride of m dimension of jagged index.
2357
- stride_zn (int): Stride of n dimension of jagged index.
2358
2356
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
2359
2357
MAX_FP8 (float): Maxmimum expressible value for FP8.
2360
2358
EPS (float): Epsilon value for numerical stability.
@@ -2380,24 +2378,22 @@ def _kernel_quantize_fp8_row(
2380
2378
+ (pid % (M * N )) % N * stride_on
2381
2379
)
2382
2380
2383
- if JAGGED :
2384
- z_offset_base = (
2385
- pid // (M * N ) * stride_zb
2386
- + (pid % (M * N )) // N * stride_zm
2387
- + (pid % (M * N )) % N * stride_zn
2388
- )
2389
- row_size = tl .load (zero_start_index_M + z_offset_base )
2390
- else :
2391
- row_size = K
2381
+ K_in = K
2392
2382
2393
- blocks = tl .cdiv (row_size , BLOCK_SIZE )
2383
+ if JAGGED :
2384
+ z_offset_base = pid // (M * N ) * stride_zb + (pid % (M * N )) // N * stride_zm
2385
+ group_rows = tl .load (zero_start_index_M + z_offset_base )
2386
+ current_row = pid % N
2387
+ # If this row is empty, dont process any of it.
2388
+ if current_row >= group_rows :
2389
+ K_in = 0
2394
2390
2395
2391
# Calculate max.
2396
2392
cur_max = 0.0
2397
- for _k in range (0 , blocks ):
2393
+ for _k in range (0 , tl . cdiv ( K_in , BLOCK_SIZE ) ):
2398
2394
a = tl .load (
2399
2395
A + a_offset_base + n_offset * stride_ak ,
2400
- mask = n_offset < row_size ,
2396
+ mask = n_offset < K_in ,
2401
2397
other = 0.0 ,
2402
2398
)
2403
2399
tile_max = tl .max (tl .abs (a ))
@@ -2418,15 +2414,14 @@ def _kernel_quantize_fp8_row(
2418
2414
for _k in range (0 , tl .cdiv (K , BLOCK_SIZE )):
2419
2415
a = tl .load (
2420
2416
A + a_offset_base + n_offset * stride_ak ,
2421
- mask = n_offset < row_size ,
2417
+ mask = n_offset < K_in ,
2422
2418
other = 0.0 ,
2423
2419
)
2424
2420
a_fp8 = a * a_scale
2425
2421
# Clamp A to fp8 range to make sure there's no overflow.
2426
2422
# This is required for AMD. Nvidia's default saturation
2427
2423
# handles it, but it's nice to have anyway.
2428
- a_fp8 = tl .clamp (a_fp8 , - MAX_FP8 , MAX_FP8 )
2429
- a_fp8 .to (TL_FP8_DTYPE )
2424
+ a_fp8 = tl .clamp (a_fp8 , - MAX_FP8 , MAX_FP8 ).to (TL_FP8_DTYPE )
2430
2425
tl .store (
2431
2426
A_fp8 + a_fp8_offset_base + n_offset * stride_ok ,
2432
2427
a_fp8 ,
@@ -2481,7 +2476,6 @@ def triton_quantize_fp8_row(
2481
2476
a_fp8 .stride (3 ),
2482
2477
zero_start_index_M .stride (0 ) if zero_start_index_M is not None else None ,
2483
2478
zero_start_index_M .stride (1 ) if zero_start_index_M is not None else None ,
2484
- zero_start_index_M .stride (2 ) if zero_start_index_M is not None else None ,
2485
2479
TL_FP8_DTYPE = tl_dtype ,
2486
2480
MAX_FP8 = max_fp8 ,
2487
2481
EPS = eps ,
@@ -2527,8 +2521,8 @@ def quantize_fp8_row(
2527
2521
while a .dim () < 4 :
2528
2522
a = a .unsqueeze (0 )
2529
2523
if zero_start_index_M is not None :
2530
- while zero_start_index_M . dim () < 3 :
2531
- zero_start_index_M = zero_start_index_M .unsqueeze ( 0 )
2524
+ # There should be one value of zero_start_index_M per NxK matrix.
2525
+ zero_start_index_M = zero_start_index_M .view ( a . shape [ 0 ], a . shape [ 1 ] )
2532
2526
a_fp8 , a_scale = triton_quantize_fp8_row (a , scale_ub , zero_start_index_M )
2533
2527
return a_fp8 .view (a_shape ), a_scale .view (a_shape [:- 1 ])
2534
2528
# else use pytorch implementation.
0 commit comments