Skip to content

Commit 7c5e9b7

Browse files
committed
better block_quant_dequant impl
1 parent abe5fa2 commit 7c5e9b7

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

python/sglang/srt/layers/quantization/fp8_utils.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -240,22 +240,19 @@ def block_quant_dequant(
240240
assert n_tiles == x_s.shape[0]
241241
assert k_tiles == x_s.shape[1]
242242

243-
x_dq_block = x_q_block.to(dtype)
243+
x_dq_block = torch.empty_like(x_q_block, dtype=dtype)
244244

245-
x_dq_block_tiles = [
246-
[
247-
x_dq_block[
245+
for j in range(n_tiles):
246+
for i in range(k_tiles):
247+
x_q_block_tile = x_q_block[
248248
j * block_n : min((j + 1) * block_n, n),
249249
i * block_k : min((i + 1) * block_k, k),
250250
]
251-
for i in range(k_tiles)
252-
]
253-
for j in range(n_tiles)
254-
]
255-
256-
for i in range(k_tiles):
257-
for j in range(n_tiles):
258-
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
251+
x_dq_block_tile = x_dq_block[
252+
j * block_n : min((j + 1) * block_n, n),
253+
i * block_k : min((i + 1) * block_k, k),
254+
]
255+
x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i]
259256

260257
return x_dq_block
261258

0 commit comments

Comments
 (0)