perf: skip MQA GQA expansion in FMHA (stride=0, no 128x K/V copy)
This commit is contained in:
@@ -74,13 +74,14 @@ def _ensure_built():
|
||||
|
||||
def fmha_multitile_decode_raw(
|
||||
q: torch.Tensor, # (batch, n_h, T, hd) BF16
|
||||
k: torch.Tensor, # (batch, n_h, N, hd) BF16
|
||||
v: torch.Tensor, # (batch, n_h, hd, N) BF16
|
||||
k: torch.Tensor, # (batch, n_kv, N, hd) BF16
|
||||
v: torch.Tensor, # (batch, n_kv, hd, N) BF16
|
||||
scale: float,
|
||||
n_comp: int = 0,
|
||||
swa_len: int = 0,
|
||||
is_causal: bool = False,
|
||||
attn_sink: Optional[torch.Tensor] = None,
|
||||
skip_gqa_expand: bool = False, # Skip K/V repeat_interleave for MQA
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Launch the multi-tile TMA FMHA kernel. Returns (O, LSE)."""
|
||||
lib = _ensure_built()
|
||||
@@ -96,14 +97,18 @@ def fmha_multitile_decode_raw(
|
||||
q_per_kv = n_h // n_kv
|
||||
|
||||
# GQA: expand K/V to n_h heads
|
||||
# MQA fast path: skip the expensive repeat_interleave (128× memory copy).
|
||||
# Instead, pass stride=0 for the head dimension so all Q heads read the same KV.
|
||||
# This saves ~1.15MB allocation + copy per layer per decode step.
|
||||
if n_kv < n_h:
|
||||
k = k.repeat_interleave(q_per_kv, dim=1)
|
||||
v = v.repeat_interleave(q_per_kv, dim=1)
|
||||
if skip_gqa_expand:
|
||||
# Don't expand K/V — pass stride(1)=0 to kernel for MQA
|
||||
pass
|
||||
else:
|
||||
k = k.repeat_interleave(q_per_kv, dim=1)
|
||||
v = v.repeat_interleave(q_per_kv, dim=1)
|
||||
|
||||
# Pad N to multiple of 128 (TMA descriptor alignment)
|
||||
# CRITICAL: We track the ORIGINAL N (N_orig) separately from N_padded.
|
||||
# The kernel uses s_k=N_orig as the logical KV length for softmax masking.
|
||||
# Only the K/V tensors are padded (with zeros) for TMA alignment.
|
||||
N_orig = N
|
||||
N_padded = ((N + 127) // 128) * 128
|
||||
if N < N_padded:
|
||||
@@ -128,6 +133,13 @@ def fmha_multitile_decode_raw(
|
||||
assert sb.shape == (B, n_h), f"sink_bias shape {sb.shape} != ({B}, {n_h})"
|
||||
sink_bias_ptr = ctypes.c_void_p(sb.data_ptr())
|
||||
|
||||
# For MQA skip_gqa_expand: pass stride(1)=0 for K and V so all heads
|
||||
# read from the same KV head (head 0). The kernel's CTA for head h
|
||||
# computes k_ptr + h * k_stride1, so stride1=0 means all heads share
|
||||
# the same K/V data without the 128× memory expansion.
|
||||
k_stride1 = 0 if (n_kv < n_h and skip_gqa_expand) else k.stride(1)
|
||||
v_stride1 = 0 if (n_kv < n_h and skip_gqa_expand) else v.stride(1)
|
||||
|
||||
ret = lib.fmha_multitile_decode_launch(
|
||||
ctypes.c_void_p(q.data_ptr()),
|
||||
ctypes.c_void_p(k.data_ptr()),
|
||||
@@ -140,15 +152,12 @@ def fmha_multitile_decode_raw(
|
||||
ctypes.c_int(N_padded), # N_padded: physical KV length (for TMA descriptors)
|
||||
ctypes.c_int(hd),
|
||||
ctypes.c_int(q.stride(1)), ctypes.c_int(q.stride(0)),
|
||||
ctypes.c_int(k.stride(1)), ctypes.c_int(k.stride(0)),
|
||||
ctypes.c_int(v.stride(1)), ctypes.c_int(v.stride(0)),
|
||||
ctypes.c_int(k_stride1), ctypes.c_int(k.stride(0)),
|
||||
ctypes.c_int(v_stride1), ctypes.c_int(v.stride(0)),
|
||||
ctypes.c_int(o.stride(1)), ctypes.c_int(o.stride(0)),
|
||||
ctypes.c_int(lse.stride(1)), ctypes.c_int(lse.stride(0)),
|
||||
ctypes.c_float(scale),
|
||||
)
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"Multi-tile kernel launch failed: return code {ret}")
|
||||
# E4: Removed torch.cuda.synchronize() — the C API launch returns an error
|
||||
# code from the kernel setup. Async kernel errors will surface on the next
|
||||
# CUDA API call. A full device sync is not needed on the hot path.
|
||||
return o, lse
|
||||
|
||||
@@ -41,7 +41,8 @@ def _dsv4_attention_multitile(
|
||||
k_4d = k.unsqueeze(0).contiguous()
|
||||
v_4d = v.unsqueeze(0).transpose(-1, -2).contiguous()
|
||||
|
||||
o_4d, _lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale, attn_sink=sink_bias)
|
||||
o_4d, _lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale, attn_sink=sink_bias,
|
||||
skip_gqa_expand=True)
|
||||
return o_4d.squeeze(0)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user