[torch.compile] Cleanup compilation tests and custom passes, add debug utils, fix DCE bug (#23091), fix test (#24376), and prep for custom op matching (#24604) (#24542)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
@@ -16,10 +16,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fx_utils import find_getitem_maybe
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .multi_output_match import MultiOutputMatch
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@@ -50,8 +48,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[
|
||||
kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
||||
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
|
||||
|
||||
|
||||
class FusedRMSQuantKey(NamedTuple):
|
||||
@@ -80,68 +77,6 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||
}
|
||||
|
||||
|
||||
class QuantMultiOutputMatch(MultiOutputMatch):
|
||||
|
||||
def __init__(self, match: pm.Match, quant_op, fused_op):
|
||||
super().__init__(match)
|
||||
assert isinstance(quant_op, OpOverload)
|
||||
assert isinstance(fused_op, OpOverload)
|
||||
self.QUANT_OP = quant_op # in-place quant op
|
||||
self.FUSED_OP = fused_op # in-place fused quant op
|
||||
|
||||
def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node,
|
||||
int]],
|
||||
**kwargs):
|
||||
"""
|
||||
This utility function inserts an auto-functionalized node for FUSED_OP.
|
||||
It also correctly sets its meta value and rebinds the users of the
|
||||
unfused nodes to use the fused node instead.
|
||||
|
||||
:param fused_return_mapping: A dictionary, mapping from getitem indices
|
||||
of the fused node result to a tuple of the old node and a getitem index.
|
||||
:param kwargs: kwargs that get directly forwarded to the auto_fn node
|
||||
|
||||
Example:
|
||||
If we want to replace this graph:
|
||||
_, x1, x2 = auto_fn(op1)
|
||||
_, y1, y2 = auto_fn(op2)
|
||||
|
||||
with
|
||||
_, x1, y2, x2 = auto_fn(FUSED_OP)
|
||||
|
||||
we would call:
|
||||
insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)}
|
||||
|
||||
Note that the 0th element is None for auto-functionalized in-place ops.
|
||||
Hence, others appear 1-indexed.
|
||||
"""
|
||||
fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs)
|
||||
indices = fused_return_mapping.keys()
|
||||
getitem_nodes = self.insert_getitems(fused_node, indices)
|
||||
|
||||
# Prepare the meta value, use a list so it's mutable
|
||||
meta_val = [None] * (max(indices) + 1)
|
||||
|
||||
# Iterate through elements of the tuple produced by fused_node
|
||||
for idx, getitem_node in zip(indices, getitem_nodes):
|
||||
old_node, old_idx = fused_return_mapping[idx]
|
||||
|
||||
# If the old value was never used, the old_getitem might not exist
|
||||
old_getitem = find_getitem_maybe(old_node, old_idx)
|
||||
if old_getitem is not None:
|
||||
# Rebind the users of match getitem nodes to use the new nodes.
|
||||
# The old nodes will be removed by DCE at the end of the pass.
|
||||
old_getitem.replace_all_uses_with(getitem_node)
|
||||
getitem_node.meta["val"] = old_getitem.meta["val"]
|
||||
|
||||
# Extract the appropriate meta value
|
||||
# It is present even if the getitem node does not exist
|
||||
meta_val[idx] = old_node.meta["val"][old_idx]
|
||||
|
||||
# Fix the meta value on the new fused node
|
||||
fused_node.meta["val"] = tuple(meta_val)
|
||||
|
||||
|
||||
class RMSNormQuantPattern:
|
||||
|
||||
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
|
||||
@@ -224,8 +159,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass,
|
||||
record_match: Callable[[MultiOutputMatch], bool]):
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
@@ -271,36 +205,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
extra_check=lambda m: record_match(
|
||||
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
|
||||
|
||||
class Match(QuantMultiOutputMatch):
|
||||
|
||||
def process(self):
|
||||
# Find the nodes in the match that we need to rebind
|
||||
rms_node = self.find_auto_fn(RMS_ADD_OP)
|
||||
quant_node = self.find_auto_fn(self.QUANT_OP)
|
||||
|
||||
assert len(rms_node.users) == 2
|
||||
assert len(quant_node.users) == 1
|
||||
|
||||
# First, insert a new auto_functionalized node for the fused op,
|
||||
# as well as getitem nodes to extract the result and residual.
|
||||
# The auto_fn node returns a tuple of (None, result, residual).
|
||||
#
|
||||
# The resulting graph looks like this:
|
||||
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
|
||||
# result_node_new = at[1]
|
||||
# residual_node_new = at[2]
|
||||
with self.inserting_after_match():
|
||||
# Missing epsilon, scalars cannot be inputs to the pattern
|
||||
kwargs = self.match.kwargs.copy()
|
||||
|
||||
# 0 is always None
|
||||
fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)}
|
||||
self.insert_fused_node(fused_return_mapping,
|
||||
**kwargs,
|
||||
epsilon=rms_node.kwargs["epsilon"])
|
||||
)
|
||||
|
||||
|
||||
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
@@ -317,8 +222,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass,
|
||||
record_match: Callable[[MultiOutputMatch], bool]):
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
@@ -366,39 +270,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
extra_check=lambda m: record_match(
|
||||
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
|
||||
|
||||
class Match(QuantMultiOutputMatch):
|
||||
|
||||
def process(self):
|
||||
# Find the nodes in the match that we need to rebind
|
||||
rms_node = self.find_auto_fn(RMS_OP)
|
||||
quant_node = self.find_auto_fn(self.QUANT_OP)
|
||||
|
||||
assert len(rms_node.users) == 1
|
||||
assert len(quant_node.users) == 2
|
||||
|
||||
# First, insert a new auto_functionalized node for the fused op,
|
||||
# as well as getitem nodes to extract the result and scale.
|
||||
# The auto_fn node returns a tuple of (None, result, scale).
|
||||
#
|
||||
# The resulting graph looks like this:
|
||||
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
|
||||
# result_node_new = at[1]
|
||||
# scale_node_new = at[2]
|
||||
with self.inserting_after_match():
|
||||
# Missing epsilon, scalars cannot be inputs to the pattern
|
||||
kwargs = self.match.kwargs.copy()
|
||||
del kwargs["result_rms"] # not used in the fused op
|
||||
|
||||
fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)}
|
||||
self.insert_fused_node(
|
||||
fused_return_mapping,
|
||||
epsilon=rms_node.kwargs["epsilon"],
|
||||
scale_ub=None, # not used but required
|
||||
residual=None, # not used but required
|
||||
**kwargs)
|
||||
)
|
||||
|
||||
|
||||
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
@@ -415,8 +287,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass,
|
||||
record_match: Callable[[MultiOutputMatch], bool]):
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
@@ -464,137 +335,49 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
extra_check=lambda m: record_match(
|
||||
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
|
||||
|
||||
class Match(QuantMultiOutputMatch):
|
||||
|
||||
def process(self):
|
||||
# Find the nodes in the match that we need to rebind
|
||||
rms_node = self.find_auto_fn(RMS_ADD_OP)
|
||||
quant_node = self.find_auto_fn(self.QUANT_OP)
|
||||
|
||||
assert len(rms_node.users) == 2
|
||||
assert len(quant_node.users) == 2
|
||||
|
||||
# First, insert a new auto_functionalized node for the fused op,
|
||||
# as well as getitem nodes to extract result, scale, and residual.
|
||||
# The auto_fn node returns a tuple (None, result, scale, residual).
|
||||
#
|
||||
# The resulting graph looks like this:
|
||||
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
|
||||
# result_node_new = at[1]
|
||||
# scale_node_new = at[2]
|
||||
# residual_node_new = at[3]
|
||||
with self.inserting_after_match():
|
||||
# Missing epsilon, scalars cannot be inputs to the pattern
|
||||
kwargs = self.match.kwargs.copy()
|
||||
|
||||
fused_return_mapping = {
|
||||
1: (quant_node, 1), # result
|
||||
2: (quant_node, 2), # scale
|
||||
3: (rms_node, 2), # residual
|
||||
}
|
||||
self.insert_fused_node(
|
||||
fused_return_mapping,
|
||||
epsilon=rms_node.kwargs["epsilon"],
|
||||
scale_ub=None, # not used but required
|
||||
**kwargs)
|
||||
)
|
||||
|
||||
|
||||
class FusionPass(VllmInductorPass):
|
||||
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
It also manually processes multi-output matches, as those are broken in
|
||||
the torch pattern matcher.
|
||||
|
||||
Because patterns can only be registered once, the pass is a singleton.
|
||||
This will be addressed in a future version of PyTorch:
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
||||
It also supports fused_add_rms_norm.
|
||||
"""
|
||||
|
||||
_instance: 'Optional[FusionPass]' = None
|
||||
|
||||
@classmethod
|
||||
def instance(cls, config: VllmConfig):
|
||||
"""
|
||||
Get the singleton instance of the FusionPass.
|
||||
If the instance exists, the config is updated but
|
||||
initialization is not repeated.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = FusionPass(config)
|
||||
else:
|
||||
cls._instance.pass_config = config.compilation_config.pass_config
|
||||
return cls._instance
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
assert self.__class__._instance is None, \
|
||||
"FusionPass singleton instance already exists"
|
||||
super().__init__(config)
|
||||
|
||||
self.matches: list[MultiOutputMatch] = []
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="fusion_pass")
|
||||
pass_name="rmsnorm_quant_fusion_pass")
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
RMSNormStaticQuantPattern(epsilon,
|
||||
FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Matches for patterns below have 2 or more outputs,
|
||||
# so we need to process them manually (see process_matches)
|
||||
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns, self.record_match)
|
||||
self.patterns)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns, self.record_match)
|
||||
RMSNormDynamicQuantPattern(epsilon,
|
||||
FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns, self.record_match)
|
||||
self.patterns)
|
||||
|
||||
# WARNING: This is a hack to clear the pattern matcher cache
|
||||
# and allow multiple values of epsilon.
|
||||
torch._inductor.pattern_matcher._seen_patterns.clear()
|
||||
|
||||
def record_match(self, match: MultiOutputMatch) -> bool:
|
||||
# Hijack the extra_check to record the match and
|
||||
# save it for post-processing.
|
||||
self.matches.append(match)
|
||||
|
||||
# Return False to prevent automatic replacement.
|
||||
return False
|
||||
|
||||
def process_matches(self, graph: fx.Graph):
|
||||
"""
|
||||
Manually process multi-output matches and replace them with fused nodes.
|
||||
See MultiOutputMatch for more details.
|
||||
"""
|
||||
for match in self.matches:
|
||||
match.process()
|
||||
|
||||
# Finally, remove matched nodes
|
||||
graph.eliminate_dead_code()
|
||||
assert all(node not in graph.nodes for match in self.matches
|
||||
for node in match.match.nodes)
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_fusion")
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", count)
|
||||
self.dump_graph(graph, "after_pattern_match")
|
||||
|
||||
# Manually process multi-output matches (and run DCE)
|
||||
self.process_matches(graph)
|
||||
logger.debug("Post-processed %s matches", len(self.matches))
|
||||
self.dump_graph(graph, "after_fusion")
|
||||
self.matches.clear()
|
||||
self.end_and_log()
|
||||
def uuid(self) -> Any:
|
||||
return self.hash_source(self, RMSNormQuantPattern,
|
||||
RMSNormStaticQuantPattern,
|
||||
RMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormStaticQuantPattern,
|
||||
FusedAddRMSNormDynamicQuantPattern)
|
||||
|
||||
Reference in New Issue
Block a user