Update FlashMLA (#32491)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -17,7 +17,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
@@ -397,6 +396,10 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
# FP8 decode kernel only supports h_q = 64 or 128, so we need to pad
|
||||
self.fp8_decode_padded_heads = (
|
||||
FlashMLASparseImpl._compute_fp8_decode_padded_heads(self.num_heads)
|
||||
)
|
||||
|
||||
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
||||
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
|
||||
@@ -417,14 +420,20 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
(max_num_seqs, 1), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# Equation taken from FlashMLA/csrc/pybind.cpp
|
||||
h_q, h_k = self.num_heads, 1
|
||||
s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest
|
||||
max_num_sm_parts = int(
|
||||
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)
|
||||
)
|
||||
# Equation taken from FlashMLA/csrc/api/sparse_decode.h
|
||||
# For sparse FP8 decode, the formula depends on architecture:
|
||||
# - SM90 (Hopper): num_sm_parts = num_sms / s_q / (h_q/64)
|
||||
# - SM100 (Blackwell head64/head64x2): num_sm_parts = num_sms / s_q
|
||||
# - SM100 (Blackwell head128): num_sm_parts = num_sms / s_q / 2
|
||||
# For max buffer size, use s_q = 1 (the case that produces largest output)
|
||||
# Use padded head count since that's what will be passed to the kernel
|
||||
h_q = self.fp8_decode_padded_heads
|
||||
if current_platform.is_device_capability_family(100):
|
||||
max_num_sm_parts *= 2
|
||||
# SM100 head64 or head64x2 uses full SM count
|
||||
max_num_sm_parts = sm_count
|
||||
else:
|
||||
# SM90 uses h_q/64 divisor
|
||||
max_num_sm_parts = sm_count // max(1, h_q // 64)
|
||||
self.tile_scheduler_metadata_buffer = torch.empty(
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
# see: FlashMLA/csrc/params.h
|
||||
@@ -455,12 +464,15 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
"""
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
# Use padded head count since that's what the kernel will see
|
||||
padded_heads = self.fp8_decode_padded_heads
|
||||
|
||||
# Build metadata for all tokens as a single batch
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor[:1], # Single batch
|
||||
num_q_tokens_per_head_k=num_tokens * self.num_heads,
|
||||
num_q_tokens_per_head_k=num_tokens * padded_heads,
|
||||
topk=self.topk_tokens,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_q=padded_heads,
|
||||
num_heads_k=1,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
@@ -606,11 +618,13 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()
|
||||
|
||||
# Use padded head count since that's what the kernel will see
|
||||
padded_heads = self.fp8_decode_padded_heads
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor[:num_decodes],
|
||||
num_q_tokens_per_head_k=decode_query_len * self.num_heads,
|
||||
num_q_tokens_per_head_k=decode_query_len * padded_heads,
|
||||
topk=self.topk_tokens,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_q=padded_heads,
|
||||
num_heads_k=1,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
@@ -689,6 +703,12 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
|
||||
|
||||
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
@staticmethod
|
||||
def _compute_fp8_decode_padded_heads(num_heads: int) -> int:
|
||||
# FP8 decode kernel only supports h_q = 64 or 128
|
||||
# Compute padded head count for decode
|
||||
return 64 if num_heads <= 64 else 128
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -722,7 +742,11 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
||||
self.padding = 128 if current_platform.is_device_capability_family(100) else 64
|
||||
# Prefill BF16 kernel requires 64 on Hopper, 128 on Blackwell
|
||||
self.prefill_padding = (
|
||||
128 if current_platform.is_device_capability_family(100) else 64
|
||||
)
|
||||
self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
|
||||
|
||||
if kv_cache_dtype == "fp8_ds_mla":
|
||||
# Reserve workspace during initialization
|
||||
@@ -903,8 +927,22 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata,
|
||||
) -> torch.Tensor:
|
||||
return flash_mla_with_kvcache(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# q shape: (batch, seq_len, num_heads, head_dim)
|
||||
actual_num_heads = q.size(2)
|
||||
padded_num_heads = self.fp8_decode_padded_heads
|
||||
|
||||
# Pad query if needed (kernel only supports h_q = 64 or 128)
|
||||
if actual_num_heads < padded_num_heads:
|
||||
logger.warning_once(
|
||||
f"Padding num_heads from {actual_num_heads} to "
|
||||
f"{padded_num_heads} for FP8 sparse decode kernel"
|
||||
)
|
||||
q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3)))
|
||||
q_padded[:, :, :actual_num_heads, :] = q
|
||||
q = q_padded
|
||||
|
||||
out, lse = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
|
||||
block_table=kernel_metadata.dummy_block_table,
|
||||
@@ -917,6 +955,12 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
|
||||
# Slice output back to actual head count if we padded
|
||||
if actual_num_heads < padded_num_heads:
|
||||
out = out[:, :, :actual_num_heads, :]
|
||||
|
||||
return out, lse
|
||||
|
||||
def _bf16_flash_mla_kernel(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@@ -930,13 +974,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
|
||||
# NOTE(Chen): kernel requires num_local_head to be a multiple of
|
||||
# 64 on hopper and 128 on blackwell
|
||||
if self.num_heads % self.padding != 0:
|
||||
assert self.padding % self.num_heads == 0
|
||||
if self.num_heads % self.prefill_padding != 0:
|
||||
assert self.prefill_padding % self.num_heads == 0
|
||||
logger.warning_once(
|
||||
f"padding num_heads to {self.padding} due to sparse attn "
|
||||
"kernel requirement"
|
||||
f"Padding num_heads from {self.num_heads} to "
|
||||
f"{self.prefill_padding} for BF16 sparse prefill kernel"
|
||||
)
|
||||
q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2]))
|
||||
q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
|
||||
q_padded[:, : self.num_heads, :] = q
|
||||
q = q_padded
|
||||
|
||||
|
||||
Reference in New Issue
Block a user