From df8442041482bb6bf5a0cd521d4267b3f91cace7 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 26 May 2026 10:53:14 +0000 Subject: [PATCH] fix: add is_causal to FmhaKernel __init__ signature --- dsv4/kernels/attention/fmha.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index c66d4598..d73e1df3 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -16,7 +16,7 @@ import math class FmhaKernel: - def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False): + def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False): self.head_dim = head_dim self.s_k = s_k self.n_kv_tiles = s_k // 128 @@ -32,7 +32,7 @@ class FmhaKernel: self.batch_size = batch_size self.normalize = normalize # D5a: False = emit un-normalized O + lse self.apply_swa_mask = apply_swa_mask # D3: mask logits at positions >= swa_lens - self.is_causal = False # D4: causal mask (k_coord > m_coord) on SWA branch + self.is_causal = is_causal # D4: causal mask (k_coord > m_coord) on SWA branch self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1