Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com> Signed-off-by: Carl Y <4531192+carlyou@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -72,6 +72,7 @@ steps:
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes
|
||||
- tests/compile/passes/test_fusion_attn.py
|
||||
- tests/compile/passes/test_mla_attn_quant_fusion.py
|
||||
- tests/compile/passes/test_silu_mul_quant_fusion.py
|
||||
- tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
- tests/compile/fullgraph/test_full_graph.py
|
||||
@@ -79,6 +80,7 @@ steps:
|
||||
# b200 runners are limited, so we limit the tests to the minimum set only supported on Blackwell
|
||||
- nvidia-smi
|
||||
- pytest -v -s tests/compile/passes/test_fusion_attn.py -k FLASHINFER
|
||||
- pytest -v -s tests/compile/passes/test_mla_attn_quant_fusion.py
|
||||
- pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py
|
||||
# this runner has 2 GPUs available even though num_devices=2 is not set
|
||||
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
|
||||
@@ -22,6 +22,7 @@ or just on the low or high end.
|
||||
| ------------------------------------------------------------------------------ | ---------------------------- | ---------------------------------------------- | ------------------------------ | ------------------ | --------- | ------------ |
|
||||
| [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms) | `fuse_allreduce_rms` | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20% | No | Low |
|
||||
| [Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | Attention output → FP8/NVFP4 quant | Off by default | 3-7% | Yes | Always |
|
||||
| [MLA Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | MLA Attention output → FP8/NVFP4 quant | Off by default | TBD | Yes | Always |
|
||||
| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O2 (ROCm/AITER only) | 2-4% | No | Low |
|
||||
| [QK Norm + RoPE](#qk-norm--rope-enable_qk_norm_rope_fusion) | `enable_qk_norm_rope_fusion` | Q/K RMSNorm → rotary embedding | Off by default | 2-3% | No | Low |
|
||||
| [Sequence Parallelism](#sequence-parallelism-enable_sp) | `enable_sp` | AllReduce → ReduceScatter + AllGather | Off by default | Prereq for AsyncTP | Yes | High |
|
||||
@@ -40,6 +41,7 @@ The table below lists the quantization schemes supported by each fusion on each
|
||||
| ---------------------------- | ---------------------------------------- | ---------------------------------------- | ---------------------------------------- | ------------- | ---------------------------------------- |
|
||||
| `fuse_allreduce_rms` | FP16/BF16, FP8 static, NVFP4 | FP16/BF16, FP8 static | — | — | — |
|
||||
| `fuse_attn_quant`\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static\* |
|
||||
| `fuse_attn_quant` (MLA)\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static(untested)\* |
|
||||
| `fuse_rope_kvcache` | — | — | — | — | FP16/BF16 |
|
||||
| `enable_qk_norm_rope_fusion` | FP16/BF16 | FP16/BF16 | FP16/BF16† | FP16/BF16† | — |
|
||||
| `enable_sp` | FP16/BF16, FP8 static† | FP16/BF16, FP8 static | FP16/BF16† | FP16/BF16† | — |
|
||||
@@ -129,7 +131,8 @@ on SM90/SM100) and configurable via `PassConfig.fi_allreduce_fusion_max_size_mb`
|
||||
explicitly. It requires the full model graph to be visible (Inductor partition or `splitting_ops=[]`).
|
||||
|
||||
**What it fuses.** Fuses the attention output quantization directly after the attention computation,
|
||||
eliminating a full-precision memory round-trip of the attention output. Patterns covered:
|
||||
eliminating a full-precision memory round-trip of the attention output. This fusion supports both
|
||||
standard `Attention` and `MLAAttention` (used by DeepSeek-V2/V3/R1 models). Patterns covered:
|
||||
|
||||
`Attention → FP8 static quant`:
|
||||
|
||||
@@ -142,11 +145,24 @@ eliminating a full-precision memory round-trip of the attention output. Patterns
|
||||
|
||||
- `FLASHINFER`: CUDA sm100+ with FlashInfer installed
|
||||
|
||||
`MLAAttention → FP8 static quant` / `MLAAttention → NVFP4 dynamic quant`:
|
||||
|
||||
The MLA fusion operates at the graph level on the `unified_mla_attention_with_output` op and works
|
||||
with all MLA decode and prefill backend combinations. Unlike standard `Attention` backends (where
|
||||
the kernel writes FP8 output directly), no MLA prefill or decode backend currently supports direct
|
||||
FP8/FP4 output. The fusion writes to an intermediate buffer and quantizes in a separate step, so
|
||||
there is no memory round-trip elimination yet.
|
||||
|
||||
!!! info
|
||||
The MLA attention fusion is not expected to yield a measurable speedup yet.
|
||||
This will improve once MLA prefill/decode kernels support direct FP8/FP4 output.
|
||||
|
||||
Other attention backends do not support fused output quantization yet.
|
||||
|
||||
**Code locations.**
|
||||
|
||||
- Pass: [`vllm/compilation/passes/fusion/attn_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/attn_quant_fusion.py)
|
||||
- Pass (Attention): [`vllm/compilation/passes/fusion/attn_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/attn_quant_fusion.py)
|
||||
- Pass (MLAAttention): [`vllm/compilation/passes/fusion/mla_attn_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/mla_attn_quant_fusion.py)
|
||||
- Attention backends: [`vllm/v1/attention/backends/`](https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/)
|
||||
|
||||
### RoPE + KV-Cache Update (`fuse_rope_kvcache`)
|
||||
|
||||
@@ -84,10 +84,14 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
# Filter here to reduce code duplication
|
||||
backend_name = attn_backend.backend.name.lower()
|
||||
requires_mla = "deepseek" in model_name.lower()
|
||||
is_mla = "mla" in attn_backend.backend.name.lower()
|
||||
is_mla = "mla" in backend_name
|
||||
# DeepSeek V3.2 uses sparse MLA
|
||||
requires_sparse = "v3.2" in model_name.lower()
|
||||
is_sparse = "sparse" in backend_name
|
||||
|
||||
if requires_mla != is_mla:
|
||||
if requires_mla != is_mla or requires_sparse != is_sparse:
|
||||
pytest.skip(
|
||||
f"Incompatible model '{model_name}' and "
|
||||
f"attention backend '{attn_backend.backend.name}'"
|
||||
@@ -231,7 +235,9 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
)
|
||||
|
||||
elif match_name == "attn_quant_fusion":
|
||||
actual_match = match_table.get(match_name, 0)
|
||||
actual_match = match_table.get(
|
||||
"attn_quant_fusion", 0
|
||||
) + match_table.get("mla_attn_quant_fusion", 0)
|
||||
assert actual_match == expected_matches * n_expected, (
|
||||
f"Could not find {expected_matches * n_expected} "
|
||||
f"{match_name} (found {actual_match})."
|
||||
|
||||
@@ -58,6 +58,15 @@ TRITON_MLA_ATTN = pytest.param(
|
||||
id="TRITON_MLA",
|
||||
)
|
||||
|
||||
FLASHMLA_SPARSE_ATTN = pytest.param(
|
||||
AttentionBackendCase(backend=AttentionBackendEnum.FLASHMLA_SPARSE),
|
||||
id="FLASHMLA_SPARSE",
|
||||
marks=pytest.mark.skipif(
|
||||
not is_blackwell(),
|
||||
reason="FlashMLA Sparse requires Blackwell",
|
||||
),
|
||||
)
|
||||
|
||||
# Models
|
||||
llama3_8b = ModelFusionInfo(
|
||||
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||
@@ -141,6 +150,18 @@ qwen3_a3b_fp8 = ModelFusionInfo(
|
||||
),
|
||||
)
|
||||
|
||||
deepseek_coder_v2_lite_fp8 = ModelFusionInfo(
|
||||
model_name="RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8",
|
||||
matches=lambda n_layers: Matches(
|
||||
# first_k_dense_replace=1; MoE hides most rms+quant sites
|
||||
rms_quant_fusion=1,
|
||||
act_quant_fusion=min(1, n_layers), # dense layers only
|
||||
# MLA attn + static FP8 quant
|
||||
attn_quant_fusion=n_layers,
|
||||
ar_rms_fusion=n_layers * 2 + 1,
|
||||
),
|
||||
)
|
||||
|
||||
deepseek_v3_fp8 = ModelFusionInfo(
|
||||
model_name="deepseek-ai/DeepSeek-V3",
|
||||
matches=lambda n_layers: Matches(
|
||||
@@ -152,7 +173,7 @@ deepseek_v3_fp8 = ModelFusionInfo(
|
||||
rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers
|
||||
# silu+block quant
|
||||
act_quant_fusion=min(3, n_layers), # dense layers only
|
||||
# MLA attn + quant not supported yet:
|
||||
# MLA attn + per-group FP8 quant not supported yet:
|
||||
# https://github.com/vllm-project/vllm/issues/35792
|
||||
attn_quant_fusion=0,
|
||||
ar_rms_fusion=n_layers * 2 + 1,
|
||||
@@ -162,6 +183,16 @@ deepseek_v3_fp8 = ModelFusionInfo(
|
||||
),
|
||||
)
|
||||
|
||||
deepseek_v32_fp4 = ModelFusionInfo(
|
||||
model_name="nvidia/DeepSeek-V3.2-NVFP4",
|
||||
matches=lambda n_layers: Matches(
|
||||
rms_quant_fusion=0,
|
||||
act_quant_fusion=0,
|
||||
attn_quant_fusion=n_layers,
|
||||
ar_rms_fusion=n_layers * 2 + 1,
|
||||
),
|
||||
)
|
||||
|
||||
gpt_oss_20b = ModelFusionInfo(
|
||||
model_name="openai/gpt-oss-20b",
|
||||
matches=lambda n_layers: Matches(
|
||||
|
||||
@@ -18,11 +18,14 @@ from .common import (
|
||||
from .models import (
|
||||
FLASHINFER_ATTN,
|
||||
FLASHINFER_MLA_ATTN,
|
||||
FLASHMLA_SPARSE_ATTN,
|
||||
ROCM_AITER_UNIFIED_ATTN,
|
||||
ROCM_ATTN,
|
||||
TRITON_ATTN,
|
||||
TRITON_MLA_ATTN,
|
||||
deepseek_coder_v2_lite_fp8,
|
||||
deepseek_v3_fp8,
|
||||
deepseek_v32_fp4,
|
||||
llama3_8b_fp4,
|
||||
llama3_8b_fp8,
|
||||
llama4_scout_fp4,
|
||||
@@ -37,6 +40,7 @@ from .models import (
|
||||
(*llama3_8b_fp8, False),
|
||||
(*qwen3_a3b_fp8, False),
|
||||
(*qwen3_a3b_fp8, True),
|
||||
(*deepseek_coder_v2_lite_fp8, False),
|
||||
(*deepseek_v3_fp8, False),
|
||||
(*deepseek_v3_fp8, True),
|
||||
pytest.param(
|
||||
@@ -144,9 +148,12 @@ def test_tp1_fp8_fusions(
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, matches_fn, model_kwargs, hf_overrides",
|
||||
[llama3_8b_fp4, llama4_scout_fp4],
|
||||
[llama3_8b_fp4, llama4_scout_fp4, deepseek_v32_fp4],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"attn_backend",
|
||||
[FLASHINFER_ATTN, FLASHMLA_SPARSE_ATTN],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", [FLASHINFER_ATTN])
|
||||
@pytest.mark.parametrize("n_layers", [6])
|
||||
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
|
||||
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
|
||||
|
||||
@@ -18,8 +18,11 @@ from .common import (
|
||||
from .models import (
|
||||
FLASHINFER_ATTN,
|
||||
FLASHINFER_MLA_ATTN,
|
||||
FLASHMLA_SPARSE_ATTN,
|
||||
TRITON_ATTN,
|
||||
deepseek_coder_v2_lite_fp8,
|
||||
deepseek_v3_fp8,
|
||||
deepseek_v32_fp4,
|
||||
gpt_oss_20b,
|
||||
llama3_8b,
|
||||
llama3_8b_fp4,
|
||||
@@ -37,7 +40,13 @@ pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only tes
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, matches_fn, model_kwargs, hf_overrides",
|
||||
# qwen3 & dsv3 should still fuse AR+rms even though group quant is not yet supported
|
||||
[llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8, deepseek_v3_fp8],
|
||||
[
|
||||
llama3_8b_fp8,
|
||||
llama4_scout_fp8,
|
||||
qwen3_a3b_fp8,
|
||||
deepseek_coder_v2_lite_fp8,
|
||||
deepseek_v3_fp8,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"attn_backend", [TRITON_ATTN, FLASHINFER_ATTN, FLASHINFER_MLA_ATTN]
|
||||
@@ -104,9 +113,12 @@ def test_tp2_ar_rms_fp8_fusions(
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, matches_fn, model_kwargs, hf_overrides",
|
||||
[llama3_8b_fp4, llama4_scout_fp4],
|
||||
[llama3_8b_fp4, llama4_scout_fp4, deepseek_v32_fp4],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"attn_backend",
|
||||
[FLASHINFER_ATTN, FLASHMLA_SPARSE_ATTN],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", [FLASHINFER_ATTN])
|
||||
@pytest.mark.parametrize("n_layers", [4])
|
||||
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
|
||||
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
|
||||
|
||||
508
tests/compile/passes/test_mla_attn_quant_fusion.py
Normal file
508
tests/compile/passes/test_mla_attn_quant_fusion.py
Normal file
@@ -0,0 +1,508 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
import torch._dynamo
|
||||
|
||||
from tests.compile.backend import LazyInitPass, TestBackend
|
||||
from tests.utils import TestFP8Layer, flat_product
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.passes.fusion.mla_attn_quant_fusion import (
|
||||
MLA_ATTN_OP,
|
||||
MLAAttnQuantFusionPass,
|
||||
)
|
||||
from vllm.compilation.passes.fx_utils import find_op_nodes
|
||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
AttentionConfig,
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.attention import MLAAttention
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.layers.quantization.modelopt import ModelOptNvFp4Config
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
class MLAAttentionQuantPatternModel(torch.nn.Module):
|
||||
"""Base model for MLA AttentionQuantPattern fusion."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
kv_lora_rank: int,
|
||||
kv_cache_dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
vllm_config: VllmConfig,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.output_dim = num_heads * v_head_dim
|
||||
self.head_size = kv_lora_rank + qk_rope_head_dim
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.device = device
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
# Create kv_b_proj (ColumnParallelLinear) on device.
|
||||
# Reuse weights from prior model instance when available, because
|
||||
# ColumnParallelLinear may get NaN from recycled CUDA memory after
|
||||
# torch.compile runs in the same process.
|
||||
kv_b_proj = ColumnParallelLinear(
|
||||
input_size=kv_lora_rank,
|
||||
output_size=num_heads * (qk_nope_head_dim + v_head_dim),
|
||||
bias=False,
|
||||
prefix="model.layers.0.self_attn.kv_b_proj",
|
||||
).to(device)
|
||||
kv_b_proj_weight = kwargs.get("kv_b_proj_weight")
|
||||
if kv_b_proj_weight is not None:
|
||||
kv_b_proj.weight.data.copy_(kv_b_proj_weight)
|
||||
elif kv_b_proj.weight.data.isnan().any():
|
||||
# Sanitize NaN from recycled CUDA memory
|
||||
kv_b_proj.weight.data.normal_()
|
||||
|
||||
# Create MLAAttention
|
||||
self.mla_attn = MLAAttention(
|
||||
num_heads=num_heads,
|
||||
scale=1.0 / (self.qk_head_dim**0.5),
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
kv_b_proj=kv_b_proj,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=self.quant_config,
|
||||
prefix="model.layers.0.self_attn.attn",
|
||||
)
|
||||
self.mla_attn._k_scale = self.mla_attn._k_scale.to(device)
|
||||
self.mla_attn._v_scale = self.mla_attn._v_scale.to(device)
|
||||
|
||||
# Initialize W_UK_T and W_UV from kv_b_proj weights
|
||||
self.mla_attn.process_weights_after_loading(torch.get_default_dtype())
|
||||
self.kv_b_proj_weight = kv_b_proj.weight.data.clone()
|
||||
|
||||
self.block_size = 16
|
||||
|
||||
# Initialize MLA MetadataBuilder
|
||||
self.builder = self.mla_attn.attn_backend.get_builder_cls()(
|
||||
kv_cache_spec=MLAAttentionSpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
),
|
||||
layer_names=[self.mla_attn.layer_name],
|
||||
vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
|
||||
"""Initialize MLA attention metadata.
|
||||
|
||||
NOTE: Uses decode-only batch (query_len=1 per request). The prefill
|
||||
(forward_mha) path is not separately tested here because it requires
|
||||
FlashAttention availability and different input tensor shapes. The
|
||||
quant logic in forward_impl is identical for both paths — it quantizes
|
||||
the full output[:num_actual_toks] buffer after both forward_mha and
|
||||
forward_mqa have written their results.
|
||||
"""
|
||||
|
||||
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, self.block_size, self.device, arange_block_indices=True
|
||||
)
|
||||
|
||||
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
|
||||
num_blocks = batch_size * max_blocks
|
||||
|
||||
# MLA KV cache is 3D: (num_blocks, block_size, head_size)
|
||||
attn_backend = self.mla_attn.attn_backend
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, 1, self.head_size
|
||||
)
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||
|
||||
ordered_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
|
||||
raw_tensor = torch.zeros(
|
||||
ordered_shape, dtype=self.kv_cache_dtype, device=self.device
|
||||
)
|
||||
kv_cache = raw_tensor.permute(*inv_order)
|
||||
|
||||
self.mla_attn.kv_cache = kv_cache
|
||||
|
||||
self.attn_metadata = self.builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
|
||||
return self.attn_metadata
|
||||
|
||||
|
||||
class TestMLAAttentionFp8StaticQuantPatternModel(MLAAttentionQuantPatternModel):
|
||||
"""Test model for MLA Attention + FP8 static quant fusion."""
|
||||
|
||||
quant_key = kFp8StaticTensorSym
|
||||
quant_config = Fp8Config()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.fp8_linear = TestFP8Layer(
|
||||
weight_shape=(self.output_dim, self.output_dim),
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
w = kwargs.get("w")
|
||||
if w is not None:
|
||||
self.fp8_linear.weight = w["weight"]
|
||||
self.fp8_linear.weight_scale = w["wscale"]
|
||||
self.fp8_linear.input_scale = w["scale"]
|
||||
|
||||
self.w = {
|
||||
"weight": self.fp8_linear.weight,
|
||||
"wscale": self.fp8_linear.weight_scale,
|
||||
"scale": self.fp8_linear.input_scale,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
):
|
||||
"""Forward pass that creates the MLA attention + FP8 quant pattern."""
|
||||
attn_output = self.mla_attn(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=(q.shape[0], self.output_dim),
|
||||
)
|
||||
return self.fp8_linear(attn_output)
|
||||
|
||||
|
||||
class TestMLAAttentionNvfp4QuantPatternModel(MLAAttentionQuantPatternModel):
|
||||
"""Test model for MLA Attention + NVFP4 quant fusion."""
|
||||
|
||||
quant_key = kNvfp4Dynamic
|
||||
quant_config = ModelOptNvFp4Config(
|
||||
is_checkpoint_nvfp4_serialized=False,
|
||||
kv_cache_quant_algo=None,
|
||||
exclude_modules=[],
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.w = kwargs.get(
|
||||
"w",
|
||||
{
|
||||
"weight": torch.randint(
|
||||
256,
|
||||
(self.output_dim, self.output_dim // 2),
|
||||
dtype=FP4_DTYPE,
|
||||
device=self.device,
|
||||
),
|
||||
"wscale_swizzled": torch.randn(
|
||||
self.output_dim, self.output_dim // 16
|
||||
).to(dtype=FP8_DTYPE, device=self.device),
|
||||
"wscale": torch.tensor([500], dtype=torch.float32, device=self.device),
|
||||
"scale": torch.tensor([0.002], dtype=torch.float32, device=self.device),
|
||||
},
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
):
|
||||
"""Forward pass that creates the MLA attention + NVFP4 quant pattern."""
|
||||
attn_output = self.mla_attn(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=(q.shape[0], self.output_dim),
|
||||
)
|
||||
quant_output, output_block_scale = scaled_fp4_quant(
|
||||
attn_output, 1 / self.w["scale"]
|
||||
)
|
||||
return cutlass_scaled_fp4_mm(
|
||||
a=quant_output,
|
||||
b=self.w["weight"],
|
||||
block_scale_a=output_block_scale,
|
||||
block_scale_b=self.w["wscale_swizzled"],
|
||||
alpha=self.w["scale"] * self.w["wscale"],
|
||||
out_dtype=attn_output.dtype,
|
||||
)
|
||||
|
||||
|
||||
def is_nvfp4_supported():
|
||||
return current_platform.has_device_capability(100)
|
||||
|
||||
|
||||
# MLA test configuration
|
||||
MLA_DIMS: list[tuple[int, int, int, int, int]] = []
|
||||
PATTERN_TEST_MODELS_MLA_FP8: list[tuple[str, type]] = []
|
||||
PATTERN_TEST_MODELS_MLA_FP4: list[tuple[str, type]] = []
|
||||
BACKENDS_MLA_FP8: list[AttentionBackendEnum] = []
|
||||
BACKENDS_MLA_FP4: list[AttentionBackendEnum] = []
|
||||
|
||||
if current_platform.is_cuda():
|
||||
# (num_heads, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, kv_lora_rank)
|
||||
MLA_DIMS = [(16, 128, 64, 128, 512)]
|
||||
PATTERN_TEST_MODELS_MLA_FP8 = [
|
||||
(
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
TestMLAAttentionFp8StaticQuantPatternModel,
|
||||
)
|
||||
]
|
||||
PATTERN_TEST_MODELS_MLA_FP4 = [
|
||||
(
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
TestMLAAttentionNvfp4QuantPatternModel,
|
||||
)
|
||||
]
|
||||
BACKENDS_MLA_FP8 = [AttentionBackendEnum.TRITON_MLA]
|
||||
BACKENDS_MLA_FP4 = [AttentionBackendEnum.TRITON_MLA]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_heads, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, kv_lora_rank",
|
||||
MLA_DIMS,
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [7, 256] if current_platform.is_cuda() else [8])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize(
|
||||
"backend, model_name, model_class, custom_ops",
|
||||
list(
|
||||
flat_product(
|
||||
BACKENDS_MLA_FP8,
|
||||
PATTERN_TEST_MODELS_MLA_FP8,
|
||||
["+quant_fp8", "-quant_fp8"],
|
||||
)
|
||||
)
|
||||
+ list(flat_product(BACKENDS_MLA_FP4, PATTERN_TEST_MODELS_MLA_FP4, [""])),
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
|
||||
)
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
def test_mla_attention_quant_pattern(
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
kv_lora_rank: int,
|
||||
batch_size: int,
|
||||
dtype: torch.dtype,
|
||||
custom_ops: str,
|
||||
model_name: str,
|
||||
model_class: type[MLAAttentionQuantPatternModel],
|
||||
backend: AttentionBackendEnum,
|
||||
dist_init,
|
||||
monkeypatch,
|
||||
use_fresh_inductor_cache,
|
||||
):
|
||||
"""Test MLA AttentionQuantPattern fusion pass"""
|
||||
if (
|
||||
model_class is TestMLAAttentionNvfp4QuantPatternModel
|
||||
and not is_nvfp4_supported()
|
||||
):
|
||||
pytest.skip("NVFP4 is not supported on this GPU (requires SM 100+).")
|
||||
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(42)
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
dtype=dtype,
|
||||
)
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_num_seqs=1024,
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops_list,
|
||||
),
|
||||
cache_config=CacheConfig(cache_dtype="auto"),
|
||||
attention_config=AttentionConfig(backend=backend),
|
||||
)
|
||||
|
||||
# MLA inputs: q(B, N, qk_head_dim), kv_c_normed(B, L), k_pe(B, 1, R)
|
||||
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
q = torch.randn(batch_size, num_heads, qk_head_dim, dtype=dtype, device=device)
|
||||
kv_c_normed = torch.randn(batch_size, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe = torch.randn(batch_size, 1, qk_rope_head_dim, dtype=dtype, device=device)
|
||||
|
||||
# Mark first dimension as dynamic
|
||||
torch._dynamo.mark_dynamic(q, 0)
|
||||
torch._dynamo.mark_dynamic(kv_c_normed, 0)
|
||||
torch._dynamo.mark_dynamic(k_pe, 0)
|
||||
|
||||
# Run model without fusion
|
||||
vllm_config_unfused = copy.deepcopy(vllm_config)
|
||||
with (
|
||||
set_current_vllm_config(vllm_config_unfused),
|
||||
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
|
||||
):
|
||||
model_unfused = model_class(
|
||||
num_heads=num_heads,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
kv_cache_dtype=dtype,
|
||||
device=device,
|
||||
vllm_config=vllm_config_unfused,
|
||||
)
|
||||
model_unfused = model_unfused.to(device)
|
||||
# HACK: See #131044
|
||||
result_unfused_0 = model_unfused(q, kv_c_normed, k_pe) # noqa: F841
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
|
||||
|
||||
compiled_unfused = torch.compile(model_unfused, fullgraph=True)
|
||||
result_unfused = compiled_unfused(q, kv_c_normed, k_pe)
|
||||
|
||||
# Run model with attn fusion enabled
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
fuse_attn_quant=True, eliminate_noops=True
|
||||
)
|
||||
with (
|
||||
set_current_vllm_config(vllm_config),
|
||||
set_forward_context(attn_metadata=None, vllm_config=vllm_config),
|
||||
):
|
||||
model_fused = model_class(
|
||||
num_heads=num_heads,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
kv_cache_dtype=dtype,
|
||||
device=device,
|
||||
vllm_config=vllm_config,
|
||||
w=model_unfused.w,
|
||||
kv_b_proj_weight=model_unfused.kv_b_proj_weight,
|
||||
)
|
||||
model_fused = model_fused.to(device)
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
|
||||
|
||||
# Create test backend with fusion passes
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
attn_pass = LazyInitPass(MLAAttnQuantFusionPass, vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
|
||||
# HACK: See https://github.com/vllm-project/vllm/issues/31044
|
||||
result_fused_0 = model_fused(q, kv_c_normed, k_pe) # noqa: F841
|
||||
|
||||
compiled_fused = torch.compile(
|
||||
model_fused, backend=test_backend, fullgraph=True
|
||||
)
|
||||
|
||||
result_fused = compiled_fused(q, kv_c_normed, k_pe)
|
||||
|
||||
# Check attn fusion support
|
||||
quant_key: QuantKey = model_class.quant_key
|
||||
attn_fusion_supported = [
|
||||
layer.impl.fused_output_quant_supported(quant_key)
|
||||
for key, layer in vllm_config.compilation_config.static_forward_context.items()
|
||||
if isinstance(layer, MLAAttention)
|
||||
]
|
||||
assert sum(attn_fusion_supported) == len(attn_fusion_supported), (
|
||||
"All MLA layers should support attention fusion"
|
||||
)
|
||||
|
||||
# Check quantization ops in the graph
|
||||
quant_op = (
|
||||
torch.ops.aten.reciprocal
|
||||
if "-quant_fp8" in custom_ops_list
|
||||
else QUANT_OPS[quant_key]
|
||||
)
|
||||
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic)
|
||||
|
||||
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
|
||||
|
||||
# Check MLA attention ops in the graph
|
||||
attn_nodes_pre = list(find_op_nodes(MLA_ATTN_OP, test_backend.graph_pre_pass))
|
||||
attn_nodes_post = list(find_op_nodes(MLA_ATTN_OP, test_backend.graph_post_pass))
|
||||
|
||||
assert len(attn_nodes_pre) > 0, "Should have MLA attention nodes before fusion"
|
||||
assert len(attn_nodes_pre) == len(attn_nodes_post), (
|
||||
"Should have same number of MLA attention nodes before and after fusion"
|
||||
)
|
||||
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
|
||||
"MLA attention should not have output_scale before fusion"
|
||||
)
|
||||
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
|
||||
"MLA attention should have output_scale after fusion"
|
||||
)
|
||||
|
||||
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
|
||||
"MLA attention should not have output_block_scale before fusion"
|
||||
)
|
||||
|
||||
if quant_key.dtype == FP8_DTYPE:
|
||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
|
||||
"MLA attention should not have output_block_scale after FP8 fusion"
|
||||
)
|
||||
elif quant_key.dtype == FP4_DTYPE:
|
||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
|
||||
"MLA attention should have output_block_scale after FP4 fusion"
|
||||
)
|
||||
|
||||
# Check numerical correctness
|
||||
torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2)
|
||||
262
vllm/compilation/passes/fusion/mla_attn_quant_fusion.py
Normal file
262
vllm/compilation/passes/fusion/mla_attn_quant_fusion.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
from vllm._custom_ops import create_fp4_output_tensors
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import MLAAttention
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement
|
||||
from .matcher_utils import MatcherQuantFP8
|
||||
from .rms_quant_fusion import QUANT_OPS
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
MLA_ATTN_OP = torch.ops.vllm.unified_mla_attention_with_output.default
|
||||
|
||||
|
||||
class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
|
||||
"""
|
||||
Fusion for MLA Attention+Fp8StaticQuant.
|
||||
|
||||
Matches the pattern: MLA attention -> static FP8 quant, and replaces
|
||||
it with MLA attention(output_scale=scale, output=fp8_buffer).
|
||||
"""
|
||||
|
||||
def __init__(self, layer: MLAAttention, dtype: torch.dtype) -> None:
|
||||
self._layer_name = layer.layer_name
|
||||
self._num_heads = layer.num_heads
|
||||
self._v_head_dim = layer.v_head_dim
|
||||
self._kv_lora_rank = layer.kv_lora_rank
|
||||
self._qk_rope_head_dim = layer.qk_rope_head_dim
|
||||
self._qk_head_dim = layer.qk_nope_head_dim + layer.qk_rope_head_dim
|
||||
self._output_dim = layer.num_heads * layer.v_head_dim
|
||||
self._dtype = dtype
|
||||
self._quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
@property
|
||||
def pattern(self) -> Callable[..., torch.Tensor]:
|
||||
def _pattern(
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
at1 = auto_functionalized(
|
||||
MLA_ATTN_OP,
|
||||
q=q,
|
||||
kv_c_normed=kv_c_normed,
|
||||
k_pe=k_pe,
|
||||
output=output_attn,
|
||||
layer_name=self._layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
# MLA output is already 2D (T, N*V), no reshape needed
|
||||
return self._quant_matcher(at1[1], scale)[0]
|
||||
|
||||
return _pattern
|
||||
|
||||
@property
|
||||
def replacement(self) -> Callable[..., torch.Tensor]:
|
||||
def _replacement(
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# MLA output in quant_dtype
|
||||
output_attn = torch.empty(
|
||||
[q.shape[0], self._output_dim],
|
||||
dtype=FP8_DTYPE,
|
||||
device=q.device,
|
||||
)
|
||||
at1 = auto_functionalized(
|
||||
MLA_ATTN_OP,
|
||||
q=q,
|
||||
kv_c_normed=kv_c_normed,
|
||||
k_pe=k_pe,
|
||||
output=output_attn,
|
||||
layer_name=self._layer_name,
|
||||
output_scale=scale,
|
||||
output_block_scale=None,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return at1[1]
|
||||
|
||||
return _replacement
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
return [
|
||||
self.empty(5, self._num_heads, self._qk_head_dim, dtype=self._dtype),
|
||||
self.empty(5, self._kv_lora_rank, dtype=self._dtype),
|
||||
self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype),
|
||||
self.empty(5, self._output_dim, dtype=self._dtype),
|
||||
self.empty_fp32(1, 1),
|
||||
self.empty(0, dtype=self._dtype),
|
||||
]
|
||||
|
||||
|
||||
class MLAAttnNvfp4QuantPattern(
|
||||
VllmPatternReplacement[..., tuple[torch.Tensor, torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Fusion for MLA Attention+Nvfp4Quant.
|
||||
|
||||
Matches the pattern: MLA attention -> NVFP4 quant, and replaces
|
||||
it with MLA attention(output_scale=scale, output_block_scale=block_scale,
|
||||
output=fp4_buffer).
|
||||
"""
|
||||
|
||||
def __init__(self, layer: MLAAttention, dtype: torch.dtype) -> None:
|
||||
self._layer_name = layer.layer_name
|
||||
self._num_heads = layer.num_heads
|
||||
self._v_head_dim = layer.v_head_dim
|
||||
self._kv_lora_rank = layer.kv_lora_rank
|
||||
self._qk_rope_head_dim = layer.qk_rope_head_dim
|
||||
self._qk_head_dim = layer.qk_nope_head_dim + layer.qk_rope_head_dim
|
||||
self._output_dim = layer.num_heads * layer.v_head_dim
|
||||
self._dtype = dtype
|
||||
self._QUANT_OP = QUANT_OPS[kNvfp4Dynamic]
|
||||
|
||||
@property
|
||||
def pattern(
|
||||
self,
|
||||
) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
|
||||
def _pattern(
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at1 = auto_functionalized(
|
||||
MLA_ATTN_OP,
|
||||
q=q,
|
||||
kv_c_normed=kv_c_normed,
|
||||
k_pe=k_pe,
|
||||
output=output_attn,
|
||||
layer_name=self._layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
# Replicate what scaled_fp4_quant() does: allocate output
|
||||
# tensors inline then call the .out variant.
|
||||
output_quant, output_scale = create_fp4_output_tensors(
|
||||
at1[1].shape[0], at1[1].shape[1], at1[1].device, True
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self._QUANT_OP,
|
||||
input=at1[1],
|
||||
input_scale=input_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
output=output_quant,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||
return at2[1], output_scale_view
|
||||
|
||||
return _pattern
|
||||
|
||||
@property
|
||||
def replacement(
|
||||
self,
|
||||
) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
|
||||
def _replacement(
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# MLA output in quant_dtype (FP4 packed as uint8)
|
||||
output_attn = torch.empty(
|
||||
[q.shape[0], self._output_dim // 2],
|
||||
dtype=FP4_DTYPE,
|
||||
device=q.device,
|
||||
)
|
||||
# attention output block scale
|
||||
output_scale = create_fp4_output_tensors(
|
||||
q.shape[0], self._output_dim, q.device, True
|
||||
)[1]
|
||||
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
|
||||
at2 = auto_functionalized(
|
||||
MLA_ATTN_OP,
|
||||
q=q,
|
||||
kv_c_normed=kv_c_normed,
|
||||
k_pe=k_pe,
|
||||
output=output_attn,
|
||||
layer_name=self._layer_name,
|
||||
output_scale=input_scale,
|
||||
output_block_scale=output_scale_view,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return at2[1], at2[2]
|
||||
|
||||
return _replacement
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
return [
|
||||
self.empty(5, self._num_heads, self._qk_head_dim, dtype=self._dtype),
|
||||
self.empty(5, self._kv_lora_rank, dtype=self._dtype),
|
||||
self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype),
|
||||
self.empty(5, self._output_dim, dtype=self._dtype),
|
||||
self.empty_fp32(1, 1),
|
||||
self.empty(0, dtype=self._dtype),
|
||||
]
|
||||
|
||||
|
||||
class MLAAttnQuantFusionPass(VllmFusionPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses post-attention quantization onto MLA attention if supported.
|
||||
|
||||
It uses the pattern matcher and matches each MLA layer manually, as strings
|
||||
cannot be wildcarded. This also lets us check support on attention layers
|
||||
upon registration instead of during pattern matching.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config, "mla_attn_quant_fusion")
|
||||
|
||||
dtype = config.model_config.dtype
|
||||
layers = list(get_layers_from_vllm_config(config, MLAAttention).values())
|
||||
|
||||
if len(layers) == 0:
|
||||
logger.warning(
|
||||
"MLA attention + quant fusion is enabled, but no MLA "
|
||||
"attention layers were found in "
|
||||
"CompilationConfig.static_forward_context "
|
||||
"so no fusion patterns were registered."
|
||||
)
|
||||
|
||||
for layer in layers:
|
||||
if layer.impl.fused_output_quant_supported(kFp8StaticTensorSym):
|
||||
self.register(MLAAttnFp8StaticQuantPattern(layer, dtype))
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
for layer in layers:
|
||||
if layer.impl.fused_output_quant_supported(kNvfp4Dynamic):
|
||||
self.register(MLAAttnNvfp4QuantPattern(layer, dtype))
|
||||
|
||||
self.dump_patterns(config, self.pm_pass)
|
||||
@@ -27,6 +27,7 @@ if rocm_aiter_ops.is_enabled():
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fusion.act_quant_fusion import ActivationQuantFusionPass
|
||||
from .fusion.attn_quant_fusion import AttnQuantFusionPass
|
||||
from .fusion.mla_attn_quant_fusion import MLAAttnQuantFusionPass
|
||||
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
|
||||
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
||||
from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
|
||||
@@ -157,6 +158,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
||||
|
||||
if self.pass_config.fuse_attn_quant:
|
||||
self.passes += [AttnQuantFusionPass(config)]
|
||||
self.passes += [MLAAttnQuantFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||
self.passes += [SplitCoalescingPass(config)]
|
||||
|
||||
@@ -121,7 +121,7 @@ class PassConfig:
|
||||
fuse_act_quant: bool = None # type: ignore[assignment]
|
||||
"""Fuse the custom SiluMul + quant ops."""
|
||||
fuse_attn_quant: bool = None # type: ignore[assignment]
|
||||
"""Fuse the custom attention + quant ops."""
|
||||
"""Fuse the custom Attention and MLAAttention + quant ops."""
|
||||
eliminate_noops: bool = Field(default=True)
|
||||
"""Eliminate no-op ops."""
|
||||
enable_sp: bool = None # type: ignore[assignment]
|
||||
|
||||
@@ -449,6 +449,11 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
compile_native=True,
|
||||
)
|
||||
self._quant_fp8_op = QuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
compile_native=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def chunked_prefill_workspace_size(self) -> int:
|
||||
@@ -545,9 +550,19 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
) -> torch.Tensor:
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for MLA"
|
||||
use_quant = output_scale is not None or output_block_scale is not None
|
||||
if use_quant:
|
||||
# The fusion pass has allocated output with quantized dtype
|
||||
# (FP8 or uint8 for FP4). We can't write into it directly,
|
||||
# so we swap in a temp buffer for computation, then quantize
|
||||
# into the real output at the end.
|
||||
# NOTE(carlyou): this is temporary until kernels support fp8 output
|
||||
quant_output = output
|
||||
output = torch.empty(
|
||||
output.shape[0],
|
||||
self.num_heads * self.v_head_dim,
|
||||
dtype=q.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
@@ -567,6 +582,8 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
if use_quant:
|
||||
return quant_output.fill_(0)
|
||||
return output.fill_(0)
|
||||
|
||||
if self.impl.dcp_world_size == -1:
|
||||
@@ -706,6 +723,21 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
|
||||
# v_up projection
|
||||
self._v_up_proj(attn_out, out=mqa_output_slice)
|
||||
|
||||
if use_quant:
|
||||
# Quantize the BF16 computation result into the quantized output
|
||||
actual = output[:num_actual_toks]
|
||||
if output_block_scale is not None:
|
||||
# NVFP4: two FP4 values packed into one uint8
|
||||
fp4_data, fp4_scales = ops.scaled_fp4_quant(actual, output_scale)
|
||||
quant_output[:num_actual_toks].copy_(fp4_data)
|
||||
output_block_scale.copy_(fp4_scales)
|
||||
else:
|
||||
# Static FP8 quantization
|
||||
fp8_data, _ = self._quant_fp8_op(actual, output_scale)
|
||||
quant_output[:num_actual_toks].copy_(fp8_data)
|
||||
return quant_output
|
||||
|
||||
return output_padded
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
@@ -2069,6 +2101,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def fused_output_quant_supported(self, quant_key):
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
)
|
||||
|
||||
return quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -2513,8 +2553,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
if hasattr(self.kv_b_proj, "weight")
|
||||
else self.kv_b_proj.params_dtype
|
||||
)
|
||||
if use_fp8_prefill or _kv_b_proj_w_dtype != current_platform.fp8_dtype():
|
||||
kv_c_normed = kv_c_normed.to(_kv_b_proj_w_dtype)
|
||||
# For NVFP4, weights are packed uint8 — keep input in model dtype
|
||||
# since the NVFP4 linear layer quantizes internally.
|
||||
if (
|
||||
use_fp8_prefill or _kv_b_proj_w_dtype != current_platform.fp8_dtype()
|
||||
) and _kv_b_proj_w_dtype != torch.uint8:
|
||||
kv_c_normed = kv_c_normed.to(self.kv_b_proj.weight.dtype)
|
||||
|
||||
k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
|
||||
@@ -10,6 +10,11 @@ import numpy as np
|
||||
import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
@@ -873,6 +878,14 @@ class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""MQA-style decode forward pass."""
|
||||
raise NotImplementedError
|
||||
|
||||
def fused_output_quant_supported(self, quant_key: "QuantKey"):
|
||||
"""
|
||||
Does this attention implementation support fused output quantization.
|
||||
Since MLA quantization is done manually in forward_impl (common code),
|
||||
all MLA backends support it by default.
|
||||
"""
|
||||
return quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
kv_c_normed: torch.Tensor,
|
||||
@@ -903,6 +916,14 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
They do not support prefill (MHA-style) attention.
|
||||
"""
|
||||
|
||||
def fused_output_quant_supported(self, quant_key: "QuantKey"):
|
||||
"""
|
||||
Does this attention implementation support fused output quantization.
|
||||
Since MLA quantization is done manually in forward_impl (common code),
|
||||
all MLA backends support it by default.
|
||||
"""
|
||||
return quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user