Compile with row_sums tensor so kernel writes per-row row_sums

This commit is contained in:
2026-05-27 07:10:00 +00:00
parent 0736a04d9b
commit 778d9d4f4f

View File

@@ -45,6 +45,7 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False,
v = torch.randn(s_k, pv_n_tile, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
row_sums = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
@@ -54,8 +55,9 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False,
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
mLSE = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse))
mRS = ct.from_dlpack(row_sums).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums))
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, lse=mLSE)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS)
_kernel_cache[key] = (compiled, kernel)
return (compiled, kernel)