From d6a56342cc7dc8e5f2255dd5a5abba9f1e1b6fe2 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 25 May 2026 17:31:01 +0000 Subject: [PATCH] D3: add swa_lens parameter to FmhaKernel (in-kernel masking TBD) --- dsv4/kernels/attention/fmha.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 0d7d7877..79604bc1 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -131,6 +131,9 @@ class FmhaKernel: # CuTeDSL doesn't support None parameters in @cute.kernel. if const_expr(lse is None): lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,))) + if const_expr(swa_lens is None): + # No SWA masking — pass a dummy tensor + swa_lens = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,))) # Grid: (M_tiles, 1, batch) where M = n_h * T packed into M dimension # For single-head (n_h=1): grid=(1,1,1) — backward compatible self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_lens).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream)