diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index e7170babb..a2079e70d 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -206,7 +206,7 @@ configuration. |---------|--------|-----------|-------------|------------|------|--------|-----------|-----|-----------------|--------------| | `CUTLASS_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x | | `FLASHINFER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x | -| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | +| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | | `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x | | `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 32c0b9064..86efefc37 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -327,6 +327,12 @@ class MockSparseMLAAttentionLayer: self._k_scale_float = 1.0 self._v_scale_float = 1.0 + self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( + static=True, + group_shape=GroupShape.PER_TENSOR, + compile_native=True, + ) + def forward_impl( self, q: torch.Tensor, @@ -338,6 +344,7 @@ class MockSparseMLAAttentionLayer: ) -> torch.Tensor: """Forward for sparse MLA - uses forward_mqa for all tokens.""" kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto") + fp8_attention = kv_cache_dtype.startswith("fp8") # Write to KV cache if kv_cache.numel() > 0: @@ -350,6 +357,9 @@ class MockSparseMLAAttentionLayer: scale=self._k_scale, ) + if fp8_attention and kv_cache_dtype != "fp8_ds_mla": + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + num_tokens = q.shape[0] # Sparse MLA uses forward_mqa for all tokens @@ -367,8 +377,14 @@ class MockSparseMLAAttentionLayer: # Convert from (N, B, L) to (B, N, L) mqa_ql_nope = mqa_ql_nope.transpose(0, 1) - # Pass as tuple to forward_mqa - mqa_q = (mqa_ql_nope, mqa_q_pe) + if fp8_attention and self.impl.supports_quant_query_input: + assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0] + assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1] + mqa_q = self._decode_concat_quant_fp8_op( + mqa_ql_nope, mqa_q_pe, self._q_scale + ) + else: + mqa_q = (mqa_ql_nope, mqa_q_pe) attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index 86cefa036..0fd0ba6fa 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -191,6 +191,16 @@ def test_sparse_backend_decode_correctness( if kv_cache_dtype not in backend_cls.supported_kv_cache_dtypes: pytest.skip(f"{backend_cls.get_name()} does not support {kv_cache_dtype}") + if ( + backend_cls == FlashMLASparseBackend + and kv_cache_dtype.startswith("fp8") + and kv_cache_dtype != "fp8_ds_mla" + ): + pytest.skip( + "FlashMLA Sparse Attention backend fp8 only supports " + "fp8_ds_mla kv-cache dtype" + ) + supported_block_sizes = backend_cls.get_supported_kernel_block_sizes() if block_size not in supported_block_sizes: pytest.skip( @@ -419,7 +429,7 @@ def test_sparse_backend_decode_correctness( num_blocks=vllm_config.cache_config.num_gpu_blocks, common_attn_metadata=common_attn_metadata, randomize_blocks=False, - kv_cache_dtype=kv_cache_dtype if use_fp8_ds_mla_quantization else "auto", + kv_cache_dtype=kv_cache_dtype, scale=kv_cache_scale, ) diff --git a/tools/pre_commit/generate_attention_backend_docs.py b/tools/pre_commit/generate_attention_backend_docs.py index 628656f0d..3ec2248a8 100644 --- a/tools/pre_commit/generate_attention_backend_docs.py +++ b/tools/pre_commit/generate_attention_backend_docs.py @@ -49,6 +49,11 @@ MLA_ATTENTION_FILE = ( # Backends to skip during doc generation SKIP_BACKENDS = {"CUSTOM", "TORCH_SDPA"} +BACKEND_KV_DTYPE_EXCLUDES: dict[str, set[str]] = { + # fp8 is an alias for fp8_ds_mla for FlashMLA Sparse + "FLASHMLA_SPARSE": {"fp8"}, +} + def is_relevant_file(filepath: str) -> bool: """Check if a file matches any of the relevant patterns.""" @@ -546,10 +551,19 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None tree, impl_class_name, "can_return_lse_for_decode", False, file_path ) + kv_cache_dtypes = parse_kv_cache_dtypes(class_node) + if backend_name in BACKEND_KV_DTYPE_EXCLUDES: + excluded = BACKEND_KV_DTYPE_EXCLUDES[backend_name] + kv_cache_dtypes = ", ".join( + d + for d in (d.strip() for d in kv_cache_dtypes.split(",")) + if d not in excluded + ) + return { "name": backend_name, "dtypes": parse_supported_dtypes(class_node), - "kv_cache_dtypes": parse_kv_cache_dtypes(class_node), + "kv_cache_dtypes": kv_cache_dtypes, "block_sizes": parse_block_sizes(class_node), "head_sizes": parse_head_sizes(class_node), "attn_types": parse_attention_types(class_node), diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index b0e16fa52..97ae3ef1b 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -331,11 +331,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): calculate_kv_scales = False self.quant_config = quant_config - # Initialize KV cache quantization attributes - self.kv_cache_dtype = kv_cache_dtype - self.calculate_kv_scales = calculate_kv_scales - _init_kv_cache_quant(self, quant_config, prefix) - dtype = torch.get_default_dtype() self.attn_backend = get_attn_backend( self.head_size, @@ -347,6 +342,36 @@ class MLAAttention(nn.Module, AttentionLayerBase): num_heads=self.num_heads, ) + # FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format + # Automatically convert fp8 kv-cache format to "fp8_ds_mla" + if ( + self.attn_backend.get_name() == "FLASHMLA_SPARSE" + and kv_cache_dtype.startswith("fp8") + and kv_cache_dtype != "fp8_ds_mla" + ): + assert cache_config is not None + cache_config.cache_dtype = "fp8_ds_mla" + kv_cache_dtype = "fp8_ds_mla" + logger.info_once( + "Using DeepSeek's fp8_ds_mla KV cache format. To use standard " + "fp8 kv-cache format, please set `--attention-backend " + "FLASHINFER_MLA_SPARSE`" + ) + + if ( + self.attn_backend.get_name() == "FLASHINFER_MLA_SPARSE" + and kv_cache_dtype.startswith("fp8") + ): + logger.info_once( + "Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla " + "KV cache format, please set `--attention-backend FLASHMLA_SPARSE`" + ) + + # Initialize KV cache quantization attributes + self.kv_cache_dtype = kv_cache_dtype + self.calculate_kv_scales = calculate_kv_scales + _init_kv_cache_quant(self, quant_config, prefix) + if ( cache_config is not None and cache_config.enable_prefix_caching diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 734e3ad23..0e35bedbc 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -31,20 +31,13 @@ class VerifyAndUpdateConfig: class DeepseekV32ForCausalLM(VerifyAndUpdateConfig): @classmethod def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: - """ - Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32 - """ hf_config = vllm_config.model_config.hf_config # Mirror the check in vllm/model_executor/models/deepseek_v2.py is_v32 = hasattr(hf_config, "index_topk") assert is_v32 - # For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled. cache_config = vllm_config.cache_config - if cache_config.cache_dtype.startswith("fp8"): - cache_config.cache_dtype = "fp8_ds_mla" - logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2") if cache_config.cache_dtype == "bfloat16": cache_config.cache_dtype = "auto" logger.info("Using bfloat16 kv-cache for DeepSeekV3.2") diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py index 21a0d99c2..34683d3f6 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py @@ -63,6 +63,8 @@ class FlashInferMLASparseBackend(AttentionBackend): supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", "bfloat16", + "fp8", + "fp8_e4m3", ] @staticmethod @@ -304,6 +306,11 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata self.bmm1_scale: float | None = None self.bmm2_scale: float | None = None + # fp8 query quantization is required when using fp8 kv_cache, + # as the TRTLLM-GEN sparse MLA kernel requires matching dtypes + # for query and kv_cache (mixed bf16+fp8 is not supported). + self.supports_quant_query_input = True + def forward_mqa( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index c8a78af4a..c0cdc204d 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -83,6 +83,7 @@ class FlashMLASparseBackend(AttentionBackend): "auto", "bfloat16", "fp8_ds_mla", + "fp8", # alias for fp8_ds_mla ] @staticmethod @@ -567,6 +568,12 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]): ) self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads) + if kv_cache_dtype.startswith("fp8"): + assert kv_cache_dtype == "fp8_ds_mla", ( + "FlashMLA Sparse Attention backend fp8 only supports " + "fp8_ds_mla kv-cache dtype" + ) + if kv_cache_dtype == "fp8_ds_mla": # Reserve workspace during initialization vllm_config = get_current_vllm_config()