[Bugfix] Fix QK Norm+RoPE fusion pattern matching on B200+FP8 (#33967)

Signed-off-by: Ikenna <ikennachifo@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Ikenna
2026-02-06 21:27:33 -05:00
committed by GitHub
parent 89a385d79f
commit 906077181b
5 changed files with 154 additions and 6 deletions

View File

@@ -101,9 +101,7 @@ qwen3_a3b_fp8 = ModelFusionInfo(
model_name="Qwen/Qwen3-30B-A3B-FP8",
matches=lambda n_layers: Matches(
rms_quant_fusion=n_layers,
# TODO broken on Blackwell:
# https://github.com/vllm-project/vllm/issues/33295
norm_rope_fusion=0 if is_blackwell() else n_layers,
norm_rope_fusion=n_layers,
attn_quant_fusion=0, # attn + group quant not supported
ar_rms_fusion=n_layers * 2 + 1,
sequence_parallel=n_layers * 2 + 1,

View File

@@ -16,6 +16,7 @@ from vllm.compilation.passes.fusion.qk_norm_rope_fusion import (
)
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
from vllm.config import (
CompilationConfig,
CompilationMode,
@@ -45,6 +46,7 @@ class QKNormRoPETestModel(torch.nn.Module):
is_neox: bool,
vllm_config: VllmConfig,
dtype: torch.dtype,
test_scattered_split: bool = False,
prefix: str = "model.layers.0.self_attn.attn",
) -> None:
super().__init__()
@@ -78,10 +80,16 @@ class QKNormRoPETestModel(torch.nn.Module):
is_neox_style=is_neox,
dtype=self.dtype,
)
self.test_scattered_split = test_scattered_split
self.enable_rms_norm_custom_op = self.q_norm.enabled()
self.enable_rope_custom_op = self.rotary_emb.enabled()
def forward(self, qkv: torch.Tensor, positions: torch.Tensor):
if self.test_scattered_split:
q, _, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
_, k, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
_, _, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
q_by_head = self.q_norm(q_by_head)
@@ -112,6 +120,7 @@ class QKNormRoPETestModel(torch.nn.Module):
return [FUSED_QK_ROPE_OP]
@pytest.mark.parametrize("scattered_split", [True, False])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("is_neox", [True, False])
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@@ -122,7 +131,12 @@ class QKNormRoPETestModel(torch.nn.Module):
reason="Only test on cuda and rocm platform",
)
def test_qk_norm_rope_fusion(
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype
eps,
is_neox,
enable_rms_norm_custom_op,
enable_rope_custom_op,
dtype,
scattered_split,
):
if not hasattr(torch.ops._C, "fused_qk_norm_rope"):
pytest.skip("fused_qk_norm_rope custom op not available")
@@ -161,13 +175,15 @@ def test_qk_norm_rope_fusion(
is_neox=is_neox,
vllm_config=vllm_config,
dtype=dtype,
test_scattered_split=scattered_split,
)
noop_pass = NoOpEliminationPass(vllm_config)
coalesce_pass = SplitCoalescingPass(vllm_config)
fusion_pass = QKNormRoPEFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend = TestBackend(noop_pass, coalesce_pass, fusion_pass, cleanup_pass)
backend_baseline = TestBackend(noop_pass, cleanup_pass)
qkv = torch.randn(T, model.q_size + 2 * model.kv_size)

View File

@@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import vllm
from tests.compile.backend import TestBackend
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
class SplitCoalescingModel(torch.nn.Module):
"""Model with 3 separate split_with_sizes calls on the same input,
simulating the B200+FP8 graph where CSE fails to merge them."""
def __init__(self, q_size: int, kv_size: int) -> None:
super().__init__()
self.q_size = q_size
self.kv_size = kv_size
def forward(self, qkv: torch.Tensor):
q, _, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
_, k, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
_, _, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
return q + 1, k + 2, v + 3
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_split_coalescing(dtype):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(0)
q_size, kv_size = 2048, 512
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(),
)
)
with vllm.config.set_current_vllm_config(vllm_config):
coalesce_pass = SplitCoalescingPass(vllm_config)
backend = TestBackend(coalesce_pass)
model = SplitCoalescingModel(q_size, kv_size)
T = 5
qkv = torch.randn(T, q_size + 2 * kv_size)
torch._dynamo.mark_dynamic(qkv, 0)
result_eager = model(qkv)
model_compiled = torch.compile(model, backend=backend)
result_compiled = model_compiled(qkv)
ATOL, RTOL = (2e-3, 2e-3)
for eager, compiled in zip(result_eager, result_compiled):
torch.testing.assert_close(eager, compiled, atol=ATOL, rtol=RTOL)
assert backend.op_count(torch.ops.aten.split_with_sizes.default) == 1

View File

@@ -29,6 +29,7 @@ if current_platform.is_cuda_alike():
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
from .fusion.sequence_parallelism import SequenceParallelismPass
from .utility.split_coalescing import SplitCoalescingPass
if current_platform.is_cuda():
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
@@ -139,6 +140,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_qk_norm_rope_fusion:
self.passes += [SplitCoalescingPass(config)]
self.passes += [QKNormRoPEFusionPass(config)]
# needs a functional graph

View File

@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Coalesce duplicate ``split_with_sizes`` nodes that operate on the same
input tensor with the same split sizes.
On certain hardware/dtype combinations (e.g. B200 + FP8) the Inductor
graph may contain multiple ``split_with_sizes`` calls on the same tensor
that CSE fails to merge. This pass detects and replaces the duplicates
so that downstream pattern-matching passes (e.g. QK-Norm+RoPE fusion)
see a single split node with all users attached.
See also:
- vLLM #33295 (original issue)
- PyTorch #174472 (upstream CSE gap)
"""
import operator
import torch
from torch import fx
from vllm.logger import init_logger
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class SplitCoalescingPass(VllmInductorPass):
"""Replace duplicate ``split_with_sizes`` nodes with a single canonical
node when they share the same input tensor and split sizes."""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
count = 0
# Map from input tensor node -> list of split nodes seen so far.
split_nodes: dict[fx.Node, list[fx.Node]] = {}
for node in graph.nodes:
if not is_func(node, torch.ops.aten.split_with_sizes.default):
continue
if not all(is_func(user, operator.getitem) for user in node.users):
continue
arg_node, split_sizes = node.args[:2]
if arg_node not in split_nodes:
split_nodes[arg_node] = [node]
continue
# Find existing node with same split_sizes
canonical = next(
(
n
for n in split_nodes[arg_node]
if list(n.args[1]) == list(split_sizes)
),
None,
)
if canonical is not None:
node.replace_all_uses_with(canonical)
graph.erase_node(node)
count += 1
else:
split_nodes[arg_node].append(node)
logger.debug("Coalesced %d duplicate split_with_sizes nodes", count)