[Bugfix] Fix QK Norm+RoPE fusion pattern matching on B200+FP8 (#33967)
Signed-off-by: Ikenna <ikennachifo@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -101,9 +101,7 @@ qwen3_a3b_fp8 = ModelFusionInfo(
|
||||
model_name="Qwen/Qwen3-30B-A3B-FP8",
|
||||
matches=lambda n_layers: Matches(
|
||||
rms_quant_fusion=n_layers,
|
||||
# TODO broken on Blackwell:
|
||||
# https://github.com/vllm-project/vllm/issues/33295
|
||||
norm_rope_fusion=0 if is_blackwell() else n_layers,
|
||||
norm_rope_fusion=n_layers,
|
||||
attn_quant_fusion=0, # attn + group quant not supported
|
||||
ar_rms_fusion=n_layers * 2 + 1,
|
||||
sequence_parallel=n_layers * 2 + 1,
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.compilation.passes.fusion.qk_norm_rope_fusion import (
|
||||
)
|
||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
@@ -45,6 +46,7 @@ class QKNormRoPETestModel(torch.nn.Module):
|
||||
is_neox: bool,
|
||||
vllm_config: VllmConfig,
|
||||
dtype: torch.dtype,
|
||||
test_scattered_split: bool = False,
|
||||
prefix: str = "model.layers.0.self_attn.attn",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -78,10 +80,16 @@ class QKNormRoPETestModel(torch.nn.Module):
|
||||
is_neox_style=is_neox,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.test_scattered_split = test_scattered_split
|
||||
self.enable_rms_norm_custom_op = self.q_norm.enabled()
|
||||
self.enable_rope_custom_op = self.rotary_emb.enabled()
|
||||
|
||||
def forward(self, qkv: torch.Tensor, positions: torch.Tensor):
|
||||
if self.test_scattered_split:
|
||||
q, _, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
_, k, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
_, _, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
else:
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
@@ -112,6 +120,7 @@ class QKNormRoPETestModel(torch.nn.Module):
|
||||
return [FUSED_QK_ROPE_OP]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scattered_split", [True, False])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("is_neox", [True, False])
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@@ -122,7 +131,12 @@ class QKNormRoPETestModel(torch.nn.Module):
|
||||
reason="Only test on cuda and rocm platform",
|
||||
)
|
||||
def test_qk_norm_rope_fusion(
|
||||
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype
|
||||
eps,
|
||||
is_neox,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_rope_custom_op,
|
||||
dtype,
|
||||
scattered_split,
|
||||
):
|
||||
if not hasattr(torch.ops._C, "fused_qk_norm_rope"):
|
||||
pytest.skip("fused_qk_norm_rope custom op not available")
|
||||
@@ -161,13 +175,15 @@ def test_qk_norm_rope_fusion(
|
||||
is_neox=is_neox,
|
||||
vllm_config=vllm_config,
|
||||
dtype=dtype,
|
||||
test_scattered_split=scattered_split,
|
||||
)
|
||||
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
coalesce_pass = SplitCoalescingPass(vllm_config)
|
||||
fusion_pass = QKNormRoPEFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||
backend = TestBackend(noop_pass, coalesce_pass, fusion_pass, cleanup_pass)
|
||||
backend_baseline = TestBackend(noop_pass, cleanup_pass)
|
||||
|
||||
qkv = torch.randn(T, model.q_size + 2 * model.kv_size)
|
||||
|
||||
62
tests/compile/passes/test_split_coalescing.py
Normal file
62
tests/compile/passes/test_split_coalescing.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm
|
||||
from tests.compile.backend import TestBackend
|
||||
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
|
||||
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
|
||||
|
||||
|
||||
class SplitCoalescingModel(torch.nn.Module):
|
||||
"""Model with 3 separate split_with_sizes calls on the same input,
|
||||
simulating the B200+FP8 graph where CSE fails to merge them."""
|
||||
|
||||
def __init__(self, q_size: int, kv_size: int) -> None:
|
||||
super().__init__()
|
||||
self.q_size = q_size
|
||||
self.kv_size = kv_size
|
||||
|
||||
def forward(self, qkv: torch.Tensor):
|
||||
q, _, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
_, k, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
_, _, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
return q + 1, k + 2, v + 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_split_coalescing(dtype):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
q_size, kv_size = 2048, 512
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(),
|
||||
)
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
coalesce_pass = SplitCoalescingPass(vllm_config)
|
||||
backend = TestBackend(coalesce_pass)
|
||||
|
||||
model = SplitCoalescingModel(q_size, kv_size)
|
||||
|
||||
T = 5
|
||||
qkv = torch.randn(T, q_size + 2 * kv_size)
|
||||
torch._dynamo.mark_dynamic(qkv, 0)
|
||||
|
||||
result_eager = model(qkv)
|
||||
|
||||
model_compiled = torch.compile(model, backend=backend)
|
||||
result_compiled = model_compiled(qkv)
|
||||
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
for eager, compiled in zip(result_eager, result_compiled):
|
||||
torch.testing.assert_close(eager, compiled, atol=ATOL, rtol=RTOL)
|
||||
|
||||
assert backend.op_count(torch.ops.aten.split_with_sizes.default) == 1
|
||||
@@ -29,6 +29,7 @@ if current_platform.is_cuda_alike():
|
||||
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
|
||||
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
||||
from .fusion.sequence_parallelism import SequenceParallelismPass
|
||||
from .utility.split_coalescing import SplitCoalescingPass
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
|
||||
@@ -139,6 +140,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||
self.passes += [SplitCoalescingPass(config)]
|
||||
self.passes += [QKNormRoPEFusionPass(config)]
|
||||
|
||||
# needs a functional graph
|
||||
|
||||
70
vllm/compilation/passes/utility/split_coalescing.py
Normal file
70
vllm/compilation/passes/utility/split_coalescing.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Coalesce duplicate ``split_with_sizes`` nodes that operate on the same
|
||||
input tensor with the same split sizes.
|
||||
|
||||
On certain hardware/dtype combinations (e.g. B200 + FP8) the Inductor
|
||||
graph may contain multiple ``split_with_sizes`` calls on the same tensor
|
||||
that CSE fails to merge. This pass detects and replaces the duplicates
|
||||
so that downstream pattern-matching passes (e.g. QK-Norm+RoPE fusion)
|
||||
see a single split node with all users attached.
|
||||
|
||||
See also:
|
||||
- vLLM #33295 (original issue)
|
||||
- PyTorch #174472 (upstream CSE gap)
|
||||
"""
|
||||
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SplitCoalescingPass(VllmInductorPass):
|
||||
"""Replace duplicate ``split_with_sizes`` nodes with a single canonical
|
||||
node when they share the same input tensor and split sizes."""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
count = 0
|
||||
|
||||
# Map from input tensor node -> list of split nodes seen so far.
|
||||
split_nodes: dict[fx.Node, list[fx.Node]] = {}
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, torch.ops.aten.split_with_sizes.default):
|
||||
continue
|
||||
if not all(is_func(user, operator.getitem) for user in node.users):
|
||||
continue
|
||||
|
||||
arg_node, split_sizes = node.args[:2]
|
||||
|
||||
if arg_node not in split_nodes:
|
||||
split_nodes[arg_node] = [node]
|
||||
continue
|
||||
|
||||
# Find existing node with same split_sizes
|
||||
canonical = next(
|
||||
(
|
||||
n
|
||||
for n in split_nodes[arg_node]
|
||||
if list(n.args[1]) == list(split_sizes)
|
||||
),
|
||||
None,
|
||||
)
|
||||
if canonical is not None:
|
||||
node.replace_all_uses_with(canonical)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
else:
|
||||
split_nodes[arg_node].append(node)
|
||||
|
||||
logger.debug("Coalesced %d duplicate split_with_sizes nodes", count)
|
||||
Reference in New Issue
Block a user