[SM100] Resubmit FMHA FP8 prefill for MLA (#31195)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety
2026-02-10 13:18:43 -08:00
committed by GitHub
parent 9615575afc
commit 578977bb5e
3 changed files with 145 additions and 23 deletions

View File

@@ -1052,6 +1052,7 @@ class MLACommonPrefillMetadata:
query_seq_lens: torch.Tensor | None = None
workspace_buffer: torch.Tensor | None = None
q_data_type: torch.dtype | None = None
output_dtype: torch.dtype | None = None
@dataclass
@@ -1145,6 +1146,7 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool:
return qk_nope_head_dim == 128 and qk_rope_head_dim == 64 and v_head_dim == 128
@functools.cache
def use_flashinfer_prefill() -> bool:
# For blackwell default to flashinfer prefill if it's available since
# it is faster than FA2.
@@ -1162,6 +1164,7 @@ def use_flashinfer_prefill() -> bool:
return is_deepseek_r1_mla_compatible(vllm_config)
@functools.cache
def use_cudnn_prefill() -> bool:
from vllm.config import get_current_vllm_config
@@ -1174,6 +1177,7 @@ def use_cudnn_prefill() -> bool:
)
@functools.cache
def use_trtllm_ragged_deepseek_prefill() -> bool:
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
from vllm.config import get_current_vllm_config
@@ -1210,6 +1214,27 @@ def get_mla_dims(model_config: ModelConfig) -> MLADims:
)
@functools.cache
def backend_supports_prefill_query_quantization() -> bool:
"""Check if the selected MLA backend supports prefill query quantization.
Currently supported backends:
- FlashInfer prefill
- TRT-LLM ragged DeepSeek prefill
Not supported:
- cuDNN Prefill
- FlashAttention
- Non-GB200 devices (FP8 prefill requires device capability 100)
"""
# FP8 prefill query quantization requires GB200 (device capability 100)
# for the necessary FP8 kernels at the moment.
if not current_platform.is_device_capability_family(100):
return False
return use_flashinfer_prefill() or use_trtllm_ragged_deepseek_prefill()
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
@@ -1262,6 +1287,40 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
return chunked_prefill_workspace_size
@staticmethod
def determine_prefill_query_data_type(
vllm_config: VllmConfig,
model_dtype: torch.dtype,
) -> torch.dtype:
"""
Determine the query data type for prefill queries.
Return FP8 dtype if cache is FP8 and prefill query quantization
is enabled, else model dtype.
"""
use_fp8 = (
vllm_config.cache_config.cache_dtype.startswith("fp8")
and vllm_config.attention_config.use_prefill_query_quantization
and backend_supports_prefill_query_quantization()
)
if use_fp8:
fp8_dtype = current_platform.fp8_dtype()
logger.info_once(
"FP8 prefill attention enabled: query data type is FP8", scope="local"
)
return fp8_dtype
elif vllm_config.attention_config.use_prefill_query_quantization:
logger.info_once(
"Unable to perform FP8 prefill attention when"
" use_prefill_query_quantization is enabled. Please"
" ensure that --kv-cache-dtype is set to fp8 and your prefill"
" backend is compatible with FP8 attention.",
scope="local",
)
return model_dtype
return model_dtype
def __init__(
self,
kv_cache_spec: AttentionSpec,
@@ -1285,6 +1344,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.aot_schedule = current_platform.is_cuda()
self.kv_cache_spec = kv_cache_spec
self.q_data_type = self.determine_prefill_query_data_type(
vllm_config, self.model_config.dtype
)
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
@@ -1325,7 +1390,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.chunked_prefill_workspace_size,
self.model_config.get_head_size(),
),
dtype=self.model_config.dtype,
dtype=self.q_data_type,
device=device,
)
@@ -1435,7 +1500,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype,
q_data_type=self.q_data_type,
o_data_type=prefill.output_dtype,
)
# Prepare context prefills
@@ -1454,7 +1520,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype,
q_data_type=self.q_data_type,
o_data_type=prefill.output_dtype,
)
prefill.prefill_main = self._fi_prefill_main
@@ -1709,6 +1776,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
chunked_context=chunked_context_metadata,
output_dtype=self.model_config.dtype,
q_data_type=self.q_data_type,
)
if self._use_cudnn_prefill:
@@ -1894,7 +1963,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.kv_b_proj = kv_b_proj
self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads
self.supports_quant_query_input = True
# Use flashinfer's optimized concat_mla_k kernel when available.
@@ -2129,6 +2197,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
assert prefill.query_seq_lens is not None
assert prefill.workspace_buffer is not None
# allocate BF16 / FP16 output tensor for TRT-LLM ragged attention
out = torch.empty(
q.shape[0],
q.shape[1],
v.shape[2],
device=q.device,
dtype=prefill.output_dtype,
)
ret = trtllm_ragged_attention_deepseek(
query=q,
@@ -2148,6 +2224,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
enable_pdl=False,
is_causal=True,
return_lse=return_softmax_lse,
out=out,
)
if isinstance(ret, tuple):
@@ -2170,7 +2247,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q.shape[1],
v.shape[2],
device=q.device,
dtype=q.dtype,
dtype=prefill.output_dtype,
)
prefill.workspace_buffer.fill_(0)
@@ -2240,29 +2317,59 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
prefill_metadata = attn_metadata.prefill
assert prefill_metadata.chunked_context is not None
use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype()
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
if use_fp8_prefill:
q = q.to(prefill_metadata.q_data_type)
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
token_to_seq=prefill_metadata.chunked_context.token_to_seq[i],
num_tokens=prefill_metadata.chunked_context.chunk_total_token[i],
kv_cache_dtype=self.kv_cache_dtype,
scale=k_scale,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
if not use_fp8_prefill:
ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
token_to_seq=prefill_metadata.chunked_context.token_to_seq[i],
num_tokens=prefill_metadata.chunked_context.chunk_total_token[i],
kv_cache_dtype=self.kv_cache_dtype,
scale=k_scale,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
else:
# FP8 path: gather cache without dequantization
ops.cp_gather_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
batch_size=attn_metadata.num_prefills,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
# Extract kv_c_normed from workspace
kv_c_normed = workspace[:toks][..., : self.kv_lora_rank]
k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
# When FP8 weights are used without FP8 prefill, kv_b_proj expects
# model dtype input and will quantize internally.
if (
use_fp8_prefill
or self.kv_b_proj.weight.dtype != current_platform.fp8_dtype()
):
kv_c_normed = kv_c_normed.to(self.kv_b_proj.weight.dtype)
k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
# To Do: Use epilogue of kv_b_proj to generate fp8 kv_nope.
if use_fp8_prefill:
kv_nope = kv_nope.to(prefill_metadata.q_data_type)
k_pe = k_pe.to(prefill_metadata.q_data_type)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = self._concat_k_nope_k_pe(k_nope, k_pe)
@@ -2412,16 +2519,27 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
assert attn_metadata.prefill is not None
assert self.dcp_world_size != -1
has_context = attn_metadata.prefill.chunked_context is not None
prefill_metadata = attn_metadata.prefill
use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype()
# Convert q to FP8 if FP8 prefill attention is enabled
if use_fp8_prefill:
q = q.to(prefill_metadata.q_data_type)
has_context = prefill_metadata.chunked_context is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = self._concat_k_nope_k_pe(k_nope, k_pe)
if use_fp8_prefill:
k = k.to(prefill_metadata.q_data_type)
v = v.to(prefill_metadata.q_data_type)
output_prefill = self._run_prefill_new_tokens(
prefill=attn_metadata.prefill,
prefill=prefill_metadata,
q=q,
k=k,
v=v,