Update FlashMLA (#32491)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-01-20 22:03:37 -07:00
committed by GitHub
parent 7ab80a8e37
commit b4f64e5b02
4 changed files with 169 additions and 42 deletions

View File

@@ -17,7 +17,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
@@ -397,6 +396,10 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
# FP8 decode kernel only supports h_q = 64 or 128, so we need to pad
self.fp8_decode_padded_heads = (
FlashMLASparseImpl._compute_fp8_decode_padded_heads(self.num_heads)
)
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
@@ -417,14 +420,20 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
(max_num_seqs, 1), dtype=torch.int32, device=self.device
)
# Equation taken from FlashMLA/csrc/pybind.cpp
h_q, h_k = self.num_heads, 1
s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest
max_num_sm_parts = int(
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)
)
# Equation taken from FlashMLA/csrc/api/sparse_decode.h
# For sparse FP8 decode, the formula depends on architecture:
# - SM90 (Hopper): num_sm_parts = num_sms / s_q / (h_q/64)
# - SM100 (Blackwell head64/head64x2): num_sm_parts = num_sms / s_q
# - SM100 (Blackwell head128): num_sm_parts = num_sms / s_q / 2
# For max buffer size, use s_q = 1 (the case that produces largest output)
# Use padded head count since that's what will be passed to the kernel
h_q = self.fp8_decode_padded_heads
if current_platform.is_device_capability_family(100):
max_num_sm_parts *= 2
# SM100 head64 or head64x2 uses full SM count
max_num_sm_parts = sm_count
else:
# SM90 uses h_q/64 divisor
max_num_sm_parts = sm_count // max(1, h_q // 64)
self.tile_scheduler_metadata_buffer = torch.empty(
# TileSchedulerMetaDataSize = 8
# see: FlashMLA/csrc/params.h
@@ -455,12 +464,15 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
"""
num_tokens = common_attn_metadata.num_actual_tokens
# Use padded head count since that's what the kernel will see
padded_heads = self.fp8_decode_padded_heads
# Build metadata for all tokens as a single batch
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor[:1], # Single batch
num_q_tokens_per_head_k=num_tokens * self.num_heads,
num_q_tokens_per_head_k=num_tokens * padded_heads,
topk=self.topk_tokens,
num_heads_q=self.num_heads,
num_heads_q=padded_heads,
num_heads_k=1,
is_fp8_kvcache=True,
)
@@ -606,11 +618,13 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()
# Use padded head count since that's what the kernel will see
padded_heads = self.fp8_decode_padded_heads
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor[:num_decodes],
num_q_tokens_per_head_k=decode_query_len * self.num_heads,
num_q_tokens_per_head_k=decode_query_len * padded_heads,
topk=self.topk_tokens,
num_heads_q=self.num_heads,
num_heads_q=padded_heads,
num_heads_k=1,
is_fp8_kvcache=True,
)
@@ -689,6 +703,12 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
@staticmethod
def _compute_fp8_decode_padded_heads(num_heads: int) -> int:
# FP8 decode kernel only supports h_q = 64 or 128
# Compute padded head count for decode
return 64 if num_heads <= 64 else 128
def __init__(
self,
num_heads: int,
@@ -722,7 +742,11 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
self.padding = 128 if current_platform.is_device_capability_family(100) else 64
# Prefill BF16 kernel requires 64 on Hopper, 128 on Blackwell
self.prefill_padding = (
128 if current_platform.is_device_capability_family(100) else 64
)
self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
if kv_cache_dtype == "fp8_ds_mla":
# Reserve workspace during initialization
@@ -903,8 +927,22 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata,
) -> torch.Tensor:
return flash_mla_with_kvcache(
) -> tuple[torch.Tensor, torch.Tensor]:
# q shape: (batch, seq_len, num_heads, head_dim)
actual_num_heads = q.size(2)
padded_num_heads = self.fp8_decode_padded_heads
# Pad query if needed (kernel only supports h_q = 64 or 128)
if actual_num_heads < padded_num_heads:
logger.warning_once(
f"Padding num_heads from {actual_num_heads} to "
f"{padded_num_heads} for FP8 sparse decode kernel"
)
q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3)))
q_padded[:, :, :actual_num_heads, :] = q
q = q_padded
out, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
block_table=kernel_metadata.dummy_block_table,
@@ -917,6 +955,12 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
softmax_scale=self.softmax_scale,
)
# Slice output back to actual head count if we padded
if actual_num_heads < padded_num_heads:
out = out[:, :, :actual_num_heads, :]
return out, lse
def _bf16_flash_mla_kernel(
self,
q: torch.Tensor,
@@ -930,13 +974,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
# NOTE(Chen): kernel requires num_local_head to be a multiple of
# 64 on hopper and 128 on blackwell
if self.num_heads % self.padding != 0:
assert self.padding % self.num_heads == 0
if self.num_heads % self.prefill_padding != 0:
assert self.prefill_padding % self.num_heads == 0
logger.warning_once(
f"padding num_heads to {self.padding} due to sparse attn "
"kernel requirement"
f"Padding num_heads from {self.num_heads} to "
f"{self.prefill_padding} for BF16 sparse prefill kernel"
)
q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2]))
q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
q_padded[:, : self.num_heads, :] = q
q = q_padded