[Bugfix] Fix QK Norm+RoPE fusion pattern matching on B200+FP8 (#33967)
Signed-off-by: Ikenna <ikennachifo@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -101,9 +101,7 @@ qwen3_a3b_fp8 = ModelFusionInfo(
|
|||||||
model_name="Qwen/Qwen3-30B-A3B-FP8",
|
model_name="Qwen/Qwen3-30B-A3B-FP8",
|
||||||
matches=lambda n_layers: Matches(
|
matches=lambda n_layers: Matches(
|
||||||
rms_quant_fusion=n_layers,
|
rms_quant_fusion=n_layers,
|
||||||
# TODO broken on Blackwell:
|
norm_rope_fusion=n_layers,
|
||||||
# https://github.com/vllm-project/vllm/issues/33295
|
|
||||||
norm_rope_fusion=0 if is_blackwell() else n_layers,
|
|
||||||
attn_quant_fusion=0, # attn + group quant not supported
|
attn_quant_fusion=0, # attn + group quant not supported
|
||||||
ar_rms_fusion=n_layers * 2 + 1,
|
ar_rms_fusion=n_layers * 2 + 1,
|
||||||
sequence_parallel=n_layers * 2 + 1,
|
sequence_parallel=n_layers * 2 + 1,
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from vllm.compilation.passes.fusion.qk_norm_rope_fusion import (
|
|||||||
)
|
)
|
||||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||||
|
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CompilationConfig,
|
CompilationConfig,
|
||||||
CompilationMode,
|
CompilationMode,
|
||||||
@@ -45,6 +46,7 @@ class QKNormRoPETestModel(torch.nn.Module):
|
|||||||
is_neox: bool,
|
is_neox: bool,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
test_scattered_split: bool = False,
|
||||||
prefix: str = "model.layers.0.self_attn.attn",
|
prefix: str = "model.layers.0.self_attn.attn",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -78,11 +80,17 @@ class QKNormRoPETestModel(torch.nn.Module):
|
|||||||
is_neox_style=is_neox,
|
is_neox_style=is_neox,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
self.test_scattered_split = test_scattered_split
|
||||||
self.enable_rms_norm_custom_op = self.q_norm.enabled()
|
self.enable_rms_norm_custom_op = self.q_norm.enabled()
|
||||||
self.enable_rope_custom_op = self.rotary_emb.enabled()
|
self.enable_rope_custom_op = self.rotary_emb.enabled()
|
||||||
|
|
||||||
def forward(self, qkv: torch.Tensor, positions: torch.Tensor):
|
def forward(self, qkv: torch.Tensor, positions: torch.Tensor):
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
if self.test_scattered_split:
|
||||||
|
q, _, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
_, k, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
_, _, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
else:
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
||||||
q_by_head = self.q_norm(q_by_head)
|
q_by_head = self.q_norm(q_by_head)
|
||||||
q = q_by_head.view(q.shape)
|
q = q_by_head.view(q.shape)
|
||||||
@@ -112,6 +120,7 @@ class QKNormRoPETestModel(torch.nn.Module):
|
|||||||
return [FUSED_QK_ROPE_OP]
|
return [FUSED_QK_ROPE_OP]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("scattered_split", [True, False])
|
||||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||||
@pytest.mark.parametrize("is_neox", [True, False])
|
@pytest.mark.parametrize("is_neox", [True, False])
|
||||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||||
@@ -122,7 +131,12 @@ class QKNormRoPETestModel(torch.nn.Module):
|
|||||||
reason="Only test on cuda and rocm platform",
|
reason="Only test on cuda and rocm platform",
|
||||||
)
|
)
|
||||||
def test_qk_norm_rope_fusion(
|
def test_qk_norm_rope_fusion(
|
||||||
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype
|
eps,
|
||||||
|
is_neox,
|
||||||
|
enable_rms_norm_custom_op,
|
||||||
|
enable_rope_custom_op,
|
||||||
|
dtype,
|
||||||
|
scattered_split,
|
||||||
):
|
):
|
||||||
if not hasattr(torch.ops._C, "fused_qk_norm_rope"):
|
if not hasattr(torch.ops._C, "fused_qk_norm_rope"):
|
||||||
pytest.skip("fused_qk_norm_rope custom op not available")
|
pytest.skip("fused_qk_norm_rope custom op not available")
|
||||||
@@ -161,13 +175,15 @@ def test_qk_norm_rope_fusion(
|
|||||||
is_neox=is_neox,
|
is_neox=is_neox,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
test_scattered_split=scattered_split,
|
||||||
)
|
)
|
||||||
|
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
|
coalesce_pass = SplitCoalescingPass(vllm_config)
|
||||||
fusion_pass = QKNormRoPEFusionPass(vllm_config)
|
fusion_pass = QKNormRoPEFusionPass(vllm_config)
|
||||||
cleanup_pass = PostCleanupPass(vllm_config)
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
|
|
||||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
backend = TestBackend(noop_pass, coalesce_pass, fusion_pass, cleanup_pass)
|
||||||
backend_baseline = TestBackend(noop_pass, cleanup_pass)
|
backend_baseline = TestBackend(noop_pass, cleanup_pass)
|
||||||
|
|
||||||
qkv = torch.randn(T, model.q_size + 2 * model.kv_size)
|
qkv = torch.randn(T, model.q_size + 2 * model.kv_size)
|
||||||
|
|||||||
62
tests/compile/passes/test_split_coalescing.py
Normal file
62
tests/compile/passes/test_split_coalescing.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
from tests.compile.backend import TestBackend
|
||||||
|
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
|
||||||
|
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SplitCoalescingModel(torch.nn.Module):
|
||||||
|
"""Model with 3 separate split_with_sizes calls on the same input,
|
||||||
|
simulating the B200+FP8 graph where CSE fails to merge them."""
|
||||||
|
|
||||||
|
def __init__(self, q_size: int, kv_size: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.q_size = q_size
|
||||||
|
self.kv_size = kv_size
|
||||||
|
|
||||||
|
def forward(self, qkv: torch.Tensor):
|
||||||
|
q, _, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
_, k, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
_, _, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
return q + 1, k + 2, v + 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
|
def test_split_coalescing(dtype):
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
q_size, kv_size = 2048, 512
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
|
pass_config=PassConfig(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
|
coalesce_pass = SplitCoalescingPass(vllm_config)
|
||||||
|
backend = TestBackend(coalesce_pass)
|
||||||
|
|
||||||
|
model = SplitCoalescingModel(q_size, kv_size)
|
||||||
|
|
||||||
|
T = 5
|
||||||
|
qkv = torch.randn(T, q_size + 2 * kv_size)
|
||||||
|
torch._dynamo.mark_dynamic(qkv, 0)
|
||||||
|
|
||||||
|
result_eager = model(qkv)
|
||||||
|
|
||||||
|
model_compiled = torch.compile(model, backend=backend)
|
||||||
|
result_compiled = model_compiled(qkv)
|
||||||
|
|
||||||
|
ATOL, RTOL = (2e-3, 2e-3)
|
||||||
|
for eager, compiled in zip(result_eager, result_compiled):
|
||||||
|
torch.testing.assert_close(eager, compiled, atol=ATOL, rtol=RTOL)
|
||||||
|
|
||||||
|
assert backend.op_count(torch.ops.aten.split_with_sizes.default) == 1
|
||||||
@@ -29,6 +29,7 @@ if current_platform.is_cuda_alike():
|
|||||||
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
|
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
|
||||||
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
||||||
from .fusion.sequence_parallelism import SequenceParallelismPass
|
from .fusion.sequence_parallelism import SequenceParallelismPass
|
||||||
|
from .utility.split_coalescing import SplitCoalescingPass
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
|
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
|
||||||
@@ -139,6 +140,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
|||||||
self.passes += [AttnFusionPass(config)]
|
self.passes += [AttnFusionPass(config)]
|
||||||
|
|
||||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||||
|
self.passes += [SplitCoalescingPass(config)]
|
||||||
self.passes += [QKNormRoPEFusionPass(config)]
|
self.passes += [QKNormRoPEFusionPass(config)]
|
||||||
|
|
||||||
# needs a functional graph
|
# needs a functional graph
|
||||||
|
|||||||
70
vllm/compilation/passes/utility/split_coalescing.py
Normal file
70
vllm/compilation/passes/utility/split_coalescing.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Coalesce duplicate ``split_with_sizes`` nodes that operate on the same
|
||||||
|
input tensor with the same split sizes.
|
||||||
|
|
||||||
|
On certain hardware/dtype combinations (e.g. B200 + FP8) the Inductor
|
||||||
|
graph may contain multiple ``split_with_sizes`` calls on the same tensor
|
||||||
|
that CSE fails to merge. This pass detects and replaces the duplicates
|
||||||
|
so that downstream pattern-matching passes (e.g. QK-Norm+RoPE fusion)
|
||||||
|
see a single split node with all users attached.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- vLLM #33295 (original issue)
|
||||||
|
- PyTorch #174472 (upstream CSE gap)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import operator
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import fx
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
from ..fx_utils import is_func
|
||||||
|
from ..vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SplitCoalescingPass(VllmInductorPass):
|
||||||
|
"""Replace duplicate ``split_with_sizes`` nodes with a single canonical
|
||||||
|
node when they share the same input tensor and split sizes."""
|
||||||
|
|
||||||
|
@VllmInductorPass.time_and_log
|
||||||
|
def __call__(self, graph: fx.Graph) -> None:
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# Map from input tensor node -> list of split nodes seen so far.
|
||||||
|
split_nodes: dict[fx.Node, list[fx.Node]] = {}
|
||||||
|
|
||||||
|
for node in graph.nodes:
|
||||||
|
if not is_func(node, torch.ops.aten.split_with_sizes.default):
|
||||||
|
continue
|
||||||
|
if not all(is_func(user, operator.getitem) for user in node.users):
|
||||||
|
continue
|
||||||
|
|
||||||
|
arg_node, split_sizes = node.args[:2]
|
||||||
|
|
||||||
|
if arg_node not in split_nodes:
|
||||||
|
split_nodes[arg_node] = [node]
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find existing node with same split_sizes
|
||||||
|
canonical = next(
|
||||||
|
(
|
||||||
|
n
|
||||||
|
for n in split_nodes[arg_node]
|
||||||
|
if list(n.args[1]) == list(split_sizes)
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if canonical is not None:
|
||||||
|
node.replace_all_uses_with(canonical)
|
||||||
|
graph.erase_node(node)
|
||||||
|
count += 1
|
||||||
|
else:
|
||||||
|
split_nodes[arg_node].append(node)
|
||||||
|
|
||||||
|
logger.debug("Coalesced %d duplicate split_with_sizes nodes", count)
|
||||||
Reference in New Issue
Block a user