From 578977bb5ed208c62cf9cff80d955836775e0d24 Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Tue, 10 Feb 2026 13:18:43 -0800 Subject: [PATCH] [SM100] Resubmit FMHA FP8 prefill for MLA (#31195) Signed-off-by: Pavani Majety --- tests/v1/attention/test_mla_backends.py | 7 +- vllm/config/attention.py | 3 + .../layers/attention/mla_attention.py | 158 +++++++++++++++--- 3 files changed, 145 insertions(+), 23 deletions(-) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 815274e1c..ba70c8251 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -27,7 +27,7 @@ from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported -from vllm.v1.kv_cache_interface import FullAttentionSpec +from vllm.v1.kv_cache_interface import MLAAttentionSpec BACKENDS_TO_TEST = [ AttentionBackendEnum.CUTLASS_MLA, @@ -512,7 +512,7 @@ class MockMLAAttentionLayer(AttentionLayerBase): def run_attention_backend( backend: AttentionBackendEnum, - kv_cache_spec: FullAttentionSpec, + kv_cache_spec: MLAAttentionSpec, layer_names: list[str], vllm_config, device: torch.device, @@ -989,7 +989,7 @@ def test_backend_correctness( kv_cache = kv_cache_per_block_size[block_size] # Create kv_cache_spec with the correct block_size for this backend - backend_kv_cache_spec = FullAttentionSpec( + backend_kv_cache_spec = MLAAttentionSpec( block_size=block_size, num_kv_heads=vllm_config.model_config.get_num_kv_heads( vllm_config.parallel_config @@ -997,6 +997,7 @@ def test_backend_correctness( head_size=vllm_config.model_config.get_head_size(), dtype=vllm_config.model_config.dtype, sliding_window=vllm_config.model_config.get_sliding_window(), + cache_dtype_str=vllm_config.cache_config.cache_dtype, ) backend_output = run_attention_backend( diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 9379b2878..97a139c79 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -43,6 +43,9 @@ class AttentionConfig: disable_flashinfer_q_quantization: bool = False """If set, when using fp8 kv, do not quantize Q to fp8.""" + use_prefill_query_quantization: bool = False + """If set, quantize query for attention in prefill.""" + def compute_hash(self) -> str: """ Provide a hash that uniquely identifies all the configs diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index c31aa7b41..c44bf1f16 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -1052,6 +1052,7 @@ class MLACommonPrefillMetadata: query_seq_lens: torch.Tensor | None = None workspace_buffer: torch.Tensor | None = None q_data_type: torch.dtype | None = None + output_dtype: torch.dtype | None = None @dataclass @@ -1145,6 +1146,7 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool: return qk_nope_head_dim == 128 and qk_rope_head_dim == 64 and v_head_dim == 128 +@functools.cache def use_flashinfer_prefill() -> bool: # For blackwell default to flashinfer prefill if it's available since # it is faster than FA2. @@ -1162,6 +1164,7 @@ def use_flashinfer_prefill() -> bool: return is_deepseek_r1_mla_compatible(vllm_config) +@functools.cache def use_cudnn_prefill() -> bool: from vllm.config import get_current_vllm_config @@ -1174,6 +1177,7 @@ def use_cudnn_prefill() -> bool: ) +@functools.cache def use_trtllm_ragged_deepseek_prefill() -> bool: """Check if TRT-LLM ragged DeepSeek prefill should be used.""" from vllm.config import get_current_vllm_config @@ -1210,6 +1214,27 @@ def get_mla_dims(model_config: ModelConfig) -> MLADims: ) +@functools.cache +def backend_supports_prefill_query_quantization() -> bool: + """Check if the selected MLA backend supports prefill query quantization. + + Currently supported backends: + - FlashInfer prefill + - TRT-LLM ragged DeepSeek prefill + + Not supported: + - cuDNN Prefill + - FlashAttention + - Non-GB200 devices (FP8 prefill requires device capability 100) + """ + # FP8 prefill query quantization requires GB200 (device capability 100) + # for the necessary FP8 kernels at the moment. + if not current_platform.is_device_capability_family(100): + return False + + return use_flashinfer_prefill() or use_trtllm_ragged_deepseek_prefill() + + class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ NOTE: Please read the comment at the top of the file before trying to @@ -1262,6 +1287,40 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): return chunked_prefill_workspace_size + @staticmethod + def determine_prefill_query_data_type( + vllm_config: VllmConfig, + model_dtype: torch.dtype, + ) -> torch.dtype: + """ + Determine the query data type for prefill queries. + Return FP8 dtype if cache is FP8 and prefill query quantization + is enabled, else model dtype. + """ + use_fp8 = ( + vllm_config.cache_config.cache_dtype.startswith("fp8") + and vllm_config.attention_config.use_prefill_query_quantization + and backend_supports_prefill_query_quantization() + ) + + if use_fp8: + fp8_dtype = current_platform.fp8_dtype() + logger.info_once( + "FP8 prefill attention enabled: query data type is FP8", scope="local" + ) + return fp8_dtype + elif vllm_config.attention_config.use_prefill_query_quantization: + logger.info_once( + "Unable to perform FP8 prefill attention when" + " use_prefill_query_quantization is enabled. Please" + " ensure that --kv-cache-dtype is set to fp8 and your prefill" + " backend is compatible with FP8 attention.", + scope="local", + ) + return model_dtype + + return model_dtype + def __init__( self, kv_cache_spec: AttentionSpec, @@ -1285,6 +1344,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.aot_schedule = current_platform.is_cuda() + + self.kv_cache_spec = kv_cache_spec + self.q_data_type = self.determine_prefill_query_data_type( + vllm_config, self.model_config.dtype + ) + try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group @@ -1325,7 +1390,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.chunked_prefill_workspace_size, self.model_config.get_head_size(), ), - dtype=self.model_config.dtype, + dtype=self.q_data_type, device=device, ) @@ -1435,7 +1500,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, - q_data_type=self.model_config.dtype, + q_data_type=self.q_data_type, + o_data_type=prefill.output_dtype, ) # Prepare context prefills @@ -1454,7 +1520,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, - q_data_type=self.model_config.dtype, + q_data_type=self.q_data_type, + o_data_type=prefill.output_dtype, ) prefill.prefill_main = self._fi_prefill_main @@ -1709,6 +1776,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, + output_dtype=self.model_config.dtype, + q_data_type=self.q_data_type, ) if self._use_cudnn_prefill: @@ -1894,7 +1963,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.kv_b_proj = kv_b_proj self.indexer = indexer self.q_pad_num_heads = q_pad_num_heads - self.supports_quant_query_input = True # Use flashinfer's optimized concat_mla_k kernel when available. @@ -2129,6 +2197,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): assert prefill.query_seq_lens is not None assert prefill.workspace_buffer is not None + # allocate BF16 / FP16 output tensor for TRT-LLM ragged attention + out = torch.empty( + q.shape[0], + q.shape[1], + v.shape[2], + device=q.device, + dtype=prefill.output_dtype, + ) ret = trtllm_ragged_attention_deepseek( query=q, @@ -2148,6 +2224,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): enable_pdl=False, is_causal=True, return_lse=return_softmax_lse, + out=out, ) if isinstance(ret, tuple): @@ -2170,7 +2247,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): q.shape[1], v.shape[2], device=q.device, - dtype=q.dtype, + dtype=prefill.output_dtype, ) prefill.workspace_buffer.fill_(0) @@ -2240,29 +2317,59 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): prefill_metadata = attn_metadata.prefill assert prefill_metadata.chunked_context is not None + use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype() + output = None iters = len(prefill_metadata.chunked_context.seq_tot) workspace = prefill_metadata.chunked_context.workspace + + if use_fp8_prefill: + q = q.to(prefill_metadata.q_data_type) + for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - ops.gather_and_maybe_dequant_cache( - src_cache=kv_c_and_k_pe_cache, - dst=workspace, - block_table=prefill_metadata.block_table, - cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], - token_to_seq=prefill_metadata.chunked_context.token_to_seq[i], - num_tokens=prefill_metadata.chunked_context.chunk_total_token[i], - kv_cache_dtype=self.kv_cache_dtype, - scale=k_scale, - seq_starts=prefill_metadata.chunked_context.starts[i], - ) + if not use_fp8_prefill: + ops.gather_and_maybe_dequant_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], + token_to_seq=prefill_metadata.chunked_context.token_to_seq[i], + num_tokens=prefill_metadata.chunked_context.chunk_total_token[i], + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + else: + # FP8 path: gather cache without dequantization + ops.cp_gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], + batch_size=attn_metadata.num_prefills, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + # Extract kv_c_normed from workspace kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] - k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) + # When FP8 weights are used without FP8 prefill, kv_b_proj expects + # model dtype input and will quantize internally. + if ( + use_fp8_prefill + or self.kv_b_proj.weight.dtype != current_platform.fp8_dtype() + ): + kv_c_normed = kv_c_normed.to(self.kv_b_proj.weight.dtype) + k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) kv_nope = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) + + # To Do: Use epilogue of kv_b_proj to generate fp8 kv_nope. + if use_fp8_prefill: + kv_nope = kv_nope.to(prefill_metadata.q_data_type) + k_pe = k_pe.to(prefill_metadata.q_data_type) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = self._concat_k_nope_k_pe(k_nope, k_pe) @@ -2412,16 +2519,27 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): assert attn_metadata.prefill is not None assert self.dcp_world_size != -1 - has_context = attn_metadata.prefill.chunked_context is not None + prefill_metadata = attn_metadata.prefill + use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype() + + # Convert q to FP8 if FP8 prefill attention is enabled + if use_fp8_prefill: + q = q.to(prefill_metadata.q_data_type) + + has_context = prefill_metadata.chunked_context is not None + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = self._concat_k_nope_k_pe(k_nope, k_pe) + if use_fp8_prefill: + k = k.to(prefill_metadata.q_data_type) + v = v.to(prefill_metadata.q_data_type) + output_prefill = self._run_prefill_new_tokens( - prefill=attn_metadata.prefill, + prefill=prefill_metadata, q=q, k=k, v=v,