From dbdbcecadc1f4b917fd75d433bf37b5ddf580d04 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 26 May 2026 15:02:43 +0000 Subject: [PATCH] fix: sink_bias must be pre-converted to CuTe tensor before passing to compile --- dsv4/kernels/attention/fmha.py | 3 +-- tests/unit/test_d5c_fused.py | 10 ++++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 4d9d3d3c..0992413e 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -149,8 +149,7 @@ class FmhaKernel: # D5c: sink_bias not provided. Create a dummy tensor pointing to valid memory. # Never actually read (const_expr(self.n_comp > 0) guards the read). sink_bias = cute.make_tensor(lse.iterator, cute.make_layout((1,), stride=(0,))) - else: - sink_bias = ct.from_dlpack(sink_bias).mark_layout_dynamic(leading_dim=ct.get_leading_dim(sink_bias)) + # else: sink_bias is already a CuTe tensor (caller must pass via ct.from_dlpack) # Grid: (M_tiles, 1, batch) where M = n_h * T packed into M dimension # For single-head (n_h=1): grid=(1,1,1) — backward compatible self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream) diff --git a/tests/unit/test_d5c_fused.py b/tests/unit/test_d5c_fused.py index e623bad0..5a12760b 100644 --- a/tests/unit/test_d5c_fused.py +++ b/tests/unit/test_d5c_fused.py @@ -175,16 +175,17 @@ def test_d5c_combined(): # Compile print('Compiling D5c kernel (combined KV + sink bias)...', flush=True) + mSinkBias = to_cute(attn_sink) compiled = cute.compile( kernel, mQ, mK, mV, mC, stream, mLSE, - swa_len=swa_len, sink_bias=attn_sink, + swa_len=swa_len, sink_bias=mSinkBias, ) # Run print('Running D5c kernel...', flush=True) compiled( mQ, mK, mV, mC, stream, mLSE, - swa_len=swa_len, sink_bias=attn_sink, + swa_len=swa_len, sink_bias=mSinkBias, ) torch.cuda.synchronize() @@ -260,13 +261,14 @@ def test_d5c_with_causal(): mLSE = to_cute(lse_out) print('Compiling D5c kernel (causal + sink bias)...', flush=True) + mSinkBias = to_cute(attn_sink) compiled = cute.compile( kernel, mQ, mK, mV, mC, stream, mLSE, - swa_len=swa_len, sink_bias=attn_sink, + swa_len=swa_len, sink_bias=mSinkBias, ) compiled( mQ, mK, mV, mC, stream, mLSE, - swa_len=swa_len, sink_bias=attn_sink, + swa_len=swa_len, sink_bias=mSinkBias, ) torch.cuda.synchronize()