perf: skip MQA GQA expansion in FMHA (stride=0, no 128x K/V copy)

This commit is contained in:
2026-06-02 03:54:03 +00:00
parent 7b82d31330
commit ca53bdb8e1
2 changed files with 23 additions and 13 deletions

View File

@@ -74,13 +74,14 @@ def _ensure_built():
def fmha_multitile_decode_raw(
q: torch.Tensor, # (batch, n_h, T, hd) BF16
k: torch.Tensor, # (batch, n_h, N, hd) BF16
v: torch.Tensor, # (batch, n_h, hd, N) BF16
k: torch.Tensor, # (batch, n_kv, N, hd) BF16
v: torch.Tensor, # (batch, n_kv, hd, N) BF16
scale: float,
n_comp: int = 0,
swa_len: int = 0,
is_causal: bool = False,
attn_sink: Optional[torch.Tensor] = None,
skip_gqa_expand: bool = False, # Skip K/V repeat_interleave for MQA
) -> tuple[torch.Tensor, torch.Tensor]:
"""Launch the multi-tile TMA FMHA kernel. Returns (O, LSE)."""
lib = _ensure_built()
@@ -96,14 +97,18 @@ def fmha_multitile_decode_raw(
q_per_kv = n_h // n_kv
# GQA: expand K/V to n_h heads
# MQA fast path: skip the expensive repeat_interleave (128× memory copy).
# Instead, pass stride=0 for the head dimension so all Q heads read the same KV.
# This saves ~1.15MB allocation + copy per layer per decode step.
if n_kv < n_h:
k = k.repeat_interleave(q_per_kv, dim=1)
v = v.repeat_interleave(q_per_kv, dim=1)
if skip_gqa_expand:
# Don't expand K/V — pass stride(1)=0 to kernel for MQA
pass
else:
k = k.repeat_interleave(q_per_kv, dim=1)
v = v.repeat_interleave(q_per_kv, dim=1)
# Pad N to multiple of 128 (TMA descriptor alignment)
# CRITICAL: We track the ORIGINAL N (N_orig) separately from N_padded.
# The kernel uses s_k=N_orig as the logical KV length for softmax masking.
# Only the K/V tensors are padded (with zeros) for TMA alignment.
N_orig = N
N_padded = ((N + 127) // 128) * 128
if N < N_padded:
@@ -128,6 +133,13 @@ def fmha_multitile_decode_raw(
assert sb.shape == (B, n_h), f"sink_bias shape {sb.shape} != ({B}, {n_h})"
sink_bias_ptr = ctypes.c_void_p(sb.data_ptr())
# For MQA skip_gqa_expand: pass stride(1)=0 for K and V so all heads
# read from the same KV head (head 0). The kernel's CTA for head h
# computes k_ptr + h * k_stride1, so stride1=0 means all heads share
# the same K/V data without the 128× memory expansion.
k_stride1 = 0 if (n_kv < n_h and skip_gqa_expand) else k.stride(1)
v_stride1 = 0 if (n_kv < n_h and skip_gqa_expand) else v.stride(1)
ret = lib.fmha_multitile_decode_launch(
ctypes.c_void_p(q.data_ptr()),
ctypes.c_void_p(k.data_ptr()),
@@ -140,15 +152,12 @@ def fmha_multitile_decode_raw(
ctypes.c_int(N_padded), # N_padded: physical KV length (for TMA descriptors)
ctypes.c_int(hd),
ctypes.c_int(q.stride(1)), ctypes.c_int(q.stride(0)),
ctypes.c_int(k.stride(1)), ctypes.c_int(k.stride(0)),
ctypes.c_int(v.stride(1)), ctypes.c_int(v.stride(0)),
ctypes.c_int(k_stride1), ctypes.c_int(k.stride(0)),
ctypes.c_int(v_stride1), ctypes.c_int(v.stride(0)),
ctypes.c_int(o.stride(1)), ctypes.c_int(o.stride(0)),
ctypes.c_int(lse.stride(1)), ctypes.c_int(lse.stride(0)),
ctypes.c_float(scale),
)
if ret != 0:
raise RuntimeError(f"Multi-tile kernel launch failed: return code {ret}")
# E4: Removed torch.cuda.synchronize() — the C API launch returns an error
# code from the kernel setup. Async kernel errors will surface on the next
# CUDA API call. A full device sync is not needed on the hot path.
return o, lse

View File

@@ -41,7 +41,8 @@ def _dsv4_attention_multitile(
k_4d = k.unsqueeze(0).contiguous()
v_4d = v.unsqueeze(0).transpose(-1, -2).contiguous()
o_4d, _lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale, attn_sink=sink_bias)
o_4d, _lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale, attn_sink=sink_bias,
skip_gqa_expand=True)
return o_4d.squeeze(0)