diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index d31e0f2d..67f65c77 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -15,13 +15,14 @@ import math class FmhaKernel: - def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None): + def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True): self.head_dim = head_dim self.s_k = s_k self.n_kv_tiles = s_k // 128 self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256 self.n_pv_tiles = head_dim // self.pv_n_tile self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) + self.normalize = normalize # D5a: False = emit un-normalized O + lse self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1