Triton MLA perf fixes (#33529)

Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Koushik Dutta
2026-04-02 06:40:01 -07:00
committed by GitHub
parent 16a65e4173
commit d9408ffba3
2 changed files with 69 additions and 26 deletions

View File

@@ -14,6 +14,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonMetadata,
)
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import triton
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionLayer,
@@ -115,6 +116,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
if is_quantized_kv_cache(self.kv_cache_dtype):
self.supports_quant_query_input = False
self._sm_count = torch.cuda.get_device_properties(0).multi_processor_count
def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
):
@@ -149,7 +152,24 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
# For batch invariance, use only 1 split to ensure deterministic reduction
num_kv_splits = 1 if envs.VLLM_BATCH_INVARIANT else 4
if envs.VLLM_BATCH_INVARIANT:
num_kv_splits = 1
else:
# Minimum work per split
# hardware dependent
min_work_per_split = 512
ideal_splits = max(1, attn_metadata.max_seq_len // min_work_per_split)
# use power of 2 to avoid excessive kernel instantiations
ideal_splits = triton.next_power_of_2(ideal_splits)
# Calculate SM-based maximum splits with occupancy multiplier
# 2-4x allows multiple blocks per SM for latency hiding
# hardware dependent
occupancy_multiplier = 2
max_splits = self._sm_count * occupancy_multiplier
num_kv_splits = min(ideal_splits, max_splits)
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
@@ -186,6 +206,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
PAGE_SIZE,
k_scale=layer._k_scale,
v_scale=layer._k_scale,
is_mla=True,
)
return o, lse

View File

@@ -291,6 +291,7 @@ def _fwd_grouped_kernel_stage1(
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
IS_MLA: tl.constexpr = False,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
@@ -310,7 +311,12 @@ def _fwd_grouped_kernel_stage1(
cur_batch_req_idx = cur_batch
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
q = tl.load(
Q + offs_q,
mask=(mask_h[:, None]) & (mask_d[None, :]),
other=0.0,
cache_modifier=".ca",
)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
@@ -319,7 +325,10 @@ def _fwd_grouped_kernel_stage1(
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
Q + off_qpe,
mask=(mask_h[:, None]) & (mask_dpe[None, :]),
other=0.0,
cache_modifier=".ca",
)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
@@ -331,9 +340,14 @@ def _fwd_grouped_kernel_stage1(
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
base_offs_k = cur_kv_head * stride_buf_kh + offs_d[:, None]
base_offs_v = cur_kv_head * stride_buf_vh + offs_dv[None, :]
if BLOCK_DPE > 0:
base_offs_kpe = cur_kv_head * stride_buf_kh + offs_dpe[:, None]
ks = tl.load(k_scale)
vs = tl.load(v_scale)
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
for start_n in tl.range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens
@@ -341,31 +355,29 @@ def _fwd_grouped_kernel_stage1(
+ offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
cache_modifier=".ca",
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)
# explicitly facilitate overlapping load/compute
offs_buf_k = kv_loc[None, :] * stride_buf_kbs + base_offs_k
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
other=0.0,
cache_modifier=".cg",
)
if k.dtype.is_fp8():
k = (k.to(tl.float32) * ks).to(q.dtype)
qk = tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_buf_kpe = (
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
offs_buf_kpe = kv_loc[None, :] * stride_buf_kbs + base_offs_kpe
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
other=0.0,
cache_modifier=".cg",
)
if kpe.dtype.is_fp8():
kpe = (kpe.to(tl.float32) * ks).to(qpe.dtype)
@@ -379,18 +391,20 @@ def _fwd_grouped_kernel_stage1(
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
)
offs_buf_v = (
kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Buffer + offs_buf_v,
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
if v.dtype.is_fp8():
v = (v.to(tl.float32) * vs).to(q.dtype)
if not IS_MLA:
offs_buf_v = kv_loc[:, None] * stride_buf_vbs + base_offs_v
v = tl.load(
V_Buffer + offs_buf_v,
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
if v.dtype.is_fp8():
v = (v.to(tl.float32) * vs).to(q.dtype)
else:
# MLA uses a single c_kv.
# loading the same c_kv to interpret it as v is not necessary.
# transpose the existing c_kv (aka k) for the dot product.
v = tl.trans(k)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
@@ -441,7 +455,10 @@ def _decode_grouped_att_m_fwd(
logit_cap,
k_scale,
v_scale,
is_mla=False,
):
# with is_mla there is only a single c_kv in smem.
# could increase BLOCK or num_stages.
BLOCK = 32
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
@@ -514,6 +531,7 @@ def _decode_grouped_att_m_fwd(
num_stages=num_stages,
Lk=Lk,
Lv=Lv,
IS_MLA=is_mla,
**extra_kargs,
)
@@ -673,6 +691,7 @@ def decode_attention_fwd_grouped(
logit_cap=0.0,
k_scale=None,
v_scale=None,
is_mla=False,
):
_decode_grouped_att_m_fwd(
q,
@@ -687,6 +706,7 @@ def decode_attention_fwd_grouped(
logit_cap,
k_scale,
v_scale,
is_mla=is_mla,
)
_decode_softmax_reducev_fwd(
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
@@ -708,6 +728,7 @@ def decode_attention_fwd(
logit_cap=0.0,
k_scale=None,
v_scale=None,
is_mla=False,
):
assert num_kv_splits == attn_logits.shape[2]
@@ -753,4 +774,5 @@ def decode_attention_fwd(
logit_cap,
k_scale,
v_scale,
is_mla=is_mla,
)