diff --git a/dsv4/kernels/attention/production.py b/dsv4/kernels/attention/production.py index 3f89f1a8..80b40aec 100644 --- a/dsv4/kernels/attention/production.py +++ b/dsv4/kernels/attention/production.py @@ -45,6 +45,7 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False, v = torch.randn(s_k, pv_n_tile, dtype=torch.bfloat16, device='cuda') c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') lse = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + row_sums = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) @@ -54,8 +55,9 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False, mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile)) mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) mLSE = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse)) + mRS = ct.from_dlpack(row_sums).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums)) - compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, lse=mLSE) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS) _kernel_cache[key] = (compiled, kernel) return (compiled, kernel)