Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
154 lines
5.7 KiB
Python
154 lines
5.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import functools
|
|
import operator
|
|
import time
|
|
from typing import ClassVar, Optional
|
|
|
|
import regex as re
|
|
import torch
|
|
from torch._dynamo.utils import lazy_format_graph_code
|
|
from torch._inductor.pattern_matcher import (PatternMatcherPass,
|
|
PatternPrettyPrinter)
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.logger import init_logger
|
|
|
|
from .inductor_pass import InductorPass
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class VllmInductorPass(InductorPass):
|
|
"""
|
|
An inductor pass with access to vLLM PassConfig.
|
|
It provides timing, logging, and dumping utilities.
|
|
"""
|
|
dump_prefix: ClassVar[Optional[int]] = None
|
|
"""Keep track of pass index for debug dump ordering."""
|
|
|
|
def __init__(self, config: VllmConfig):
|
|
self.pass_config = config.compilation_config.pass_config
|
|
self.model_dtype = config.model_config.dtype if config.model_config \
|
|
else None
|
|
self.device = config.device_config.device if config.device_config \
|
|
else None
|
|
self.pass_name = self.__class__.__name__
|
|
|
|
@staticmethod
|
|
def time_and_log(call_fn):
|
|
|
|
@functools.wraps(call_fn)
|
|
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph):
|
|
self.begin()
|
|
self.dump_graph(graph, "before")
|
|
call_fn(self, graph)
|
|
self.dump_graph(graph, "after")
|
|
self.end_and_log()
|
|
|
|
return wrapped
|
|
|
|
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
|
i = VllmInductorPass.dump_prefix
|
|
i_str = "" if i is None else f".{i}"
|
|
lazy_format_graph_code(f"post_grad{i_str}.{self.pass_name}.{stage}",
|
|
graph.owning_module)
|
|
|
|
def begin(self):
|
|
self._start_time = time.perf_counter_ns()
|
|
|
|
def end_and_log(self):
|
|
self._end_time = time.perf_counter_ns()
|
|
duration_ms = float(self._end_time - self._start_time) / 1.0e6
|
|
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
|
|
|
|
|
class VllmPatternMatcherPass(VllmInductorPass):
|
|
"""
|
|
A VllmInductorPass that uses the Inductor pattern matcher.
|
|
Its main use is providing the dump_patterns utility that dumps the
|
|
Inductor pattern matcher patterns into a file, which greatly aids debugging.
|
|
|
|
TODO(luka) move more utilities to this pass.
|
|
"""
|
|
matched_count: int = 0
|
|
"""The number of matched patterns in the pass."""
|
|
|
|
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
|
|
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>")
|
|
|
|
def _replace_op_overloads(self, string: str) -> str:
|
|
"""Replace <OpOverload(..., ...)> with nicer formulations"""
|
|
return self._OP_OVERLOAD_PATTERN.sub(
|
|
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
|
|
string,
|
|
)
|
|
|
|
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
|
|
"""
|
|
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
|
|
into the debug_dump_path folder next to the dumped fx graphs.
|
|
|
|
This method does its best to print something that looks like Python code
|
|
for easier debugging and potentially navigation. If any errors appear in
|
|
the output, please add to this method.
|
|
|
|
TODO(luka): use pattern object to manually produce pattern graph
|
|
"""
|
|
debug_dump_path = config.compile_debug_dump_path()
|
|
if not debug_dump_path:
|
|
return
|
|
|
|
debug_dump_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
from vllm.utils import unique_filepath
|
|
file_path = unique_filepath(
|
|
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py")
|
|
|
|
with file_path.open("w") as f:
|
|
print(
|
|
f'# This file was produced by VllmPatternMatcherPass.'
|
|
f'dump_patterns for {self.pass_name}.\n'
|
|
f'# It does its best to produce valid-Python-looking code but'
|
|
f' please add to dump_patterns if there are any errors.\n\n'
|
|
f'from torch._higher_order_ops.auto_functionalize import '
|
|
f'auto_functionalized as auto_functionalized\n'
|
|
f'from torch._inductor.pattern_matcher import *',
|
|
file=f)
|
|
|
|
for node, patterns in pm_pass.patterns.items():
|
|
# fix the operator.getitem repr
|
|
if node[1] == operator.getitem:
|
|
node_repr = f"({repr(node[0])}, operator.getitem)"
|
|
else:
|
|
node_repr = repr(node)
|
|
|
|
node_repr = self._replace_op_overloads(node_repr)
|
|
|
|
print(f"\n\n# Patterns for op: {node_repr}", file=f)
|
|
for i, pattern in enumerate(patterns):
|
|
# reserve auto_functionalized ahead of time
|
|
pp = PatternPrettyPrinter()
|
|
pp.namespace.create_name("auto_functionalized", None)
|
|
|
|
# Assemble pattern
|
|
out_node = pp.pretty_print(pattern.pattern)
|
|
pattern_repr = "\n".join([f"def pattern_{i}():"] + [
|
|
f"{pp.memoized_objs_names[key]} = "
|
|
f"{pp.memoized_objs_pp[key]}"
|
|
for key in pp.memoized_objs_names
|
|
] + [f"return {out_node}"]).replace("\n", "\n ")
|
|
|
|
pattern_repr = self._replace_op_overloads(pattern_repr)
|
|
print(f"{pattern_repr}\n", file=f)
|
|
|
|
|
|
class PrinterInductorPass(VllmInductorPass):
|
|
|
|
def __init__(self, name: str, config: VllmConfig):
|
|
super().__init__(config)
|
|
self.name = name
|
|
|
|
def __call__(self, graph: torch.fx.Graph):
|
|
self.dump_graph(graph, self.name)
|