Compile with row_sums tensor so kernel writes per-row row_sums
This commit is contained in:
@@ -45,6 +45,7 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False,
|
||||
v = torch.randn(s_k, pv_n_tile, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
row_sums = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
@@ -54,8 +55,9 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False,
|
||||
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
mLSE = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse))
|
||||
mRS = ct.from_dlpack(row_sums).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums))
|
||||
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, lse=mLSE)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS)
|
||||
_kernel_cache[key] = (compiled, kernel)
|
||||
return (compiled, kernel)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user