From 1f5ec2889c4176593c2fedae3b33b6244c996d29 Mon Sep 17 00:00:00 2001 From: Carl Y <4531192+carlyou@users.noreply.github.com> Date: Thu, 2 Apr 2026 18:16:11 -0700 Subject: [PATCH] [mla] Support fused FP8/NVFP4 output quantization in MLA attention (#35792) (#36205) Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com> Signed-off-by: Carl Y <4531192+carlyou@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- .buildkite/test_areas/compile.yaml | 2 + docs/design/fusions.md | 20 +- tests/compile/fusions_e2e/conftest.py | 12 +- tests/compile/fusions_e2e/models.py | 33 +- tests/compile/fusions_e2e/test_tp1_quant.py | 11 +- tests/compile/fusions_e2e/test_tp2_ar_rms.py | 18 +- .../passes/test_mla_attn_quant_fusion.py | 508 ++++++++++++++++++ .../passes/fusion/mla_attn_quant_fusion.py | 262 +++++++++ vllm/compilation/passes/pass_manager.py | 2 + vllm/config/compilation.py | 2 +- .../layers/attention/mla_attention.py | 54 +- vllm/v1/attention/backend.py | 21 + 12 files changed, 928 insertions(+), 17 deletions(-) create mode 100644 tests/compile/passes/test_mla_attn_quant_fusion.py create mode 100644 vllm/compilation/passes/fusion/mla_attn_quant_fusion.py diff --git a/.buildkite/test_areas/compile.yaml b/.buildkite/test_areas/compile.yaml index c21b66552..aa46447c2 100644 --- a/.buildkite/test_areas/compile.yaml +++ b/.buildkite/test_areas/compile.yaml @@ -72,6 +72,7 @@ steps: - vllm/v1/attention/backends/flashinfer.py - vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes - tests/compile/passes/test_fusion_attn.py + - tests/compile/passes/test_mla_attn_quant_fusion.py - tests/compile/passes/test_silu_mul_quant_fusion.py - tests/compile/passes/distributed/test_fusion_all_reduce.py - tests/compile/fullgraph/test_full_graph.py @@ -79,6 +80,7 @@ steps: # b200 runners are limited, so we limit the tests to the minimum set only supported on Blackwell - nvidia-smi - pytest -v -s tests/compile/passes/test_fusion_attn.py -k FLASHINFER + - pytest -v -s tests/compile/passes/test_mla_attn_quant_fusion.py - pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py # this runner has 2 GPUs available even though num_devices=2 is not set - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py diff --git a/docs/design/fusions.md b/docs/design/fusions.md index cdc825e8d..3fe1f769b 100644 --- a/docs/design/fusions.md +++ b/docs/design/fusions.md @@ -22,6 +22,7 @@ or just on the low or high end. | ------------------------------------------------------------------------------ | ---------------------------- | ---------------------------------------------- | ------------------------------ | ------------------ | --------- | ------------ | | [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms) | `fuse_allreduce_rms` | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20% | No | Low | | [Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | Attention output → FP8/NVFP4 quant | Off by default | 3-7% | Yes | Always | +| [MLA Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | MLA Attention output → FP8/NVFP4 quant | Off by default | TBD | Yes | Always | | [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O2 (ROCm/AITER only) | 2-4% | No | Low | | [QK Norm + RoPE](#qk-norm--rope-enable_qk_norm_rope_fusion) | `enable_qk_norm_rope_fusion` | Q/K RMSNorm → rotary embedding | Off by default | 2-3% | No | Low | | [Sequence Parallelism](#sequence-parallelism-enable_sp) | `enable_sp` | AllReduce → ReduceScatter + AllGather | Off by default | Prereq for AsyncTP | Yes | High | @@ -40,6 +41,7 @@ The table below lists the quantization schemes supported by each fusion on each | ---------------------------- | ---------------------------------------- | ---------------------------------------- | ---------------------------------------- | ------------- | ---------------------------------------- | | `fuse_allreduce_rms` | FP16/BF16, FP8 static, NVFP4 | FP16/BF16, FP8 static | — | — | — | | `fuse_attn_quant`\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static\* | +| `fuse_attn_quant` (MLA)\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static(untested)\* | | `fuse_rope_kvcache` | — | — | — | — | FP16/BF16 | | `enable_qk_norm_rope_fusion` | FP16/BF16 | FP16/BF16 | FP16/BF16† | FP16/BF16† | — | | `enable_sp` | FP16/BF16, FP8 static† | FP16/BF16, FP8 static | FP16/BF16† | FP16/BF16† | — | @@ -129,7 +131,8 @@ on SM90/SM100) and configurable via `PassConfig.fi_allreduce_fusion_max_size_mb` explicitly. It requires the full model graph to be visible (Inductor partition or `splitting_ops=[]`). **What it fuses.** Fuses the attention output quantization directly after the attention computation, -eliminating a full-precision memory round-trip of the attention output. Patterns covered: +eliminating a full-precision memory round-trip of the attention output. This fusion supports both +standard `Attention` and `MLAAttention` (used by DeepSeek-V2/V3/R1 models). Patterns covered: `Attention → FP8 static quant`: @@ -142,11 +145,24 @@ eliminating a full-precision memory round-trip of the attention output. Patterns - `FLASHINFER`: CUDA sm100+ with FlashInfer installed +`MLAAttention → FP8 static quant` / `MLAAttention → NVFP4 dynamic quant`: + +The MLA fusion operates at the graph level on the `unified_mla_attention_with_output` op and works +with all MLA decode and prefill backend combinations. Unlike standard `Attention` backends (where +the kernel writes FP8 output directly), no MLA prefill or decode backend currently supports direct +FP8/FP4 output. The fusion writes to an intermediate buffer and quantizes in a separate step, so +there is no memory round-trip elimination yet. + +!!! info + The MLA attention fusion is not expected to yield a measurable speedup yet. + This will improve once MLA prefill/decode kernels support direct FP8/FP4 output. + Other attention backends do not support fused output quantization yet. **Code locations.** -- Pass: [`vllm/compilation/passes/fusion/attn_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/attn_quant_fusion.py) +- Pass (Attention): [`vllm/compilation/passes/fusion/attn_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/attn_quant_fusion.py) +- Pass (MLAAttention): [`vllm/compilation/passes/fusion/mla_attn_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/mla_attn_quant_fusion.py) - Attention backends: [`vllm/v1/attention/backends/`](https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/) ### RoPE + KV-Cache Update (`fuse_rope_kvcache`) diff --git a/tests/compile/fusions_e2e/conftest.py b/tests/compile/fusions_e2e/conftest.py index ca67d90d2..adc569192 100644 --- a/tests/compile/fusions_e2e/conftest.py +++ b/tests/compile/fusions_e2e/conftest.py @@ -84,10 +84,14 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): rocm_aiter_ops.refresh_env_variables() # Filter here to reduce code duplication + backend_name = attn_backend.backend.name.lower() requires_mla = "deepseek" in model_name.lower() - is_mla = "mla" in attn_backend.backend.name.lower() + is_mla = "mla" in backend_name + # DeepSeek V3.2 uses sparse MLA + requires_sparse = "v3.2" in model_name.lower() + is_sparse = "sparse" in backend_name - if requires_mla != is_mla: + if requires_mla != is_mla or requires_sparse != is_sparse: pytest.skip( f"Incompatible model '{model_name}' and " f"attention backend '{attn_backend.backend.name}'" @@ -231,7 +235,9 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): ) elif match_name == "attn_quant_fusion": - actual_match = match_table.get(match_name, 0) + actual_match = match_table.get( + "attn_quant_fusion", 0 + ) + match_table.get("mla_attn_quant_fusion", 0) assert actual_match == expected_matches * n_expected, ( f"Could not find {expected_matches * n_expected} " f"{match_name} (found {actual_match})." diff --git a/tests/compile/fusions_e2e/models.py b/tests/compile/fusions_e2e/models.py index b174efd25..8d830e884 100644 --- a/tests/compile/fusions_e2e/models.py +++ b/tests/compile/fusions_e2e/models.py @@ -58,6 +58,15 @@ TRITON_MLA_ATTN = pytest.param( id="TRITON_MLA", ) +FLASHMLA_SPARSE_ATTN = pytest.param( + AttentionBackendCase(backend=AttentionBackendEnum.FLASHMLA_SPARSE), + id="FLASHMLA_SPARSE", + marks=pytest.mark.skipif( + not is_blackwell(), + reason="FlashMLA Sparse requires Blackwell", + ), +) + # Models llama3_8b = ModelFusionInfo( model_name="meta-llama/Llama-3.1-8B-Instruct", @@ -141,6 +150,18 @@ qwen3_a3b_fp8 = ModelFusionInfo( ), ) +deepseek_coder_v2_lite_fp8 = ModelFusionInfo( + model_name="RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8", + matches=lambda n_layers: Matches( + # first_k_dense_replace=1; MoE hides most rms+quant sites + rms_quant_fusion=1, + act_quant_fusion=min(1, n_layers), # dense layers only + # MLA attn + static FP8 quant + attn_quant_fusion=n_layers, + ar_rms_fusion=n_layers * 2 + 1, + ), +) + deepseek_v3_fp8 = ModelFusionInfo( model_name="deepseek-ai/DeepSeek-V3", matches=lambda n_layers: Matches( @@ -152,7 +173,7 @@ deepseek_v3_fp8 = ModelFusionInfo( rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers # silu+block quant act_quant_fusion=min(3, n_layers), # dense layers only - # MLA attn + quant not supported yet: + # MLA attn + per-group FP8 quant not supported yet: # https://github.com/vllm-project/vllm/issues/35792 attn_quant_fusion=0, ar_rms_fusion=n_layers * 2 + 1, @@ -162,6 +183,16 @@ deepseek_v3_fp8 = ModelFusionInfo( ), ) +deepseek_v32_fp4 = ModelFusionInfo( + model_name="nvidia/DeepSeek-V3.2-NVFP4", + matches=lambda n_layers: Matches( + rms_quant_fusion=0, + act_quant_fusion=0, + attn_quant_fusion=n_layers, + ar_rms_fusion=n_layers * 2 + 1, + ), +) + gpt_oss_20b = ModelFusionInfo( model_name="openai/gpt-oss-20b", matches=lambda n_layers: Matches( diff --git a/tests/compile/fusions_e2e/test_tp1_quant.py b/tests/compile/fusions_e2e/test_tp1_quant.py index 8186ecbb4..ded39939e 100644 --- a/tests/compile/fusions_e2e/test_tp1_quant.py +++ b/tests/compile/fusions_e2e/test_tp1_quant.py @@ -18,11 +18,14 @@ from .common import ( from .models import ( FLASHINFER_ATTN, FLASHINFER_MLA_ATTN, + FLASHMLA_SPARSE_ATTN, ROCM_AITER_UNIFIED_ATTN, ROCM_ATTN, TRITON_ATTN, TRITON_MLA_ATTN, + deepseek_coder_v2_lite_fp8, deepseek_v3_fp8, + deepseek_v32_fp4, llama3_8b_fp4, llama3_8b_fp8, llama4_scout_fp4, @@ -37,6 +40,7 @@ from .models import ( (*llama3_8b_fp8, False), (*qwen3_a3b_fp8, False), (*qwen3_a3b_fp8, True), + (*deepseek_coder_v2_lite_fp8, False), (*deepseek_v3_fp8, False), (*deepseek_v3_fp8, True), pytest.param( @@ -144,9 +148,12 @@ def test_tp1_fp8_fusions( @pytest.mark.parametrize( "model_name, matches_fn, model_kwargs, hf_overrides", - [llama3_8b_fp4, llama4_scout_fp4], + [llama3_8b_fp4, llama4_scout_fp4, deepseek_v32_fp4], +) +@pytest.mark.parametrize( + "attn_backend", + [FLASHINFER_ATTN, FLASHMLA_SPARSE_ATTN], ) -@pytest.mark.parametrize("attn_backend", [FLASHINFER_ATTN]) @pytest.mark.parametrize("n_layers", [6]) @pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm")) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) diff --git a/tests/compile/fusions_e2e/test_tp2_ar_rms.py b/tests/compile/fusions_e2e/test_tp2_ar_rms.py index fa1ceb7f0..4b0a0859b 100644 --- a/tests/compile/fusions_e2e/test_tp2_ar_rms.py +++ b/tests/compile/fusions_e2e/test_tp2_ar_rms.py @@ -18,8 +18,11 @@ from .common import ( from .models import ( FLASHINFER_ATTN, FLASHINFER_MLA_ATTN, + FLASHMLA_SPARSE_ATTN, TRITON_ATTN, + deepseek_coder_v2_lite_fp8, deepseek_v3_fp8, + deepseek_v32_fp4, gpt_oss_20b, llama3_8b, llama3_8b_fp4, @@ -37,7 +40,13 @@ pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only tes @pytest.mark.parametrize( "model_name, matches_fn, model_kwargs, hf_overrides", # qwen3 & dsv3 should still fuse AR+rms even though group quant is not yet supported - [llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8, deepseek_v3_fp8], + [ + llama3_8b_fp8, + llama4_scout_fp8, + qwen3_a3b_fp8, + deepseek_coder_v2_lite_fp8, + deepseek_v3_fp8, + ], ) @pytest.mark.parametrize( "attn_backend", [TRITON_ATTN, FLASHINFER_ATTN, FLASHINFER_MLA_ATTN] @@ -104,9 +113,12 @@ def test_tp2_ar_rms_fp8_fusions( @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( "model_name, matches_fn, model_kwargs, hf_overrides", - [llama3_8b_fp4, llama4_scout_fp4], + [llama3_8b_fp4, llama4_scout_fp4, deepseek_v32_fp4], +) +@pytest.mark.parametrize( + "attn_backend", + [FLASHINFER_ATTN, FLASHMLA_SPARSE_ATTN], ) -@pytest.mark.parametrize("attn_backend", [FLASHINFER_ATTN]) @pytest.mark.parametrize("n_layers", [4]) @pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm")) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) diff --git a/tests/compile/passes/test_mla_attn_quant_fusion.py b/tests/compile/passes/test_mla_attn_quant_fusion.py new file mode 100644 index 000000000..426fbb6a7 --- /dev/null +++ b/tests/compile/passes/test_mla_attn_quant_fusion.py @@ -0,0 +1,508 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy + +import pytest +import torch._dynamo + +from tests.compile.backend import LazyInitPass, TestBackend +from tests.utils import TestFP8Layer, flat_product +from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant +from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS +from vllm.compilation.passes.fusion.mla_attn_quant_fusion import ( + MLA_ATTN_OP, + MLAAttnQuantFusionPass, +) +from vllm.compilation.passes.fx_utils import find_op_nodes +from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass +from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass +from vllm.config import ( + AttentionConfig, + CacheConfig, + CompilationConfig, + CompilationMode, + ModelConfig, + PassConfig, + SchedulerConfig, + VllmConfig, + set_current_vllm_config, +) +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.model_executor.layers.attention import MLAAttention +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.modelopt import ModelOptNvFp4Config +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8StaticTensorSym, + kNvfp4Dynamic, +) +from vllm.platforms import current_platform +from vllm.v1.attention.backend import AttentionMetadata +from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.kv_cache_interface import MLAAttentionSpec + +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + + +class MLAAttentionQuantPatternModel(torch.nn.Module): + """Base model for MLA AttentionQuantPattern fusion.""" + + def __init__( + self, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + kv_lora_rank: int, + kv_cache_dtype: torch.dtype, + device: torch.device, + vllm_config: VllmConfig, + **kwargs, + ): + super().__init__() + self.num_heads = num_heads + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.kv_lora_rank = kv_lora_rank + self.output_dim = num_heads * v_head_dim + self.head_size = kv_lora_rank + qk_rope_head_dim + self.kv_cache_dtype = kv_cache_dtype + self.device = device + self.vllm_config = vllm_config + + # Create kv_b_proj (ColumnParallelLinear) on device. + # Reuse weights from prior model instance when available, because + # ColumnParallelLinear may get NaN from recycled CUDA memory after + # torch.compile runs in the same process. + kv_b_proj = ColumnParallelLinear( + input_size=kv_lora_rank, + output_size=num_heads * (qk_nope_head_dim + v_head_dim), + bias=False, + prefix="model.layers.0.self_attn.kv_b_proj", + ).to(device) + kv_b_proj_weight = kwargs.get("kv_b_proj_weight") + if kv_b_proj_weight is not None: + kv_b_proj.weight.data.copy_(kv_b_proj_weight) + elif kv_b_proj.weight.data.isnan().any(): + # Sanitize NaN from recycled CUDA memory + kv_b_proj.weight.data.normal_() + + # Create MLAAttention + self.mla_attn = MLAAttention( + num_heads=num_heads, + scale=1.0 / (self.qk_head_dim**0.5), + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + kv_b_proj=kv_b_proj, + cache_config=vllm_config.cache_config, + quant_config=self.quant_config, + prefix="model.layers.0.self_attn.attn", + ) + self.mla_attn._k_scale = self.mla_attn._k_scale.to(device) + self.mla_attn._v_scale = self.mla_attn._v_scale.to(device) + + # Initialize W_UK_T and W_UV from kv_b_proj weights + self.mla_attn.process_weights_after_loading(torch.get_default_dtype()) + self.kv_b_proj_weight = kv_b_proj.weight.data.clone() + + self.block_size = 16 + + # Initialize MLA MetadataBuilder + self.builder = self.mla_attn.attn_backend.get_builder_cls()( + kv_cache_spec=MLAAttentionSpec( + block_size=self.block_size, + num_kv_heads=1, + head_size=self.head_size, + dtype=self.kv_cache_dtype, + ), + layer_names=[self.mla_attn.layer_name], + vllm_config=self.vllm_config, + device=self.device, + ) + + def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: + """Initialize MLA attention metadata. + + NOTE: Uses decode-only batch (query_len=1 per request). The prefill + (forward_mha) path is not separately tested here because it requires + FlashAttention availability and different input tensor shapes. The + quant logic in forward_impl is identical for both paths — it quantizes + the full output[:num_actual_toks] buffer after both forward_mha and + forward_mqa have written their results. + """ + + batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size) + common_attn_metadata = create_common_attn_metadata( + batch_spec, self.block_size, self.device, arange_block_indices=True + ) + + max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size + num_blocks = batch_size * max_blocks + + # MLA KV cache is 3D: (num_blocks, block_size, head_size) + attn_backend = self.mla_attn.attn_backend + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, 1, self.head_size + ) + try: + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(kv_cache_shape))) + + ordered_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) + inv_order = [ + kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) + ] + + raw_tensor = torch.zeros( + ordered_shape, dtype=self.kv_cache_dtype, device=self.device + ) + kv_cache = raw_tensor.permute(*inv_order) + + self.mla_attn.kv_cache = kv_cache + + self.attn_metadata = self.builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) + + return self.attn_metadata + + +class TestMLAAttentionFp8StaticQuantPatternModel(MLAAttentionQuantPatternModel): + """Test model for MLA Attention + FP8 static quant fusion.""" + + quant_key = kFp8StaticTensorSym + quant_config = Fp8Config() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.fp8_linear = TestFP8Layer( + weight_shape=(self.output_dim, self.output_dim), + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, + device=self.device, + ) + + w = kwargs.get("w") + if w is not None: + self.fp8_linear.weight = w["weight"] + self.fp8_linear.weight_scale = w["wscale"] + self.fp8_linear.input_scale = w["scale"] + + self.w = { + "weight": self.fp8_linear.weight, + "wscale": self.fp8_linear.weight_scale, + "scale": self.fp8_linear.input_scale, + } + + def forward( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + ): + """Forward pass that creates the MLA attention + FP8 quant pattern.""" + attn_output = self.mla_attn( + q, + kv_c_normed, + k_pe, + output_shape=(q.shape[0], self.output_dim), + ) + return self.fp8_linear(attn_output) + + +class TestMLAAttentionNvfp4QuantPatternModel(MLAAttentionQuantPatternModel): + """Test model for MLA Attention + NVFP4 quant fusion.""" + + quant_key = kNvfp4Dynamic + quant_config = ModelOptNvFp4Config( + is_checkpoint_nvfp4_serialized=False, + kv_cache_quant_algo=None, + exclude_modules=[], + ) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.w = kwargs.get( + "w", + { + "weight": torch.randint( + 256, + (self.output_dim, self.output_dim // 2), + dtype=FP4_DTYPE, + device=self.device, + ), + "wscale_swizzled": torch.randn( + self.output_dim, self.output_dim // 16 + ).to(dtype=FP8_DTYPE, device=self.device), + "wscale": torch.tensor([500], dtype=torch.float32, device=self.device), + "scale": torch.tensor([0.002], dtype=torch.float32, device=self.device), + }, + ) + + def forward( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + ): + """Forward pass that creates the MLA attention + NVFP4 quant pattern.""" + attn_output = self.mla_attn( + q, + kv_c_normed, + k_pe, + output_shape=(q.shape[0], self.output_dim), + ) + quant_output, output_block_scale = scaled_fp4_quant( + attn_output, 1 / self.w["scale"] + ) + return cutlass_scaled_fp4_mm( + a=quant_output, + b=self.w["weight"], + block_scale_a=output_block_scale, + block_scale_b=self.w["wscale_swizzled"], + alpha=self.w["scale"] * self.w["wscale"], + out_dtype=attn_output.dtype, + ) + + +def is_nvfp4_supported(): + return current_platform.has_device_capability(100) + + +# MLA test configuration +MLA_DIMS: list[tuple[int, int, int, int, int]] = [] +PATTERN_TEST_MODELS_MLA_FP8: list[tuple[str, type]] = [] +PATTERN_TEST_MODELS_MLA_FP4: list[tuple[str, type]] = [] +BACKENDS_MLA_FP8: list[AttentionBackendEnum] = [] +BACKENDS_MLA_FP4: list[AttentionBackendEnum] = [] + +if current_platform.is_cuda(): + # (num_heads, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, kv_lora_rank) + MLA_DIMS = [(16, 128, 64, 128, 512)] + PATTERN_TEST_MODELS_MLA_FP8 = [ + ( + "deepseek-ai/DeepSeek-V2-Lite", + TestMLAAttentionFp8StaticQuantPatternModel, + ) + ] + PATTERN_TEST_MODELS_MLA_FP4 = [ + ( + "deepseek-ai/DeepSeek-V2-Lite", + TestMLAAttentionNvfp4QuantPatternModel, + ) + ] + BACKENDS_MLA_FP8 = [AttentionBackendEnum.TRITON_MLA] + BACKENDS_MLA_FP4 = [AttentionBackendEnum.TRITON_MLA] + + +@pytest.mark.parametrize( + "num_heads, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, kv_lora_rank", + MLA_DIMS, +) +@pytest.mark.parametrize("batch_size", [7, 256] if current_platform.is_cuda() else [8]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize( + "backend, model_name, model_class, custom_ops", + list( + flat_product( + BACKENDS_MLA_FP8, + PATTERN_TEST_MODELS_MLA_FP8, + ["+quant_fp8", "-quant_fp8"], + ) + ) + + list(flat_product(BACKENDS_MLA_FP4, PATTERN_TEST_MODELS_MLA_FP4, [""])), +) +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" +) +@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") +def test_mla_attention_quant_pattern( + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + kv_lora_rank: int, + batch_size: int, + dtype: torch.dtype, + custom_ops: str, + model_name: str, + model_class: type[MLAAttentionQuantPatternModel], + backend: AttentionBackendEnum, + dist_init, + monkeypatch, + use_fresh_inductor_cache, +): + """Test MLA AttentionQuantPattern fusion pass""" + if ( + model_class is TestMLAAttentionNvfp4QuantPatternModel + and not is_nvfp4_supported() + ): + pytest.skip("NVFP4 is not supported on this GPU (requires SM 100+).") + + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.manual_seed(42) + + model_config = ModelConfig( + model=model_name, + max_model_len=2048, + dtype=dtype, + ) + vllm_config = VllmConfig( + model_config=model_config, + scheduler_config=SchedulerConfig( + max_num_seqs=1024, + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=custom_ops_list, + ), + cache_config=CacheConfig(cache_dtype="auto"), + attention_config=AttentionConfig(backend=backend), + ) + + # MLA inputs: q(B, N, qk_head_dim), kv_c_normed(B, L), k_pe(B, 1, R) + qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + q = torch.randn(batch_size, num_heads, qk_head_dim, dtype=dtype, device=device) + kv_c_normed = torch.randn(batch_size, kv_lora_rank, dtype=dtype, device=device) + k_pe = torch.randn(batch_size, 1, qk_rope_head_dim, dtype=dtype, device=device) + + # Mark first dimension as dynamic + torch._dynamo.mark_dynamic(q, 0) + torch._dynamo.mark_dynamic(kv_c_normed, 0) + torch._dynamo.mark_dynamic(k_pe, 0) + + # Run model without fusion + vllm_config_unfused = copy.deepcopy(vllm_config) + with ( + set_current_vllm_config(vllm_config_unfused), + set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused), + ): + model_unfused = model_class( + num_heads=num_heads, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_lora_rank=kv_lora_rank, + kv_cache_dtype=dtype, + device=device, + vllm_config=vllm_config_unfused, + ) + model_unfused = model_unfused.to(device) + # HACK: See #131044 + result_unfused_0 = model_unfused(q, kv_c_normed, k_pe) # noqa: F841 + + forward_ctx = get_forward_context() + forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size) + + compiled_unfused = torch.compile(model_unfused, fullgraph=True) + result_unfused = compiled_unfused(q, kv_c_normed, k_pe) + + # Run model with attn fusion enabled + vllm_config.compilation_config.pass_config = PassConfig( + fuse_attn_quant=True, eliminate_noops=True + ) + with ( + set_current_vllm_config(vllm_config), + set_forward_context(attn_metadata=None, vllm_config=vllm_config), + ): + model_fused = model_class( + num_heads=num_heads, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_lora_rank=kv_lora_rank, + kv_cache_dtype=dtype, + device=device, + vllm_config=vllm_config, + w=model_unfused.w, + kv_b_proj_weight=model_unfused.kv_b_proj_weight, + ) + model_fused = model_fused.to(device) + + forward_ctx = get_forward_context() + forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size) + + # Create test backend with fusion passes + noop_pass = NoOpEliminationPass(vllm_config) + attn_pass = LazyInitPass(MLAAttnQuantFusionPass, vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass) + # HACK: See https://github.com/vllm-project/vllm/issues/31044 + result_fused_0 = model_fused(q, kv_c_normed, k_pe) # noqa: F841 + + compiled_fused = torch.compile( + model_fused, backend=test_backend, fullgraph=True + ) + + result_fused = compiled_fused(q, kv_c_normed, k_pe) + + # Check attn fusion support + quant_key: QuantKey = model_class.quant_key + attn_fusion_supported = [ + layer.impl.fused_output_quant_supported(quant_key) + for key, layer in vllm_config.compilation_config.static_forward_context.items() + if isinstance(layer, MLAAttention) + ] + assert sum(attn_fusion_supported) == len(attn_fusion_supported), ( + "All MLA layers should support attention fusion" + ) + + # Check quantization ops in the graph + quant_op = ( + torch.ops.aten.reciprocal + if "-quant_fp8" in custom_ops_list + else QUANT_OPS[quant_key] + ) + test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic) + + assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) + + # Check MLA attention ops in the graph + attn_nodes_pre = list(find_op_nodes(MLA_ATTN_OP, test_backend.graph_pre_pass)) + attn_nodes_post = list(find_op_nodes(MLA_ATTN_OP, test_backend.graph_post_pass)) + + assert len(attn_nodes_pre) > 0, "Should have MLA attention nodes before fusion" + assert len(attn_nodes_pre) == len(attn_nodes_post), ( + "Should have same number of MLA attention nodes before and after fusion" + ) + assert attn_nodes_pre[0].kwargs.get("output_scale") is None, ( + "MLA attention should not have output_scale before fusion" + ) + assert attn_nodes_post[0].kwargs.get("output_scale") is not None, ( + "MLA attention should have output_scale after fusion" + ) + + assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, ( + "MLA attention should not have output_block_scale before fusion" + ) + + if quant_key.dtype == FP8_DTYPE: + assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, ( + "MLA attention should not have output_block_scale after FP8 fusion" + ) + elif quant_key.dtype == FP4_DTYPE: + assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, ( + "MLA attention should have output_block_scale after FP4 fusion" + ) + + # Check numerical correctness + torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2) diff --git a/vllm/compilation/passes/fusion/mla_attn_quant_fusion.py b/vllm/compilation/passes/fusion/mla_attn_quant_fusion.py new file mode 100644 index 000000000..5a9ef46a0 --- /dev/null +++ b/vllm/compilation/passes/fusion/mla_attn_quant_fusion.py @@ -0,0 +1,262 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized + +from vllm._custom_ops import create_fp4_output_tensors +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.logger import init_logger +from vllm.model_executor.layers.attention.mla_attention import MLAAttention +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, + kNvfp4Dynamic, +) +from vllm.platforms import current_platform + +from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement +from .matcher_utils import MatcherQuantFP8 +from .rms_quant_fusion import QUANT_OPS + +logger = init_logger(__name__) + +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + +MLA_ATTN_OP = torch.ops.vllm.unified_mla_attention_with_output.default + + +class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]): + """ + Fusion for MLA Attention+Fp8StaticQuant. + + Matches the pattern: MLA attention -> static FP8 quant, and replaces + it with MLA attention(output_scale=scale, output=fp8_buffer). + """ + + def __init__(self, layer: MLAAttention, dtype: torch.dtype) -> None: + self._layer_name = layer.layer_name + self._num_heads = layer.num_heads + self._v_head_dim = layer.v_head_dim + self._kv_lora_rank = layer.kv_lora_rank + self._qk_rope_head_dim = layer.qk_rope_head_dim + self._qk_head_dim = layer.qk_nope_head_dim + layer.qk_rope_head_dim + self._output_dim = layer.num_heads * layer.v_head_dim + self._dtype = dtype + self._quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) + + @property + def pattern(self) -> Callable[..., torch.Tensor]: + def _pattern( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + output_attn: torch.Tensor, + scale: torch.Tensor, + kv_cache_dummy_dep: torch.Tensor, + ) -> torch.Tensor: + at1 = auto_functionalized( + MLA_ATTN_OP, + q=q, + kv_c_normed=kv_c_normed, + k_pe=k_pe, + output=output_attn, + layer_name=self._layer_name, + output_scale=None, + output_block_scale=None, + kv_cache_dummy_dep=kv_cache_dummy_dep, + ) + # MLA output is already 2D (T, N*V), no reshape needed + return self._quant_matcher(at1[1], scale)[0] + + return _pattern + + @property + def replacement(self) -> Callable[..., torch.Tensor]: + def _replacement( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + output_attn: torch.Tensor, + scale: torch.Tensor, + kv_cache_dummy_dep: torch.Tensor, + ) -> torch.Tensor: + # MLA output in quant_dtype + output_attn = torch.empty( + [q.shape[0], self._output_dim], + dtype=FP8_DTYPE, + device=q.device, + ) + at1 = auto_functionalized( + MLA_ATTN_OP, + q=q, + kv_c_normed=kv_c_normed, + k_pe=k_pe, + output=output_attn, + layer_name=self._layer_name, + output_scale=scale, + output_block_scale=None, + kv_cache_dummy_dep=kv_cache_dummy_dep, + ) + return at1[1] + + return _replacement + + def get_inputs(self) -> list[torch.Tensor]: + return [ + self.empty(5, self._num_heads, self._qk_head_dim, dtype=self._dtype), + self.empty(5, self._kv_lora_rank, dtype=self._dtype), + self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype), + self.empty(5, self._output_dim, dtype=self._dtype), + self.empty_fp32(1, 1), + self.empty(0, dtype=self._dtype), + ] + + +class MLAAttnNvfp4QuantPattern( + VllmPatternReplacement[..., tuple[torch.Tensor, torch.Tensor]] +): + """ + Fusion for MLA Attention+Nvfp4Quant. + + Matches the pattern: MLA attention -> NVFP4 quant, and replaces + it with MLA attention(output_scale=scale, output_block_scale=block_scale, + output=fp4_buffer). + """ + + def __init__(self, layer: MLAAttention, dtype: torch.dtype) -> None: + self._layer_name = layer.layer_name + self._num_heads = layer.num_heads + self._v_head_dim = layer.v_head_dim + self._kv_lora_rank = layer.kv_lora_rank + self._qk_rope_head_dim = layer.qk_rope_head_dim + self._qk_head_dim = layer.qk_nope_head_dim + layer.qk_rope_head_dim + self._output_dim = layer.num_heads * layer.v_head_dim + self._dtype = dtype + self._QUANT_OP = QUANT_OPS[kNvfp4Dynamic] + + @property + def pattern( + self, + ) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: + def _pattern( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + output_attn: torch.Tensor, + input_scale: torch.Tensor, + kv_cache_dummy_dep: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + at1 = auto_functionalized( + MLA_ATTN_OP, + q=q, + kv_c_normed=kv_c_normed, + k_pe=k_pe, + output=output_attn, + layer_name=self._layer_name, + output_scale=None, + output_block_scale=None, + kv_cache_dummy_dep=kv_cache_dummy_dep, + ) + # Replicate what scaled_fp4_quant() does: allocate output + # tensors inline then call the .out variant. + output_quant, output_scale = create_fp4_output_tensors( + at1[1].shape[0], at1[1].shape[1], at1[1].device, True + ) + at2 = auto_functionalized( + self._QUANT_OP, + input=at1[1], + input_scale=input_scale, + is_sf_swizzled_layout=True, + output=output_quant, + output_scale=output_scale, + ) + output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) + return at2[1], output_scale_view + + return _pattern + + @property + def replacement( + self, + ) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: + def _replacement( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + output_attn: torch.Tensor, + input_scale: torch.Tensor, + kv_cache_dummy_dep: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # MLA output in quant_dtype (FP4 packed as uint8) + output_attn = torch.empty( + [q.shape[0], self._output_dim // 2], + dtype=FP4_DTYPE, + device=q.device, + ) + # attention output block scale + output_scale = create_fp4_output_tensors( + q.shape[0], self._output_dim, q.device, True + )[1] + output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE) + at2 = auto_functionalized( + MLA_ATTN_OP, + q=q, + kv_c_normed=kv_c_normed, + k_pe=k_pe, + output=output_attn, + layer_name=self._layer_name, + output_scale=input_scale, + output_block_scale=output_scale_view, + kv_cache_dummy_dep=kv_cache_dummy_dep, + ) + return at2[1], at2[2] + + return _replacement + + def get_inputs(self) -> list[torch.Tensor]: + return [ + self.empty(5, self._num_heads, self._qk_head_dim, dtype=self._dtype), + self.empty(5, self._kv_lora_rank, dtype=self._dtype), + self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype), + self.empty(5, self._output_dim, dtype=self._dtype), + self.empty_fp32(1, 1), + self.empty(0, dtype=self._dtype), + ] + + +class MLAAttnQuantFusionPass(VllmFusionPatternMatcherPass): + """ + This pass fuses post-attention quantization onto MLA attention if supported. + + It uses the pattern matcher and matches each MLA layer manually, as strings + cannot be wildcarded. This also lets us check support on attention layers + upon registration instead of during pattern matching. + """ + + def __init__(self, config: VllmConfig) -> None: + super().__init__(config, "mla_attn_quant_fusion") + + dtype = config.model_config.dtype + layers = list(get_layers_from_vllm_config(config, MLAAttention).values()) + + if len(layers) == 0: + logger.warning( + "MLA attention + quant fusion is enabled, but no MLA " + "attention layers were found in " + "CompilationConfig.static_forward_context " + "so no fusion patterns were registered." + ) + + for layer in layers: + if layer.impl.fused_output_quant_supported(kFp8StaticTensorSym): + self.register(MLAAttnFp8StaticQuantPattern(layer, dtype)) + + if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): + for layer in layers: + if layer.impl.fused_output_quant_supported(kNvfp4Dynamic): + self.register(MLAAttnNvfp4QuantPattern(layer, dtype)) + + self.dump_patterns(config, self.pm_pass) diff --git a/vllm/compilation/passes/pass_manager.py b/vllm/compilation/passes/pass_manager.py index 057174141..b4823a0af 100644 --- a/vllm/compilation/passes/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -27,6 +27,7 @@ if rocm_aiter_ops.is_enabled(): if current_platform.is_cuda_alike(): from .fusion.act_quant_fusion import ActivationQuantFusionPass from .fusion.attn_quant_fusion import AttnQuantFusionPass + from .fusion.mla_attn_quant_fusion import MLAAttnQuantFusionPass from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass from .fusion.rms_quant_fusion import RMSNormQuantFusionPass from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass @@ -157,6 +158,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc] if self.pass_config.fuse_attn_quant: self.passes += [AttnQuantFusionPass(config)] + self.passes += [MLAAttnQuantFusionPass(config)] if self.pass_config.enable_qk_norm_rope_fusion: self.passes += [SplitCoalescingPass(config)] diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 6a089fdfa..716c208a9 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -121,7 +121,7 @@ class PassConfig: fuse_act_quant: bool = None # type: ignore[assignment] """Fuse the custom SiluMul + quant ops.""" fuse_attn_quant: bool = None # type: ignore[assignment] - """Fuse the custom attention + quant ops.""" + """Fuse the custom Attention and MLAAttention + quant ops.""" eliminate_noops: bool = Field(default=True) """Eliminate no-op ops.""" enable_sp: bool = None # type: ignore[assignment] diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 0be46fbbc..699238b48 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -449,6 +449,11 @@ class MLAAttention(nn.Module, AttentionLayerBase): group_shape=GroupShape.PER_TENSOR, compile_native=True, ) + self._quant_fp8_op = QuantFP8( + static=True, + group_shape=GroupShape.PER_TENSOR, + compile_native=True, + ) @property def chunked_prefill_workspace_size(self) -> int: @@ -545,9 +550,19 @@ class MLAAttention(nn.Module, AttentionLayerBase): ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported for MLA" + use_quant = output_scale is not None or output_block_scale is not None + if use_quant: + # The fusion pass has allocated output with quantized dtype + # (FP8 or uint8 for FP4). We can't write into it directly, + # so we swap in a temp buffer for computation, then quantize + # into the real output at the end. + # NOTE(carlyou): this is temporary until kernels support fp8 output + quant_output = output + output = torch.empty( + output.shape[0], + self.num_heads * self.v_head_dim, + dtype=q.dtype, + device=output.device, ) if attn_metadata is None: @@ -567,6 +582,8 @@ class MLAAttention(nn.Module, AttentionLayerBase): # The zero fill is required when used with DP + EP # to ensure all ranks within a DP group compute the # same expert outputs. + if use_quant: + return quant_output.fill_(0) return output.fill_(0) if self.impl.dcp_world_size == -1: @@ -706,6 +723,21 @@ class MLAAttention(nn.Module, AttentionLayerBase): # v_up projection self._v_up_proj(attn_out, out=mqa_output_slice) + + if use_quant: + # Quantize the BF16 computation result into the quantized output + actual = output[:num_actual_toks] + if output_block_scale is not None: + # NVFP4: two FP4 values packed into one uint8 + fp4_data, fp4_scales = ops.scaled_fp4_quant(actual, output_scale) + quant_output[:num_actual_toks].copy_(fp4_data) + output_block_scale.copy_(fp4_scales) + else: + # Static FP8 quantization + fp8_data, _ = self._quant_fp8_op(actual, output_scale) + quant_output[:num_actual_toks].copy_(fp8_data) + return quant_output + return output_padded def process_weights_after_loading(self, act_dtype: torch.dtype): @@ -2069,6 +2101,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): understand this class """ + def fused_output_quant_supported(self, quant_key): + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, + kNvfp4Dynamic, + ) + + return quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic) + def __init__( self, num_heads: int, @@ -2513,8 +2553,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): if hasattr(self.kv_b_proj, "weight") else self.kv_b_proj.params_dtype ) - if use_fp8_prefill or _kv_b_proj_w_dtype != current_platform.fp8_dtype(): - kv_c_normed = kv_c_normed.to(_kv_b_proj_w_dtype) + # For NVFP4, weights are packed uint8 — keep input in model dtype + # since the NVFP4 linear layer quantizes internally. + if ( + use_fp8_prefill or _kv_b_proj_w_dtype != current_platform.fp8_dtype() + ) and _kv_b_proj_w_dtype != torch.uint8: + kv_c_normed = kv_c_normed.to(self.kv_b_proj.weight.dtype) k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) kv_nope = self.kv_b_proj(kv_c_normed)[0].view( diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 4663cb71d..bb05b31bb 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -10,6 +10,11 @@ import numpy as np import torch from typing_extensions import deprecated +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, + kNvfp4Dynamic, +) + if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.config.cache import CacheDType @@ -873,6 +878,14 @@ class MLAAttentionImpl(AttentionImplBase[T], Generic[T]): """MQA-style decode forward pass.""" raise NotImplementedError + def fused_output_quant_supported(self, quant_key: "QuantKey"): + """ + Does this attention implementation support fused output quantization. + Since MLA quantization is done manually in forward_impl (common code), + all MLA backends support it by default. + """ + return quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic) + def do_kv_cache_update( self, kv_c_normed: torch.Tensor, @@ -903,6 +916,14 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]): They do not support prefill (MHA-style) attention. """ + def fused_output_quant_supported(self, quant_key: "QuantKey"): + """ + Does this attention implementation support fused output quantization. + Since MLA quantization is done manually in forward_impl (common code), + all MLA backends support it by default. + """ + return quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic) + @abstractmethod def __init__( self,