[Bugfix] Fix QK Norm+RoPE fusion pattern matching on B200+FP8 (#33967)
Signed-off-by: Ikenna <ikennachifo@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -29,6 +29,7 @@ if current_platform.is_cuda_alike():
|
||||
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
|
||||
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
||||
from .fusion.sequence_parallelism import SequenceParallelismPass
|
||||
from .utility.split_coalescing import SplitCoalescingPass
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
|
||||
@@ -139,6 +140,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||
self.passes += [SplitCoalescingPass(config)]
|
||||
self.passes += [QKNormRoPEFusionPass(config)]
|
||||
|
||||
# needs a functional graph
|
||||
|
||||
70
vllm/compilation/passes/utility/split_coalescing.py
Normal file
70
vllm/compilation/passes/utility/split_coalescing.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Coalesce duplicate ``split_with_sizes`` nodes that operate on the same
|
||||
input tensor with the same split sizes.
|
||||
|
||||
On certain hardware/dtype combinations (e.g. B200 + FP8) the Inductor
|
||||
graph may contain multiple ``split_with_sizes`` calls on the same tensor
|
||||
that CSE fails to merge. This pass detects and replaces the duplicates
|
||||
so that downstream pattern-matching passes (e.g. QK-Norm+RoPE fusion)
|
||||
see a single split node with all users attached.
|
||||
|
||||
See also:
|
||||
- vLLM #33295 (original issue)
|
||||
- PyTorch #174472 (upstream CSE gap)
|
||||
"""
|
||||
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SplitCoalescingPass(VllmInductorPass):
|
||||
"""Replace duplicate ``split_with_sizes`` nodes with a single canonical
|
||||
node when they share the same input tensor and split sizes."""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
count = 0
|
||||
|
||||
# Map from input tensor node -> list of split nodes seen so far.
|
||||
split_nodes: dict[fx.Node, list[fx.Node]] = {}
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, torch.ops.aten.split_with_sizes.default):
|
||||
continue
|
||||
if not all(is_func(user, operator.getitem) for user in node.users):
|
||||
continue
|
||||
|
||||
arg_node, split_sizes = node.args[:2]
|
||||
|
||||
if arg_node not in split_nodes:
|
||||
split_nodes[arg_node] = [node]
|
||||
continue
|
||||
|
||||
# Find existing node with same split_sizes
|
||||
canonical = next(
|
||||
(
|
||||
n
|
||||
for n in split_nodes[arg_node]
|
||||
if list(n.args[1]) == list(split_sizes)
|
||||
),
|
||||
None,
|
||||
)
|
||||
if canonical is not None:
|
||||
node.replace_all_uses_with(canonical)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
else:
|
||||
split_nodes[arg_node].append(node)
|
||||
|
||||
logger.debug("Coalesced %d duplicate split_with_sizes nodes", count)
|
||||
Reference in New Issue
Block a user