From 4e9dae6b37c19be45c9d3952737c3a4a0a37d3bc Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Thu, 6 Mar 2025 23:24:12 -0800 Subject: [PATCH] Update --- stk/backend/triton_kernels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stk/backend/triton_kernels.py b/stk/backend/triton_kernels.py index c535309..e2cbd95 100644 --- a/stk/backend/triton_kernels.py +++ b/stk/backend/triton_kernels.py @@ -41,8 +41,8 @@ def _sdd_kernel(A, B, C, M, N, K, ): # matrix multiplication pid = tl.program_id(0) - pid_m = tl.load(row_indices + pid) - pid_n = tl.load(column_indices + pid) + pid_m = tl.load(row_indices + pid).to(tl.int32) + pid_n = tl.load(column_indices + pid).to(tl.int32) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)