[mla] Support fused FP8/NVFP4 output quantization in MLA attention (#35792) (#36205)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl Y <4531192+carlyou@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Carl Y
2026-04-02 18:16:11 -07:00
committed by GitHub
parent ee3cf45739
commit 1f5ec2889c
12 changed files with 928 additions and 17 deletions

View File

@@ -72,6 +72,7 @@ steps:
- vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes
- tests/compile/passes/test_fusion_attn.py
- tests/compile/passes/test_mla_attn_quant_fusion.py
- tests/compile/passes/test_silu_mul_quant_fusion.py
- tests/compile/passes/distributed/test_fusion_all_reduce.py
- tests/compile/fullgraph/test_full_graph.py
@@ -79,6 +80,7 @@ steps:
# b200 runners are limited, so we limit the tests to the minimum set only supported on Blackwell
- nvidia-smi
- pytest -v -s tests/compile/passes/test_fusion_attn.py -k FLASHINFER
- pytest -v -s tests/compile/passes/test_mla_attn_quant_fusion.py
- pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py
# this runner has 2 GPUs available even though num_devices=2 is not set
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py

View File

@@ -22,6 +22,7 @@ or just on the low or high end.
| ------------------------------------------------------------------------------ | ---------------------------- | ---------------------------------------------- | ------------------------------ | ------------------ | --------- | ------------ |
| [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms) | `fuse_allreduce_rms` | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20% | No | Low |
| [Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | Attention output → FP8/NVFP4 quant | Off by default | 3-7% | Yes | Always |
| [MLA Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | MLA Attention output → FP8/NVFP4 quant | Off by default | TBD | Yes | Always |
| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O2 (ROCm/AITER only) | 2-4% | No | Low |
| [QK Norm + RoPE](#qk-norm--rope-enable_qk_norm_rope_fusion) | `enable_qk_norm_rope_fusion` | Q/K RMSNorm → rotary embedding | Off by default | 2-3% | No | Low |
| [Sequence Parallelism](#sequence-parallelism-enable_sp) | `enable_sp` | AllReduce → ReduceScatter + AllGather | Off by default | Prereq for AsyncTP | Yes | High |
@@ -40,6 +41,7 @@ The table below lists the quantization schemes supported by each fusion on each
| ---------------------------- | ---------------------------------------- | ---------------------------------------- | ---------------------------------------- | ------------- | ---------------------------------------- |
| `fuse_allreduce_rms` | FP16/BF16, FP8 static, NVFP4 | FP16/BF16, FP8 static | — | — | — |
| `fuse_attn_quant`\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static\* |
| `fuse_attn_quant` (MLA)\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static(untested)\* |
| `fuse_rope_kvcache` | — | — | — | — | FP16/BF16 |
| `enable_qk_norm_rope_fusion` | FP16/BF16 | FP16/BF16 | FP16/BF16† | FP16/BF16† | — |
| `enable_sp` | FP16/BF16, FP8 static† | FP16/BF16, FP8 static | FP16/BF16† | FP16/BF16† | — |
@@ -129,7 +131,8 @@ on SM90/SM100) and configurable via `PassConfig.fi_allreduce_fusion_max_size_mb`
explicitly. It requires the full model graph to be visible (Inductor partition or `splitting_ops=[]`).
**What it fuses.** Fuses the attention output quantization directly after the attention computation,
eliminating a full-precision memory round-trip of the attention output. Patterns covered:
eliminating a full-precision memory round-trip of the attention output. This fusion supports both
standard `Attention` and `MLAAttention` (used by DeepSeek-V2/V3/R1 models). Patterns covered:
`Attention → FP8 static quant`:
@@ -142,11 +145,24 @@ eliminating a full-precision memory round-trip of the attention output. Patterns
- `FLASHINFER`: CUDA sm100+ with FlashInfer installed
`MLAAttention → FP8 static quant` / `MLAAttention → NVFP4 dynamic quant`:
The MLA fusion operates at the graph level on the `unified_mla_attention_with_output` op and works
with all MLA decode and prefill backend combinations. Unlike standard `Attention` backends (where
the kernel writes FP8 output directly), no MLA prefill or decode backend currently supports direct
FP8/FP4 output. The fusion writes to an intermediate buffer and quantizes in a separate step, so
there is no memory round-trip elimination yet.
!!! info
The MLA attention fusion is not expected to yield a measurable speedup yet.
This will improve once MLA prefill/decode kernels support direct FP8/FP4 output.
Other attention backends do not support fused output quantization yet.
**Code locations.**
- Pass: [`vllm/compilation/passes/fusion/attn_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/attn_quant_fusion.py)
- Pass (Attention): [`vllm/compilation/passes/fusion/attn_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/attn_quant_fusion.py)
- Pass (MLAAttention): [`vllm/compilation/passes/fusion/mla_attn_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/mla_attn_quant_fusion.py)
- Attention backends: [`vllm/v1/attention/backends/`](https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/)
### RoPE + KV-Cache Update (`fuse_rope_kvcache`)

View File

@@ -84,10 +84,14 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
rocm_aiter_ops.refresh_env_variables()
# Filter here to reduce code duplication
backend_name = attn_backend.backend.name.lower()
requires_mla = "deepseek" in model_name.lower()
is_mla = "mla" in attn_backend.backend.name.lower()
is_mla = "mla" in backend_name
# DeepSeek V3.2 uses sparse MLA
requires_sparse = "v3.2" in model_name.lower()
is_sparse = "sparse" in backend_name
if requires_mla != is_mla:
if requires_mla != is_mla or requires_sparse != is_sparse:
pytest.skip(
f"Incompatible model '{model_name}' and "
f"attention backend '{attn_backend.backend.name}'"
@@ -231,7 +235,9 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
)
elif match_name == "attn_quant_fusion":
actual_match = match_table.get(match_name, 0)
actual_match = match_table.get(
"attn_quant_fusion", 0
) + match_table.get("mla_attn_quant_fusion", 0)
assert actual_match == expected_matches * n_expected, (
f"Could not find {expected_matches * n_expected} "
f"{match_name} (found {actual_match})."

View File

@@ -58,6 +58,15 @@ TRITON_MLA_ATTN = pytest.param(
id="TRITON_MLA",
)
FLASHMLA_SPARSE_ATTN = pytest.param(
AttentionBackendCase(backend=AttentionBackendEnum.FLASHMLA_SPARSE),
id="FLASHMLA_SPARSE",
marks=pytest.mark.skipif(
not is_blackwell(),
reason="FlashMLA Sparse requires Blackwell",
),
)
# Models
llama3_8b = ModelFusionInfo(
model_name="meta-llama/Llama-3.1-8B-Instruct",
@@ -141,6 +150,18 @@ qwen3_a3b_fp8 = ModelFusionInfo(
),
)
deepseek_coder_v2_lite_fp8 = ModelFusionInfo(
model_name="RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8",
matches=lambda n_layers: Matches(
# first_k_dense_replace=1; MoE hides most rms+quant sites
rms_quant_fusion=1,
act_quant_fusion=min(1, n_layers), # dense layers only
# MLA attn + static FP8 quant
attn_quant_fusion=n_layers,
ar_rms_fusion=n_layers * 2 + 1,
),
)
deepseek_v3_fp8 = ModelFusionInfo(
model_name="deepseek-ai/DeepSeek-V3",
matches=lambda n_layers: Matches(
@@ -152,7 +173,7 @@ deepseek_v3_fp8 = ModelFusionInfo(
rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers
# silu+block quant
act_quant_fusion=min(3, n_layers), # dense layers only
# MLA attn + quant not supported yet:
# MLA attn + per-group FP8 quant not supported yet:
# https://github.com/vllm-project/vllm/issues/35792
attn_quant_fusion=0,
ar_rms_fusion=n_layers * 2 + 1,
@@ -162,6 +183,16 @@ deepseek_v3_fp8 = ModelFusionInfo(
),
)
deepseek_v32_fp4 = ModelFusionInfo(
model_name="nvidia/DeepSeek-V3.2-NVFP4",
matches=lambda n_layers: Matches(
rms_quant_fusion=0,
act_quant_fusion=0,
attn_quant_fusion=n_layers,
ar_rms_fusion=n_layers * 2 + 1,
),
)
gpt_oss_20b = ModelFusionInfo(
model_name="openai/gpt-oss-20b",
matches=lambda n_layers: Matches(

View File

@@ -18,11 +18,14 @@ from .common import (
from .models import (
FLASHINFER_ATTN,
FLASHINFER_MLA_ATTN,
FLASHMLA_SPARSE_ATTN,
ROCM_AITER_UNIFIED_ATTN,
ROCM_ATTN,
TRITON_ATTN,
TRITON_MLA_ATTN,
deepseek_coder_v2_lite_fp8,
deepseek_v3_fp8,
deepseek_v32_fp4,
llama3_8b_fp4,
llama3_8b_fp8,
llama4_scout_fp4,
@@ -37,6 +40,7 @@ from .models import (
(*llama3_8b_fp8, False),
(*qwen3_a3b_fp8, False),
(*qwen3_a3b_fp8, True),
(*deepseek_coder_v2_lite_fp8, False),
(*deepseek_v3_fp8, False),
(*deepseek_v3_fp8, True),
pytest.param(
@@ -144,9 +148,12 @@ def test_tp1_fp8_fusions(
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b_fp4, llama4_scout_fp4],
[llama3_8b_fp4, llama4_scout_fp4, deepseek_v32_fp4],
)
@pytest.mark.parametrize(
"attn_backend",
[FLASHINFER_ATTN, FLASHMLA_SPARSE_ATTN],
)
@pytest.mark.parametrize("attn_backend", [FLASHINFER_ATTN])
@pytest.mark.parametrize("n_layers", [6])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)

View File

@@ -18,8 +18,11 @@ from .common import (
from .models import (
FLASHINFER_ATTN,
FLASHINFER_MLA_ATTN,
FLASHMLA_SPARSE_ATTN,
TRITON_ATTN,
deepseek_coder_v2_lite_fp8,
deepseek_v3_fp8,
deepseek_v32_fp4,
gpt_oss_20b,
llama3_8b,
llama3_8b_fp4,
@@ -37,7 +40,13 @@ pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only tes
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
# qwen3 & dsv3 should still fuse AR+rms even though group quant is not yet supported
[llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8, deepseek_v3_fp8],
[
llama3_8b_fp8,
llama4_scout_fp8,
qwen3_a3b_fp8,
deepseek_coder_v2_lite_fp8,
deepseek_v3_fp8,
],
)
@pytest.mark.parametrize(
"attn_backend", [TRITON_ATTN, FLASHINFER_ATTN, FLASHINFER_MLA_ATTN]
@@ -104,9 +113,12 @@ def test_tp2_ar_rms_fp8_fusions(
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b_fp4, llama4_scout_fp4],
[llama3_8b_fp4, llama4_scout_fp4, deepseek_v32_fp4],
)
@pytest.mark.parametrize(
"attn_backend",
[FLASHINFER_ATTN, FLASHMLA_SPARSE_ATTN],
)
@pytest.mark.parametrize("attn_backend", [FLASHINFER_ATTN])
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)

View File

@@ -0,0 +1,508 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import pytest
import torch._dynamo
from tests.compile.backend import LazyInitPass, TestBackend
from tests.utils import TestFP8Layer, flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
from vllm.compilation.passes.fusion.mla_attn_quant_fusion import (
MLA_ATTN_OP,
MLAAttnQuantFusionPass,
)
from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.config import (
AttentionConfig,
CacheConfig,
CompilationConfig,
CompilationMode,
ModelConfig,
PassConfig,
SchedulerConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.attention import MLAAttention
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.modelopt import ModelOptNvFp4Config
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
kNvfp4Dynamic,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import MLAAttentionSpec
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
class MLAAttentionQuantPatternModel(torch.nn.Module):
"""Base model for MLA AttentionQuantPattern fusion."""
def __init__(
self,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
kv_lora_rank: int,
kv_cache_dtype: torch.dtype,
device: torch.device,
vllm_config: VllmConfig,
**kwargs,
):
super().__init__()
self.num_heads = num_heads
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.kv_lora_rank = kv_lora_rank
self.output_dim = num_heads * v_head_dim
self.head_size = kv_lora_rank + qk_rope_head_dim
self.kv_cache_dtype = kv_cache_dtype
self.device = device
self.vllm_config = vllm_config
# Create kv_b_proj (ColumnParallelLinear) on device.
# Reuse weights from prior model instance when available, because
# ColumnParallelLinear may get NaN from recycled CUDA memory after
# torch.compile runs in the same process.
kv_b_proj = ColumnParallelLinear(
input_size=kv_lora_rank,
output_size=num_heads * (qk_nope_head_dim + v_head_dim),
bias=False,
prefix="model.layers.0.self_attn.kv_b_proj",
).to(device)
kv_b_proj_weight = kwargs.get("kv_b_proj_weight")
if kv_b_proj_weight is not None:
kv_b_proj.weight.data.copy_(kv_b_proj_weight)
elif kv_b_proj.weight.data.isnan().any():
# Sanitize NaN from recycled CUDA memory
kv_b_proj.weight.data.normal_()
# Create MLAAttention
self.mla_attn = MLAAttention(
num_heads=num_heads,
scale=1.0 / (self.qk_head_dim**0.5),
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
kv_b_proj=kv_b_proj,
cache_config=vllm_config.cache_config,
quant_config=self.quant_config,
prefix="model.layers.0.self_attn.attn",
)
self.mla_attn._k_scale = self.mla_attn._k_scale.to(device)
self.mla_attn._v_scale = self.mla_attn._v_scale.to(device)
# Initialize W_UK_T and W_UV from kv_b_proj weights
self.mla_attn.process_weights_after_loading(torch.get_default_dtype())
self.kv_b_proj_weight = kv_b_proj.weight.data.clone()
self.block_size = 16
# Initialize MLA MetadataBuilder
self.builder = self.mla_attn.attn_backend.get_builder_cls()(
kv_cache_spec=MLAAttentionSpec(
block_size=self.block_size,
num_kv_heads=1,
head_size=self.head_size,
dtype=self.kv_cache_dtype,
),
layer_names=[self.mla_attn.layer_name],
vllm_config=self.vllm_config,
device=self.device,
)
def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
"""Initialize MLA attention metadata.
NOTE: Uses decode-only batch (query_len=1 per request). The prefill
(forward_mha) path is not separately tested here because it requires
FlashAttention availability and different input tensor shapes. The
quant logic in forward_impl is identical for both paths — it quantizes
the full output[:num_actual_toks] buffer after both forward_mha and
forward_mqa have written their results.
"""
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
common_attn_metadata = create_common_attn_metadata(
batch_spec, self.block_size, self.device, arange_block_indices=True
)
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks
# MLA KV cache is 3D: (num_blocks, block_size, head_size)
attn_backend = self.mla_attn.attn_backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, 1, self.head_size
)
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
ordered_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
]
raw_tensor = torch.zeros(
ordered_shape, dtype=self.kv_cache_dtype, device=self.device
)
kv_cache = raw_tensor.permute(*inv_order)
self.mla_attn.kv_cache = kv_cache
self.attn_metadata = self.builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
return self.attn_metadata
class TestMLAAttentionFp8StaticQuantPatternModel(MLAAttentionQuantPatternModel):
"""Test model for MLA Attention + FP8 static quant fusion."""
quant_key = kFp8StaticTensorSym
quant_config = Fp8Config()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fp8_linear = TestFP8Layer(
weight_shape=(self.output_dim, self.output_dim),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
device=self.device,
)
w = kwargs.get("w")
if w is not None:
self.fp8_linear.weight = w["weight"]
self.fp8_linear.weight_scale = w["wscale"]
self.fp8_linear.input_scale = w["scale"]
self.w = {
"weight": self.fp8_linear.weight,
"wscale": self.fp8_linear.weight_scale,
"scale": self.fp8_linear.input_scale,
}
def forward(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
):
"""Forward pass that creates the MLA attention + FP8 quant pattern."""
attn_output = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(q.shape[0], self.output_dim),
)
return self.fp8_linear(attn_output)
class TestMLAAttentionNvfp4QuantPatternModel(MLAAttentionQuantPatternModel):
"""Test model for MLA Attention + NVFP4 quant fusion."""
quant_key = kNvfp4Dynamic
quant_config = ModelOptNvFp4Config(
is_checkpoint_nvfp4_serialized=False,
kv_cache_quant_algo=None,
exclude_modules=[],
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.w = kwargs.get(
"w",
{
"weight": torch.randint(
256,
(self.output_dim, self.output_dim // 2),
dtype=FP4_DTYPE,
device=self.device,
),
"wscale_swizzled": torch.randn(
self.output_dim, self.output_dim // 16
).to(dtype=FP8_DTYPE, device=self.device),
"wscale": torch.tensor([500], dtype=torch.float32, device=self.device),
"scale": torch.tensor([0.002], dtype=torch.float32, device=self.device),
},
)
def forward(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
):
"""Forward pass that creates the MLA attention + NVFP4 quant pattern."""
attn_output = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(q.shape[0], self.output_dim),
)
quant_output, output_block_scale = scaled_fp4_quant(
attn_output, 1 / self.w["scale"]
)
return cutlass_scaled_fp4_mm(
a=quant_output,
b=self.w["weight"],
block_scale_a=output_block_scale,
block_scale_b=self.w["wscale_swizzled"],
alpha=self.w["scale"] * self.w["wscale"],
out_dtype=attn_output.dtype,
)
def is_nvfp4_supported():
return current_platform.has_device_capability(100)
# MLA test configuration
MLA_DIMS: list[tuple[int, int, int, int, int]] = []
PATTERN_TEST_MODELS_MLA_FP8: list[tuple[str, type]] = []
PATTERN_TEST_MODELS_MLA_FP4: list[tuple[str, type]] = []
BACKENDS_MLA_FP8: list[AttentionBackendEnum] = []
BACKENDS_MLA_FP4: list[AttentionBackendEnum] = []
if current_platform.is_cuda():
# (num_heads, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, kv_lora_rank)
MLA_DIMS = [(16, 128, 64, 128, 512)]
PATTERN_TEST_MODELS_MLA_FP8 = [
(
"deepseek-ai/DeepSeek-V2-Lite",
TestMLAAttentionFp8StaticQuantPatternModel,
)
]
PATTERN_TEST_MODELS_MLA_FP4 = [
(
"deepseek-ai/DeepSeek-V2-Lite",
TestMLAAttentionNvfp4QuantPatternModel,
)
]
BACKENDS_MLA_FP8 = [AttentionBackendEnum.TRITON_MLA]
BACKENDS_MLA_FP4 = [AttentionBackendEnum.TRITON_MLA]
@pytest.mark.parametrize(
"num_heads, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, kv_lora_rank",
MLA_DIMS,
)
@pytest.mark.parametrize("batch_size", [7, 256] if current_platform.is_cuda() else [8])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize(
"backend, model_name, model_class, custom_ops",
list(
flat_product(
BACKENDS_MLA_FP8,
PATTERN_TEST_MODELS_MLA_FP8,
["+quant_fp8", "-quant_fp8"],
)
)
+ list(flat_product(BACKENDS_MLA_FP4, PATTERN_TEST_MODELS_MLA_FP4, [""])),
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
def test_mla_attention_quant_pattern(
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
kv_lora_rank: int,
batch_size: int,
dtype: torch.dtype,
custom_ops: str,
model_name: str,
model_class: type[MLAAttentionQuantPatternModel],
backend: AttentionBackendEnum,
dist_init,
monkeypatch,
use_fresh_inductor_cache,
):
"""Test MLA AttentionQuantPattern fusion pass"""
if (
model_class is TestMLAAttentionNvfp4QuantPatternModel
and not is_nvfp4_supported()
):
pytest.skip("NVFP4 is not supported on this GPU (requires SM 100+).")
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
custom_ops_list = custom_ops.split(",") if custom_ops else []
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.manual_seed(42)
model_config = ModelConfig(
model=model_name,
max_model_len=2048,
dtype=dtype,
)
vllm_config = VllmConfig(
model_config=model_config,
scheduler_config=SchedulerConfig(
max_num_seqs=1024,
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops_list,
),
cache_config=CacheConfig(cache_dtype="auto"),
attention_config=AttentionConfig(backend=backend),
)
# MLA inputs: q(B, N, qk_head_dim), kv_c_normed(B, L), k_pe(B, 1, R)
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
q = torch.randn(batch_size, num_heads, qk_head_dim, dtype=dtype, device=device)
kv_c_normed = torch.randn(batch_size, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(batch_size, 1, qk_rope_head_dim, dtype=dtype, device=device)
# Mark first dimension as dynamic
torch._dynamo.mark_dynamic(q, 0)
torch._dynamo.mark_dynamic(kv_c_normed, 0)
torch._dynamo.mark_dynamic(k_pe, 0)
# Run model without fusion
vllm_config_unfused = copy.deepcopy(vllm_config)
with (
set_current_vllm_config(vllm_config_unfused),
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
):
model_unfused = model_class(
num_heads=num_heads,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_lora_rank=kv_lora_rank,
kv_cache_dtype=dtype,
device=device,
vllm_config=vllm_config_unfused,
)
model_unfused = model_unfused.to(device)
# HACK: See #131044
result_unfused_0 = model_unfused(q, kv_c_normed, k_pe) # noqa: F841
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
compiled_unfused = torch.compile(model_unfused, fullgraph=True)
result_unfused = compiled_unfused(q, kv_c_normed, k_pe)
# Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig(
fuse_attn_quant=True, eliminate_noops=True
)
with (
set_current_vllm_config(vllm_config),
set_forward_context(attn_metadata=None, vllm_config=vllm_config),
):
model_fused = model_class(
num_heads=num_heads,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_lora_rank=kv_lora_rank,
kv_cache_dtype=dtype,
device=device,
vllm_config=vllm_config,
w=model_unfused.w,
kv_b_proj_weight=model_unfused.kv_b_proj_weight,
)
model_fused = model_fused.to(device)
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
# Create test backend with fusion passes
noop_pass = NoOpEliminationPass(vllm_config)
attn_pass = LazyInitPass(MLAAttnQuantFusionPass, vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
# HACK: See https://github.com/vllm-project/vllm/issues/31044
result_fused_0 = model_fused(q, kv_c_normed, k_pe) # noqa: F841
compiled_fused = torch.compile(
model_fused, backend=test_backend, fullgraph=True
)
result_fused = compiled_fused(q, kv_c_normed, k_pe)
# Check attn fusion support
quant_key: QuantKey = model_class.quant_key
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key)
for key, layer in vllm_config.compilation_config.static_forward_context.items()
if isinstance(layer, MLAAttention)
]
assert sum(attn_fusion_supported) == len(attn_fusion_supported), (
"All MLA layers should support attention fusion"
)
# Check quantization ops in the graph
quant_op = (
torch.ops.aten.reciprocal
if "-quant_fp8" in custom_ops_list
else QUANT_OPS[quant_key]
)
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic)
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
# Check MLA attention ops in the graph
attn_nodes_pre = list(find_op_nodes(MLA_ATTN_OP, test_backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(MLA_ATTN_OP, test_backend.graph_post_pass))
assert len(attn_nodes_pre) > 0, "Should have MLA attention nodes before fusion"
assert len(attn_nodes_pre) == len(attn_nodes_post), (
"Should have same number of MLA attention nodes before and after fusion"
)
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
"MLA attention should not have output_scale before fusion"
)
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
"MLA attention should have output_scale after fusion"
)
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
"MLA attention should not have output_block_scale before fusion"
)
if quant_key.dtype == FP8_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
"MLA attention should not have output_block_scale after FP8 fusion"
)
elif quant_key.dtype == FP4_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
"MLA attention should have output_block_scale after FP4 fusion"
)
# Check numerical correctness
torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2)

View File

@@ -0,0 +1,262 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm._custom_ops import create_fp4_output_tensors
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import MLAAttention
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
kNvfp4Dynamic,
)
from vllm.platforms import current_platform
from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement
from .matcher_utils import MatcherQuantFP8
from .rms_quant_fusion import QUANT_OPS
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
MLA_ATTN_OP = torch.ops.vllm.unified_mla_attention_with_output.default
class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
"""
Fusion for MLA Attention+Fp8StaticQuant.
Matches the pattern: MLA attention -> static FP8 quant, and replaces
it with MLA attention(output_scale=scale, output=fp8_buffer).
"""
def __init__(self, layer: MLAAttention, dtype: torch.dtype) -> None:
self._layer_name = layer.layer_name
self._num_heads = layer.num_heads
self._v_head_dim = layer.v_head_dim
self._kv_lora_rank = layer.kv_lora_rank
self._qk_rope_head_dim = layer.qk_rope_head_dim
self._qk_head_dim = layer.qk_nope_head_dim + layer.qk_rope_head_dim
self._output_dim = layer.num_heads * layer.v_head_dim
self._dtype = dtype
self._quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
@property
def pattern(self) -> Callable[..., torch.Tensor]:
def _pattern(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor:
at1 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=self._layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
# MLA output is already 2D (T, N*V), no reshape needed
return self._quant_matcher(at1[1], scale)[0]
return _pattern
@property
def replacement(self) -> Callable[..., torch.Tensor]:
def _replacement(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor:
# MLA output in quant_dtype
output_attn = torch.empty(
[q.shape[0], self._output_dim],
dtype=FP8_DTYPE,
device=q.device,
)
at1 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=self._layer_name,
output_scale=scale,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return at1[1]
return _replacement
def get_inputs(self) -> list[torch.Tensor]:
return [
self.empty(5, self._num_heads, self._qk_head_dim, dtype=self._dtype),
self.empty(5, self._kv_lora_rank, dtype=self._dtype),
self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype),
self.empty(5, self._output_dim, dtype=self._dtype),
self.empty_fp32(1, 1),
self.empty(0, dtype=self._dtype),
]
class MLAAttnNvfp4QuantPattern(
VllmPatternReplacement[..., tuple[torch.Tensor, torch.Tensor]]
):
"""
Fusion for MLA Attention+Nvfp4Quant.
Matches the pattern: MLA attention -> NVFP4 quant, and replaces
it with MLA attention(output_scale=scale, output_block_scale=block_scale,
output=fp4_buffer).
"""
def __init__(self, layer: MLAAttention, dtype: torch.dtype) -> None:
self._layer_name = layer.layer_name
self._num_heads = layer.num_heads
self._v_head_dim = layer.v_head_dim
self._kv_lora_rank = layer.kv_lora_rank
self._qk_rope_head_dim = layer.qk_rope_head_dim
self._qk_head_dim = layer.qk_nope_head_dim + layer.qk_rope_head_dim
self._output_dim = layer.num_heads * layer.v_head_dim
self._dtype = dtype
self._QUANT_OP = QUANT_OPS[kNvfp4Dynamic]
@property
def pattern(
self,
) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
def _pattern(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_attn: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=self._layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
# Replicate what scaled_fp4_quant() does: allocate output
# tensors inline then call the .out variant.
output_quant, output_scale = create_fp4_output_tensors(
at1[1].shape[0], at1[1].shape[1], at1[1].device, True
)
at2 = auto_functionalized(
self._QUANT_OP,
input=at1[1],
input_scale=input_scale,
is_sf_swizzled_layout=True,
output=output_quant,
output_scale=output_scale,
)
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view
return _pattern
@property
def replacement(
self,
) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
def _replacement(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_attn: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# MLA output in quant_dtype (FP4 packed as uint8)
output_attn = torch.empty(
[q.shape[0], self._output_dim // 2],
dtype=FP4_DTYPE,
device=q.device,
)
# attention output block scale
output_scale = create_fp4_output_tensors(
q.shape[0], self._output_dim, q.device, True
)[1]
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
at2 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=self._layer_name,
output_scale=input_scale,
output_block_scale=output_scale_view,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return at2[1], at2[2]
return _replacement
def get_inputs(self) -> list[torch.Tensor]:
return [
self.empty(5, self._num_heads, self._qk_head_dim, dtype=self._dtype),
self.empty(5, self._kv_lora_rank, dtype=self._dtype),
self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype),
self.empty(5, self._output_dim, dtype=self._dtype),
self.empty_fp32(1, 1),
self.empty(0, dtype=self._dtype),
]
class MLAAttnQuantFusionPass(VllmFusionPatternMatcherPass):
"""
This pass fuses post-attention quantization onto MLA attention if supported.
It uses the pattern matcher and matches each MLA layer manually, as strings
cannot be wildcarded. This also lets us check support on attention layers
upon registration instead of during pattern matching.
"""
def __init__(self, config: VllmConfig) -> None:
super().__init__(config, "mla_attn_quant_fusion")
dtype = config.model_config.dtype
layers = list(get_layers_from_vllm_config(config, MLAAttention).values())
if len(layers) == 0:
logger.warning(
"MLA attention + quant fusion is enabled, but no MLA "
"attention layers were found in "
"CompilationConfig.static_forward_context "
"so no fusion patterns were registered."
)
for layer in layers:
if layer.impl.fused_output_quant_supported(kFp8StaticTensorSym):
self.register(MLAAttnFp8StaticQuantPattern(layer, dtype))
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
for layer in layers:
if layer.impl.fused_output_quant_supported(kNvfp4Dynamic):
self.register(MLAAttnNvfp4QuantPattern(layer, dtype))
self.dump_patterns(config, self.pm_pass)

View File

@@ -27,6 +27,7 @@ if rocm_aiter_ops.is_enabled():
if current_platform.is_cuda_alike():
from .fusion.act_quant_fusion import ActivationQuantFusionPass
from .fusion.attn_quant_fusion import AttnQuantFusionPass
from .fusion.mla_attn_quant_fusion import MLAAttnQuantFusionPass
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
@@ -157,6 +158,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
if self.pass_config.fuse_attn_quant:
self.passes += [AttnQuantFusionPass(config)]
self.passes += [MLAAttnQuantFusionPass(config)]
if self.pass_config.enable_qk_norm_rope_fusion:
self.passes += [SplitCoalescingPass(config)]

View File

@@ -121,7 +121,7 @@ class PassConfig:
fuse_act_quant: bool = None # type: ignore[assignment]
"""Fuse the custom SiluMul + quant ops."""
fuse_attn_quant: bool = None # type: ignore[assignment]
"""Fuse the custom attention + quant ops."""
"""Fuse the custom Attention and MLAAttention + quant ops."""
eliminate_noops: bool = Field(default=True)
"""Eliminate no-op ops."""
enable_sp: bool = None # type: ignore[assignment]

View File

@@ -449,6 +449,11 @@ class MLAAttention(nn.Module, AttentionLayerBase):
group_shape=GroupShape.PER_TENSOR,
compile_native=True,
)
self._quant_fp8_op = QuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
compile_native=True,
)
@property
def chunked_prefill_workspace_size(self) -> int:
@@ -545,9 +550,19 @@ class MLAAttention(nn.Module, AttentionLayerBase):
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for MLA"
use_quant = output_scale is not None or output_block_scale is not None
if use_quant:
# The fusion pass has allocated output with quantized dtype
# (FP8 or uint8 for FP4). We can't write into it directly,
# so we swap in a temp buffer for computation, then quantize
# into the real output at the end.
# NOTE(carlyou): this is temporary until kernels support fp8 output
quant_output = output
output = torch.empty(
output.shape[0],
self.num_heads * self.v_head_dim,
dtype=q.dtype,
device=output.device,
)
if attn_metadata is None:
@@ -567,6 +582,8 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
if use_quant:
return quant_output.fill_(0)
return output.fill_(0)
if self.impl.dcp_world_size == -1:
@@ -706,6 +723,21 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# v_up projection
self._v_up_proj(attn_out, out=mqa_output_slice)
if use_quant:
# Quantize the BF16 computation result into the quantized output
actual = output[:num_actual_toks]
if output_block_scale is not None:
# NVFP4: two FP4 values packed into one uint8
fp4_data, fp4_scales = ops.scaled_fp4_quant(actual, output_scale)
quant_output[:num_actual_toks].copy_(fp4_data)
output_block_scale.copy_(fp4_scales)
else:
# Static FP8 quantization
fp8_data, _ = self._quant_fp8_op(actual, output_scale)
quant_output[:num_actual_toks].copy_(fp8_data)
return quant_output
return output_padded
def process_weights_after_loading(self, act_dtype: torch.dtype):
@@ -2069,6 +2101,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
understand this class
"""
def fused_output_quant_supported(self, quant_key):
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
kNvfp4Dynamic,
)
return quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
def __init__(
self,
num_heads: int,
@@ -2513,8 +2553,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if hasattr(self.kv_b_proj, "weight")
else self.kv_b_proj.params_dtype
)
if use_fp8_prefill or _kv_b_proj_w_dtype != current_platform.fp8_dtype():
kv_c_normed = kv_c_normed.to(_kv_b_proj_w_dtype)
# For NVFP4, weights are packed uint8 — keep input in model dtype
# since the NVFP4 linear layer quantizes internally.
if (
use_fp8_prefill or _kv_b_proj_w_dtype != current_platform.fp8_dtype()
) and _kv_b_proj_w_dtype != torch.uint8:
kv_c_normed = kv_c_normed.to(self.kv_b_proj.weight.dtype)
k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(

View File

@@ -10,6 +10,11 @@ import numpy as np
import torch
from typing_extensions import deprecated
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
kNvfp4Dynamic,
)
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
@@ -873,6 +878,14 @@ class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""MQA-style decode forward pass."""
raise NotImplementedError
def fused_output_quant_supported(self, quant_key: "QuantKey"):
"""
Does this attention implementation support fused output quantization.
Since MLA quantization is done manually in forward_impl (common code),
all MLA backends support it by default.
"""
return quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
def do_kv_cache_update(
self,
kv_c_normed: torch.Tensor,
@@ -903,6 +916,14 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
They do not support prefill (MHA-style) attention.
"""
def fused_output_quant_supported(self, quant_key: "QuantKey"):
"""
Does this attention implementation support fused output quantization.
Since MLA quantization is done manually in forward_impl (common code),
all MLA backends support it by default.
"""
return quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
@abstractmethod
def __init__(
self,