From 906077181b21128576018604562e55fbc7f70c34 Mon Sep 17 00:00:00 2001 From: Ikenna Date: Fri, 6 Feb 2026 21:27:33 -0500 Subject: [PATCH] [Bugfix] Fix QK Norm+RoPE fusion pattern matching on B200+FP8 (#33967) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ikenna Co-authored-by: Luka Govedič --- tests/compile/fusions_e2e/models.py | 4 +- .../passes/test_qk_norm_rope_fusion.py | 22 +++++- tests/compile/passes/test_split_coalescing.py | 62 ++++++++++++++++ vllm/compilation/passes/pass_manager.py | 2 + .../passes/utility/split_coalescing.py | 70 +++++++++++++++++++ 5 files changed, 154 insertions(+), 6 deletions(-) create mode 100644 tests/compile/passes/test_split_coalescing.py create mode 100644 vllm/compilation/passes/utility/split_coalescing.py diff --git a/tests/compile/fusions_e2e/models.py b/tests/compile/fusions_e2e/models.py index ef9b6be25..f54f617c6 100644 --- a/tests/compile/fusions_e2e/models.py +++ b/tests/compile/fusions_e2e/models.py @@ -101,9 +101,7 @@ qwen3_a3b_fp8 = ModelFusionInfo( model_name="Qwen/Qwen3-30B-A3B-FP8", matches=lambda n_layers: Matches( rms_quant_fusion=n_layers, - # TODO broken on Blackwell: - # https://github.com/vllm-project/vllm/issues/33295 - norm_rope_fusion=0 if is_blackwell() else n_layers, + norm_rope_fusion=n_layers, attn_quant_fusion=0, # attn + group quant not supported ar_rms_fusion=n_layers * 2 + 1, sequence_parallel=n_layers * 2 + 1, diff --git a/tests/compile/passes/test_qk_norm_rope_fusion.py b/tests/compile/passes/test_qk_norm_rope_fusion.py index bb8bc043e..f9a86732c 100644 --- a/tests/compile/passes/test_qk_norm_rope_fusion.py +++ b/tests/compile/passes/test_qk_norm_rope_fusion.py @@ -16,6 +16,7 @@ from vllm.compilation.passes.fusion.qk_norm_rope_fusion import ( ) from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass +from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass from vllm.config import ( CompilationConfig, CompilationMode, @@ -45,6 +46,7 @@ class QKNormRoPETestModel(torch.nn.Module): is_neox: bool, vllm_config: VllmConfig, dtype: torch.dtype, + test_scattered_split: bool = False, prefix: str = "model.layers.0.self_attn.attn", ) -> None: super().__init__() @@ -78,11 +80,17 @@ class QKNormRoPETestModel(torch.nn.Module): is_neox_style=is_neox, dtype=self.dtype, ) + self.test_scattered_split = test_scattered_split self.enable_rms_norm_custom_op = self.q_norm.enabled() self.enable_rope_custom_op = self.rotary_emb.enabled() def forward(self, qkv: torch.Tensor, positions: torch.Tensor): - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.test_scattered_split: + q, _, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + _, k, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + _, _, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + else: + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) @@ -112,6 +120,7 @@ class QKNormRoPETestModel(torch.nn.Module): return [FUSED_QK_ROPE_OP] +@pytest.mark.parametrize("scattered_split", [True, False]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("is_neox", [True, False]) @pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) @@ -122,7 +131,12 @@ class QKNormRoPETestModel(torch.nn.Module): reason="Only test on cuda and rocm platform", ) def test_qk_norm_rope_fusion( - eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype + eps, + is_neox, + enable_rms_norm_custom_op, + enable_rope_custom_op, + dtype, + scattered_split, ): if not hasattr(torch.ops._C, "fused_qk_norm_rope"): pytest.skip("fused_qk_norm_rope custom op not available") @@ -161,13 +175,15 @@ def test_qk_norm_rope_fusion( is_neox=is_neox, vllm_config=vllm_config, dtype=dtype, + test_scattered_split=scattered_split, ) noop_pass = NoOpEliminationPass(vllm_config) + coalesce_pass = SplitCoalescingPass(vllm_config) fusion_pass = QKNormRoPEFusionPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) - backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + backend = TestBackend(noop_pass, coalesce_pass, fusion_pass, cleanup_pass) backend_baseline = TestBackend(noop_pass, cleanup_pass) qkv = torch.randn(T, model.q_size + 2 * model.kv_size) diff --git a/tests/compile/passes/test_split_coalescing.py b/tests/compile/passes/test_split_coalescing.py new file mode 100644 index 000000000..a217a4af9 --- /dev/null +++ b/tests/compile/passes/test_split_coalescing.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm +from tests.compile.backend import TestBackend +from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass +from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig + + +class SplitCoalescingModel(torch.nn.Module): + """Model with 3 separate split_with_sizes calls on the same input, + simulating the B200+FP8 graph where CSE fails to merge them.""" + + def __init__(self, q_size: int, kv_size: int) -> None: + super().__init__() + self.q_size = q_size + self.kv_size = kv_size + + def forward(self, qkv: torch.Tensor): + q, _, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + _, k, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + _, _, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + return q + 1, k + 2, v + 3 + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_split_coalescing(dtype): + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(0) + + q_size, kv_size = 2048, 512 + + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + pass_config=PassConfig(), + ) + ) + with vllm.config.set_current_vllm_config(vllm_config): + coalesce_pass = SplitCoalescingPass(vllm_config) + backend = TestBackend(coalesce_pass) + + model = SplitCoalescingModel(q_size, kv_size) + + T = 5 + qkv = torch.randn(T, q_size + 2 * kv_size) + torch._dynamo.mark_dynamic(qkv, 0) + + result_eager = model(qkv) + + model_compiled = torch.compile(model, backend=backend) + result_compiled = model_compiled(qkv) + + ATOL, RTOL = (2e-3, 2e-3) + for eager, compiled in zip(result_eager, result_compiled): + torch.testing.assert_close(eager, compiled, atol=ATOL, rtol=RTOL) + + assert backend.op_count(torch.ops.aten.split_with_sizes.default) == 1 diff --git a/vllm/compilation/passes/pass_manager.py b/vllm/compilation/passes/pass_manager.py index 2fd74fcd4..d9d3cc30b 100644 --- a/vllm/compilation/passes/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -29,6 +29,7 @@ if current_platform.is_cuda_alike(): from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass from .fusion.rms_quant_fusion import RMSNormQuantFusionPass from .fusion.sequence_parallelism import SequenceParallelismPass + from .utility.split_coalescing import SplitCoalescingPass if current_platform.is_cuda(): from .fusion.allreduce_rms_fusion import AllReduceFusionPass @@ -139,6 +140,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc] self.passes += [AttnFusionPass(config)] if self.pass_config.enable_qk_norm_rope_fusion: + self.passes += [SplitCoalescingPass(config)] self.passes += [QKNormRoPEFusionPass(config)] # needs a functional graph diff --git a/vllm/compilation/passes/utility/split_coalescing.py b/vllm/compilation/passes/utility/split_coalescing.py new file mode 100644 index 000000000..6373eacfa --- /dev/null +++ b/vllm/compilation/passes/utility/split_coalescing.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Coalesce duplicate ``split_with_sizes`` nodes that operate on the same +input tensor with the same split sizes. + +On certain hardware/dtype combinations (e.g. B200 + FP8) the Inductor +graph may contain multiple ``split_with_sizes`` calls on the same tensor +that CSE fails to merge. This pass detects and replaces the duplicates +so that downstream pattern-matching passes (e.g. QK-Norm+RoPE fusion) +see a single split node with all users attached. + +See also: + - vLLM #33295 (original issue) + - PyTorch #174472 (upstream CSE gap) +""" + +import operator + +import torch +from torch import fx + +from vllm.logger import init_logger + +from ..fx_utils import is_func +from ..vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class SplitCoalescingPass(VllmInductorPass): + """Replace duplicate ``split_with_sizes`` nodes with a single canonical + node when they share the same input tensor and split sizes.""" + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph) -> None: + count = 0 + + # Map from input tensor node -> list of split nodes seen so far. + split_nodes: dict[fx.Node, list[fx.Node]] = {} + + for node in graph.nodes: + if not is_func(node, torch.ops.aten.split_with_sizes.default): + continue + if not all(is_func(user, operator.getitem) for user in node.users): + continue + + arg_node, split_sizes = node.args[:2] + + if arg_node not in split_nodes: + split_nodes[arg_node] = [node] + continue + + # Find existing node with same split_sizes + canonical = next( + ( + n + for n in split_nodes[arg_node] + if list(n.args[1]) == list(split_sizes) + ), + None, + ) + if canonical is not None: + node.replace_all_uses_with(canonical) + graph.erase_node(node) + count += 1 + else: + split_nodes[arg_node].append(node) + + logger.debug("Coalesced %d duplicate split_with_sizes nodes", count)