From ca53bdb8e1a0fd60724addcaa4a9a4f6d0099b5a Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 2 Jun 2026 03:54:03 +0000 Subject: [PATCH] perf: skip MQA GQA expansion in FMHA (stride=0, no 128x K/V copy) --- dsv4/kernels/attention/fmha_multitile_op.py | 33 +++++++++++++-------- dsv4/kernels/attention/production.py | 3 +- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/dsv4/kernels/attention/fmha_multitile_op.py b/dsv4/kernels/attention/fmha_multitile_op.py index 3ea6c463..71d29287 100644 --- a/dsv4/kernels/attention/fmha_multitile_op.py +++ b/dsv4/kernels/attention/fmha_multitile_op.py @@ -74,13 +74,14 @@ def _ensure_built(): def fmha_multitile_decode_raw( q: torch.Tensor, # (batch, n_h, T, hd) BF16 - k: torch.Tensor, # (batch, n_h, N, hd) BF16 - v: torch.Tensor, # (batch, n_h, hd, N) BF16 + k: torch.Tensor, # (batch, n_kv, N, hd) BF16 + v: torch.Tensor, # (batch, n_kv, hd, N) BF16 scale: float, n_comp: int = 0, swa_len: int = 0, is_causal: bool = False, attn_sink: Optional[torch.Tensor] = None, + skip_gqa_expand: bool = False, # Skip K/V repeat_interleave for MQA ) -> tuple[torch.Tensor, torch.Tensor]: """Launch the multi-tile TMA FMHA kernel. Returns (O, LSE).""" lib = _ensure_built() @@ -96,14 +97,18 @@ def fmha_multitile_decode_raw( q_per_kv = n_h // n_kv # GQA: expand K/V to n_h heads + # MQA fast path: skip the expensive repeat_interleave (128× memory copy). + # Instead, pass stride=0 for the head dimension so all Q heads read the same KV. + # This saves ~1.15MB allocation + copy per layer per decode step. if n_kv < n_h: - k = k.repeat_interleave(q_per_kv, dim=1) - v = v.repeat_interleave(q_per_kv, dim=1) + if skip_gqa_expand: + # Don't expand K/V — pass stride(1)=0 to kernel for MQA + pass + else: + k = k.repeat_interleave(q_per_kv, dim=1) + v = v.repeat_interleave(q_per_kv, dim=1) # Pad N to multiple of 128 (TMA descriptor alignment) - # CRITICAL: We track the ORIGINAL N (N_orig) separately from N_padded. - # The kernel uses s_k=N_orig as the logical KV length for softmax masking. - # Only the K/V tensors are padded (with zeros) for TMA alignment. N_orig = N N_padded = ((N + 127) // 128) * 128 if N < N_padded: @@ -128,6 +133,13 @@ def fmha_multitile_decode_raw( assert sb.shape == (B, n_h), f"sink_bias shape {sb.shape} != ({B}, {n_h})" sink_bias_ptr = ctypes.c_void_p(sb.data_ptr()) + # For MQA skip_gqa_expand: pass stride(1)=0 for K and V so all heads + # read from the same KV head (head 0). The kernel's CTA for head h + # computes k_ptr + h * k_stride1, so stride1=0 means all heads share + # the same K/V data without the 128× memory expansion. + k_stride1 = 0 if (n_kv < n_h and skip_gqa_expand) else k.stride(1) + v_stride1 = 0 if (n_kv < n_h and skip_gqa_expand) else v.stride(1) + ret = lib.fmha_multitile_decode_launch( ctypes.c_void_p(q.data_ptr()), ctypes.c_void_p(k.data_ptr()), @@ -140,15 +152,12 @@ def fmha_multitile_decode_raw( ctypes.c_int(N_padded), # N_padded: physical KV length (for TMA descriptors) ctypes.c_int(hd), ctypes.c_int(q.stride(1)), ctypes.c_int(q.stride(0)), - ctypes.c_int(k.stride(1)), ctypes.c_int(k.stride(0)), - ctypes.c_int(v.stride(1)), ctypes.c_int(v.stride(0)), + ctypes.c_int(k_stride1), ctypes.c_int(k.stride(0)), + ctypes.c_int(v_stride1), ctypes.c_int(v.stride(0)), ctypes.c_int(o.stride(1)), ctypes.c_int(o.stride(0)), ctypes.c_int(lse.stride(1)), ctypes.c_int(lse.stride(0)), ctypes.c_float(scale), ) if ret != 0: raise RuntimeError(f"Multi-tile kernel launch failed: return code {ret}") - # E4: Removed torch.cuda.synchronize() — the C API launch returns an error - # code from the kernel setup. Async kernel errors will surface on the next - # CUDA API call. A full device sync is not needed on the hot path. return o, lse diff --git a/dsv4/kernels/attention/production.py b/dsv4/kernels/attention/production.py index 1a1a8ab7..ba27e1b0 100644 --- a/dsv4/kernels/attention/production.py +++ b/dsv4/kernels/attention/production.py @@ -41,7 +41,8 @@ def _dsv4_attention_multitile( k_4d = k.unsqueeze(0).contiguous() v_4d = v.unsqueeze(0).transpose(-1, -2).contiguous() - o_4d, _lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale, attn_sink=sink_bias) + o_4d, _lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale, attn_sink=sink_bias, + skip_gqa_expand=True) return o_4d.squeeze(0)