Triton MLA perf fixes (#33529)
Signed-off-by: Koushik Dutta <koushd@gmail.com> Co-authored-by: root <root@ubuntu-nvidia.localdomain> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonMetadata,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionLayer,
|
||||
@@ -115,6 +116,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
self.supports_quant_query_input = False
|
||||
|
||||
self._sm_count = torch.cuda.get_device_properties(0).multi_processor_count
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
):
|
||||
@@ -149,7 +152,24 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
|
||||
|
||||
# For batch invariance, use only 1 split to ensure deterministic reduction
|
||||
num_kv_splits = 1 if envs.VLLM_BATCH_INVARIANT else 4
|
||||
if envs.VLLM_BATCH_INVARIANT:
|
||||
num_kv_splits = 1
|
||||
else:
|
||||
# Minimum work per split
|
||||
# hardware dependent
|
||||
min_work_per_split = 512
|
||||
|
||||
ideal_splits = max(1, attn_metadata.max_seq_len // min_work_per_split)
|
||||
|
||||
# use power of 2 to avoid excessive kernel instantiations
|
||||
ideal_splits = triton.next_power_of_2(ideal_splits)
|
||||
|
||||
# Calculate SM-based maximum splits with occupancy multiplier
|
||||
# 2-4x allows multiple blocks per SM for latency hiding
|
||||
# hardware dependent
|
||||
occupancy_multiplier = 2
|
||||
max_splits = self._sm_count * occupancy_multiplier
|
||||
num_kv_splits = min(ideal_splits, max_splits)
|
||||
|
||||
# TODO(lucas) Allocate ahead of time
|
||||
attn_logits = torch.empty(
|
||||
@@ -186,6 +206,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
PAGE_SIZE,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._k_scale,
|
||||
is_mla=True,
|
||||
)
|
||||
|
||||
return o, lse
|
||||
|
||||
@@ -291,6 +291,7 @@ def _fwd_grouped_kernel_stage1(
|
||||
logit_cap: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
IS_MLA: tl.constexpr = False,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head_id = tl.program_id(1)
|
||||
@@ -310,7 +311,12 @@ def _fwd_grouped_kernel_stage1(
|
||||
cur_batch_req_idx = cur_batch
|
||||
|
||||
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
||||
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
|
||||
q = tl.load(
|
||||
Q + offs_q,
|
||||
mask=(mask_h[:, None]) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
cache_modifier=".ca",
|
||||
)
|
||||
|
||||
if BLOCK_DPE > 0:
|
||||
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||
@@ -319,7 +325,10 @@ def _fwd_grouped_kernel_stage1(
|
||||
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
|
||||
)
|
||||
qpe = tl.load(
|
||||
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
|
||||
Q + off_qpe,
|
||||
mask=(mask_h[:, None]) & (mask_dpe[None, :]),
|
||||
other=0.0,
|
||||
cache_modifier=".ca",
|
||||
)
|
||||
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
@@ -331,9 +340,14 @@ def _fwd_grouped_kernel_stage1(
|
||||
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
base_offs_k = cur_kv_head * stride_buf_kh + offs_d[:, None]
|
||||
base_offs_v = cur_kv_head * stride_buf_vh + offs_dv[None, :]
|
||||
if BLOCK_DPE > 0:
|
||||
base_offs_kpe = cur_kv_head * stride_buf_kh + offs_dpe[:, None]
|
||||
|
||||
ks = tl.load(k_scale)
|
||||
vs = tl.load(v_scale)
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
for start_n in tl.range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
Req_to_tokens
|
||||
@@ -341,31 +355,29 @@ def _fwd_grouped_kernel_stage1(
|
||||
+ offs_n // PAGE_SIZE,
|
||||
mask=offs_n < split_kv_end,
|
||||
other=0,
|
||||
cache_modifier=".ca",
|
||||
)
|
||||
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||
offs_buf_k = (
|
||||
kv_loc[None, :] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_d[:, None]
|
||||
)
|
||||
|
||||
# explicitly facilitate overlapping load/compute
|
||||
offs_buf_k = kv_loc[None, :] * stride_buf_kbs + base_offs_k
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
|
||||
other=0.0,
|
||||
cache_modifier=".cg",
|
||||
)
|
||||
|
||||
if k.dtype.is_fp8():
|
||||
k = (k.to(tl.float32) * ks).to(q.dtype)
|
||||
qk = tl.dot(q, k.to(q.dtype))
|
||||
if BLOCK_DPE > 0:
|
||||
offs_buf_kpe = (
|
||||
kv_loc[None, :] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_dpe[:, None]
|
||||
)
|
||||
offs_buf_kpe = kv_loc[None, :] * stride_buf_kbs + base_offs_kpe
|
||||
kpe = tl.load(
|
||||
K_Buffer + offs_buf_kpe,
|
||||
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
|
||||
other=0.0,
|
||||
cache_modifier=".cg",
|
||||
)
|
||||
if kpe.dtype.is_fp8():
|
||||
kpe = (kpe.to(tl.float32) * ks).to(qpe.dtype)
|
||||
@@ -379,18 +391,20 @@ def _fwd_grouped_kernel_stage1(
|
||||
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
|
||||
)
|
||||
|
||||
offs_buf_v = (
|
||||
kv_loc[:, None] * stride_buf_vbs
|
||||
+ cur_kv_head * stride_buf_vh
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
if v.dtype.is_fp8():
|
||||
v = (v.to(tl.float32) * vs).to(q.dtype)
|
||||
if not IS_MLA:
|
||||
offs_buf_v = kv_loc[:, None] * stride_buf_vbs + base_offs_v
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
if v.dtype.is_fp8():
|
||||
v = (v.to(tl.float32) * vs).to(q.dtype)
|
||||
else:
|
||||
# MLA uses a single c_kv.
|
||||
# loading the same c_kv to interpret it as v is not necessary.
|
||||
# transpose the existing c_kv (aka k) for the dot product.
|
||||
v = tl.trans(k)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
@@ -441,7 +455,10 @@ def _decode_grouped_att_m_fwd(
|
||||
logit_cap,
|
||||
k_scale,
|
||||
v_scale,
|
||||
is_mla=False,
|
||||
):
|
||||
# with is_mla there is only a single c_kv in smem.
|
||||
# could increase BLOCK or num_stages.
|
||||
BLOCK = 32
|
||||
Lk = k_buffer.shape[-1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
@@ -514,6 +531,7 @@ def _decode_grouped_att_m_fwd(
|
||||
num_stages=num_stages,
|
||||
Lk=Lk,
|
||||
Lv=Lv,
|
||||
IS_MLA=is_mla,
|
||||
**extra_kargs,
|
||||
)
|
||||
|
||||
@@ -673,6 +691,7 @@ def decode_attention_fwd_grouped(
|
||||
logit_cap=0.0,
|
||||
k_scale=None,
|
||||
v_scale=None,
|
||||
is_mla=False,
|
||||
):
|
||||
_decode_grouped_att_m_fwd(
|
||||
q,
|
||||
@@ -687,6 +706,7 @@ def decode_attention_fwd_grouped(
|
||||
logit_cap,
|
||||
k_scale,
|
||||
v_scale,
|
||||
is_mla=is_mla,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(
|
||||
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
|
||||
@@ -708,6 +728,7 @@ def decode_attention_fwd(
|
||||
logit_cap=0.0,
|
||||
k_scale=None,
|
||||
v_scale=None,
|
||||
is_mla=False,
|
||||
):
|
||||
assert num_kv_splits == attn_logits.shape[2]
|
||||
|
||||
@@ -753,4 +774,5 @@ def decode_attention_fwd(
|
||||
logit_cap,
|
||||
k_scale,
|
||||
v_scale,
|
||||
is_mla=is_mla,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user