diff --git a/docs/design/cuda_graphs.md b/docs/design/cuda_graphs.md index f860ce290..718a4a815 100644 --- a/docs/design/cuda_graphs.md +++ b/docs/design/cuda_graphs.md @@ -225,7 +225,7 @@ outputs = model.generate( ### Piecewise compilation and full graph custom passes (attention fusion, sequence parallelism) -Unfortunately, some custom compile passes have to see the whole graph to be effective and hence aren't compatible with piecewise compilation. This includes `AttnFusionPass` and `SequenceParallelismPass`. As a short-term solution, we automatically disable piecewise compilation (by setting `splitting_ops=[]`) when attention fusion is enabled. We use CUDA Graph modes `FULL` or `FULL_DECODE_ONLY` (depending on backend support). However, this leads to another optimization incompatibility and confusing performance tradeoffs. +Unfortunately, some custom compile passes have to see the whole graph to be effective and hence aren't compatible with piecewise compilation. This includes `AttnQuantFusionPass` and `SequenceParallelismPass`. As a short-term solution, we automatically disable piecewise compilation (by setting `splitting_ops=[]`) when attention fusion is enabled. We use CUDA Graph modes `FULL` or `FULL_DECODE_ONLY` (depending on backend support). However, this leads to another optimization incompatibility and confusing performance tradeoffs. Long term, we've added the ability to partition the graph in Inductor instead of right after Dynamo. It can be enabled with `CompilationConfig.use_inductor_graph_partition=True` but is currently experimental and only available with `torch>=2.9`. This also increases compilation time as it has to compile the whole graph and cannot reuse piecewise compilation artifacts. Once vLLM supports 2.9, we plan to make this the default approach as it will also speed up piecewise cudagraph capture. diff --git a/tests/compile/fusions_e2e/conftest.py b/tests/compile/fusions_e2e/conftest.py index 7cd2acdf5..ca67d90d2 100644 --- a/tests/compile/fusions_e2e/conftest.py +++ b/tests/compile/fusions_e2e/conftest.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging +from collections import defaultdict import pytest import regex as re @@ -52,6 +53,16 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg llm.llm_engine.vllm_config.compilation_config.compile_ranges_endpoints ) + # Fetch match table from each worker via RPC and sum across workers. + worker_tables = llm.llm_engine.engine_core.collective_rpc( + "get_compilation_match_table" + ) + combined: defaultdict[str, int] = defaultdict(int) + for table in worker_tables: + for k, v in table.items(): + combined[k] += v + return dict(combined) + @pytest.fixture def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): @@ -113,7 +124,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): ) with caplog_mp_spawn(logging.DEBUG) as log_holder: - run_model(full_compilation_config, model_name, **model_kwargs) + match_table = run_model(full_compilation_config, model_name, **model_kwargs) num_compile_ranges = len(full_compilation_config.get_compile_ranges()) assert num_compile_ranges in [1, 2, 3] @@ -155,11 +166,14 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): else: num_ranges_activated = num_compile_ranges + # TODO: Remove log counting in unit tests + # once all matchers implement VllmFusionPatternMatcherPass n_expected = tp_size * num_ranges_activated - assert len(log_matches) == n_expected, ( - f"Could not find {n_expected} {match_name} " - f"(found {len(log_matches)}) in:\n {log_holder.text}" - ) + if match_name != "attn_quant_fusion": + assert len(log_matches) == n_expected, ( + f"Could not find {n_expected} {match_name} " + f"(found {len(log_matches)}) in:\n {log_holder.text}" + ) expected_matches = getattr(matches, match_name) @@ -215,6 +229,13 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): f"{tp_size * (num_ranges_activated - 1)} large-range " f"entries (SP took precedence), found: {log_matches}" ) + + elif match_name == "attn_quant_fusion": + actual_match = match_table.get(match_name, 0) + assert actual_match == expected_matches * n_expected, ( + f"Could not find {expected_matches * n_expected} " + f"{match_name} (found {actual_match})." + ) else: expected_matches_list = [expected_matches] * n_expected assert sorted(log_matches) == expected_matches_list, ( diff --git a/tests/compile/passes/test_fusion_attn.py b/tests/compile/passes/test_fusion_attn.py index 94014ca01..2c5ac7b0b 100644 --- a/tests/compile/passes/test_fusion_attn.py +++ b/tests/compile/passes/test_fusion_attn.py @@ -9,7 +9,10 @@ from tests.compile.backend import LazyInitPass, TestBackend from tests.utils import TestFP8Layer, flat_product from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant -from vllm.compilation.passes.fusion.attn_quant_fusion import ATTN_OP, AttnFusionPass +from vllm.compilation.passes.fusion.attn_quant_fusion import ( + ATTN_OP, + AttnQuantFusionPass, +) from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS from vllm.compilation.passes.fx_utils import find_op_nodes from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass @@ -384,7 +387,7 @@ def test_attention_quant_pattern( # Create test backend with fusion passes enabled noop_pass = NoOpEliminationPass(vllm_config) - attn_pass = LazyInitPass(AttnFusionPass, vllm_config) + attn_pass = LazyInitPass(AttnQuantFusionPass, vllm_config) cleanup_pass = PostCleanupPass(vllm_config) test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass) @@ -434,7 +437,7 @@ def test_attention_quant_pattern( # Only output quant ops are fused into attention. test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic) - # access the underlying `AttnFusionPass` on the `LazyInitPass` + # access the underlying `AttnQuantFusionPass` on the `LazyInitPass` assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) # Check attention ops in the graph before and after fusion diff --git a/vllm/compilation/passes/fusion/attn_quant_fusion.py b/vllm/compilation/passes/fusion/attn_quant_fusion.py index 0e1b846af..98d2be387 100644 --- a/vllm/compilation/passes/fusion/attn_quant_fusion.py +++ b/vllm/compilation/passes/fusion/attn_quant_fusion.py @@ -1,15 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod + from collections.abc import Callable -from typing import Any, ParamSpec import torch -import torch._inductor.pattern_matcher as pm -from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized -from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger @@ -22,14 +18,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.platforms import current_platform from vllm.utils.math_utils import round_up -from ..fx_utils import is_func -from ..inductor_pass import enable_fake_mode -from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass +from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement from .matcher_utils import MatcherQuantFP8 -from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 +from .rms_quant_fusion import QUANT_OPS logger = init_logger(__name__) -P = ParamSpec("P") + FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 @@ -37,83 +31,10 @@ ATTN_OP = torch.ops.vllm.unified_attention_with_output.default RESHAPE_OP = torch.ops.aten.reshape.default -class AttentionQuantPattern(ABC): - """ - The base class for Attn+Quant fusions. - Should not be used directly. - """ - - def __init__( - self, - layer: Attention, - quant_key: QuantKey, - dtype: torch.dtype, - ) -> None: - self.layer = layer - self.layer_name = layer.layer_name - self.num_heads = layer.num_heads - self.head_size = layer.head_size - self.quant_key = quant_key - self.quant_dtype = quant_key.dtype - self.dtype = dtype - - assert self.quant_key in QUANT_OPS, ( - f"unsupported quantization scheme {self.quant_key}" - ) - self.QUANT_OP = QUANT_OPS[self.quant_key] - - def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor: - kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs} - return torch.empty(*args, **kwargs) - - def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor: - kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs} - return torch.empty(*args, **kwargs) - - @staticmethod - def wrap_trace_fn( - trace_fn: Callable[P, fx.GraphModule], - *process_fx_fns: Callable[[fx.GraphModule], None], - ) -> Callable[P, fx.GraphModule]: - def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule: - gm = trace_fn(*args, **kwargs) - for process_fx in process_fx_fns: - process_fx(gm) - - return gm - - return wrapped - - @staticmethod - def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None: - from torch._inductor.fx_passes.post_grad import view_to_reshape - - view_to_reshape(gm) - - @staticmethod - def remove_noop_permutes(gm: torch.fx.GraphModule) -> None: - for node in gm.graph.nodes: - if not is_func(node, torch.ops.aten.permute.default): - continue - - dims = node.args[1] - if any(dim != i for i, dim in enumerate(dims)): - continue - - # this is now an identity op, remove - node.replace_all_uses_with(node.args[0]) - gm.graph.erase_node(node) - - def register_if_supported(self, pm_pass: PatternMatcherPass) -> None: - if self.layer.impl.fused_output_quant_supported(self.quant_key): - self._register(pm_pass) - - @abstractmethod - def _register(self, pm_pass: PatternMatcherPass) -> None: - raise NotImplementedError +_FP8_QUANT_KEY = QuantKey(dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=True) -class AttentionFp8StaticQuantPattern(AttentionQuantPattern): +class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]): """ Fusion for Attention+Fp8StaticQuant. @@ -123,20 +44,16 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): will be passed into Attention op as the `output_scale` argument. """ - def __init__( - self, - layer: Attention, - dtype: torch.dtype, - symmetric: bool = True, - ) -> None: - quant_key = QuantKey( - dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric - ) - super().__init__(layer, quant_key, dtype) - self.quant_matcher = MatcherQuantFP8(quant_key) + def __init__(self, layer: Attention, dtype: torch.dtype): + self._layer_name = layer.layer_name + self._num_heads = layer.num_heads + self._head_size = layer.head_size + self._dtype = dtype + self._quant_matcher = MatcherQuantFP8(_FP8_QUANT_KEY) - def _register(self, pm_pass: PatternMatcherPass) -> None: - def pattern( + @property + def pattern(self) -> Callable[..., torch.Tensor]: + def _pattern( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -150,18 +67,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): key=k, value=v, output=output_attn, - layer_name=self.layer_name, + layer_name=self._layer_name, output_scale=None, output_block_scale=None, kv_cache_dummy_dep=kv_cache_dummy_dep, ) attn_out_view = RESHAPE_OP( - at1[1], [q.shape[0], self.num_heads * self.head_size] + at1[1], [q.shape[0], self._num_heads * self._head_size] ) + return self._quant_matcher(attn_out_view, scale)[0] - return self.quant_matcher(attn_out_view, scale)[0] + return _pattern - def replacement( + @property + def replacement(self) -> Callable[..., torch.Tensor]: + def _replacement( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -169,10 +89,9 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): scale: torch.Tensor, kv_cache_dummy_dep: torch.Tensor, ) -> torch.Tensor: - # attn output in quant_dtype output_attn = torch.empty( - [q.shape[0], self.num_heads, self.head_size], - dtype=self.quant_dtype, + [q.shape[0], self._num_heads, self._head_size], + dtype=FP8_DTYPE, device=q.device, ) at1 = auto_functionalized( @@ -181,36 +100,32 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): key=k, value=v, output=output_attn, - layer_name=self.layer_name, + layer_name=self._layer_name, output_scale=scale, output_block_scale=None, kv_cache_dummy_dep=kv_cache_dummy_dep, ) - return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) + return RESHAPE_OP(at1[1], [-1, self._num_heads * self._head_size]) - inputs = [ - self.empty(5, self.num_heads, self.head_size), # q - self.empty(5, self.num_heads, self.head_size), # k - self.empty(5, self.num_heads, self.head_size), # v - self.empty(5, self.num_heads, self.head_size), # attn_output - empty_fp32(1, 1), # scale - self.empty(0), # kv_cache_dummy_dep + return _replacement + + def get_inputs(self): + dtype = self._dtype + num_heads = self._num_heads + head_size = self._head_size + return [ + self.empty(5, num_heads, head_size, dtype=dtype), # q + self.empty(5, num_heads, head_size, dtype=dtype), # k + self.empty(5, num_heads, head_size, dtype=dtype), # v + self.empty(5, num_heads, head_size, dtype=dtype), # attn_output + self.empty_fp32(1, 1), # scale + self.empty(0, dtype=dtype), # kv_cache_dummy_dep ] - pm.register_replacement( - pattern, - replacement, - inputs, - AttentionQuantPattern.wrap_trace_fn( - pm.fwd_only, - AttentionQuantPattern.fx_view_to_reshape, - AttentionQuantPattern.remove_noop_permutes, - ), - pm_pass, - ) - -class AttentionNvfp4QuantPattern(AttentionQuantPattern): +class AttnNvfp4QuantPattern( + VllmPatternReplacement[..., tuple[torch.Tensor, torch.Tensor]] +): """ Fusion for Attention+Nvfp4Quant. @@ -220,11 +135,16 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): will be passed into Attention op as the `output_scale` argument. """ - def __init__(self, layer: Attention, dtype: torch.dtype) -> None: - super().__init__(layer, kNvfp4Dynamic, dtype) + def __init__(self, layer: Attention, dtype: torch.dtype): + self._layer_name = layer.layer_name + self._num_heads = layer.num_heads + self._head_size = layer.head_size + self._dtype = dtype + self._QUANT_OP = QUANT_OPS[kNvfp4Dynamic] - def _register(self, pm_pass: PatternMatcherPass) -> None: - def pattern( + @property + def pattern(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: + def _pattern( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -240,16 +160,16 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): key=k, value=v, output=output_attn, - layer_name=self.layer_name, + layer_name=self._layer_name, output_scale=None, output_block_scale=None, kv_cache_dummy_dep=kv_cache_dummy_dep, ) attn_out_view = RESHAPE_OP( - at1[1], [q.shape[0], self.num_heads * self.head_size] + at1[1], [q.shape[0], self._num_heads * self._head_size] ) at2 = auto_functionalized( - self.QUANT_OP, + self._QUANT_OP, input=attn_out_view, input_scale=input_scale, is_sf_swizzled_layout=True, @@ -259,23 +179,25 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) return at2[1], output_scale_view - def replacement( + return _pattern + + @property + def replacement(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: + def _replacement( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, - output_quant: torch.Tensor, + _output_quant: torch.Tensor, output_scale: torch.Tensor, input_scale: torch.Tensor, kv_cache_dummy_dep: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - # attention output in quant_dtype output_attn = torch.empty( - [q.shape[0], self.num_heads, self.head_size // 2], - dtype=self.quant_dtype, + [q.shape[0], self._num_heads, self._head_size // 2], + dtype=FP4_DTYPE, device=q.device, ) - # attention output block scale output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE) at2 = auto_functionalized( ATTN_OP, @@ -283,41 +205,35 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): key=k, value=v, output=output_attn, - layer_name=self.layer_name, + layer_name=self._layer_name, output_scale=input_scale, output_block_scale=output_scale_view, kv_cache_dummy_dep=kv_cache_dummy_dep, ) - output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2]) + output = RESHAPE_OP(at2[1], [-1, self._num_heads * self._head_size // 2]) return output, at2[2] - inputs = [ - empty_bf16(5, self.num_heads, self.head_size), # q - empty_bf16(5, self.num_heads, self.head_size), # k - empty_bf16(5, self.num_heads, self.head_size), # v - empty_bf16(5, self.num_heads, self.head_size), # output_attn - self.empty_quant(5, self.num_heads * self.head_size // 2), # output_quant - empty_i32( - 128, round_up(self.num_heads * self.head_size // 16, 4) + return _replacement + + def get_inputs(self): + dtype = self._dtype + num_heads = self._num_heads + head_size = self._head_size + return [ + self.empty_bf16(5, num_heads, head_size), # q + self.empty_bf16(5, num_heads, head_size), # k + self.empty_bf16(5, num_heads, head_size), # v + self.empty_bf16(5, num_heads, head_size), # output_attn + self.empty(5, num_heads * head_size // 2, dtype=FP4_DTYPE), # output_quant + self.empty_i32( + 128, round_up(num_heads * head_size // 16, 4) ), # output_scale - empty_fp32(1, 1), # input_scale - self.empty(0), # kv_cache_dummy_dep + self.empty_fp32(1, 1), # input_scale + self.empty(0, dtype=dtype), # kv_cache_dummy_dep ] - pm.register_replacement( - pattern, - replacement, - inputs, - AttentionQuantPattern.wrap_trace_fn( - pm.fwd_only, - AttentionQuantPattern.fx_view_to_reshape, - AttentionQuantPattern.remove_noop_permutes, - ), - pm_pass, - ) - -class AttnFusionPass(VllmPatternMatcherPass): +class AttnQuantFusionPass(VllmFusionPatternMatcherPass): """ This pass fuses post-attention quantization onto attention if supported. @@ -330,43 +246,26 @@ class AttnFusionPass(VllmPatternMatcherPass): support are attention kernels, which need to support fusing output quant. """ - @enable_fake_mode def __init__(self, config: VllmConfig) -> None: - super().__init__(config) + super().__init__(config, "attn_quant_fusion") - self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass") + dtype = config.model_config.dtype + layers = list(get_layers_from_vllm_config(config, Attention).values()) - attn_layers = get_layers_from_vllm_config(config, Attention) - for layer_name, layer in attn_layers.items(): - pattern_fp8 = AttentionFp8StaticQuantPattern( - layer, config.model_config.dtype - ) - pattern_fp8.register_if_supported(self.patterns) - - if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): - pattern_nvfp4 = AttentionNvfp4QuantPattern( - layer, config.model_config.dtype - ) - pattern_nvfp4.register_if_supported(self.patterns) - - if len(attn_layers) == 0: + if len(layers) == 0: logger.warning( "Attention + quant fusion is enabled, but no attention layers " "were found in CompilationConfig.static_forward_context " "so no fusion patterns were registered." ) - self.dump_patterns(config, self.patterns) + for layer in layers: + if layer.impl.fused_output_quant_supported(_FP8_QUANT_KEY): + self.register(AttnFp8StaticQuantPattern(layer, dtype)) - @VllmInductorPass.time_and_log - def __call__(self, graph: torch.fx.graph.Graph) -> None: - self.matched_count = self.patterns.apply(graph) - logger.debug("Fused quant onto %s attention nodes", self.matched_count) + if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): + for layer in layers: + if layer.impl.fused_output_quant_supported(kNvfp4Dynamic): + self.register(AttnNvfp4QuantPattern(layer, dtype)) - def uuid(self) -> str: - return VllmInductorPass.hash_source( - self, - AttentionQuantPattern, - AttentionFp8StaticQuantPattern, - AttentionNvfp4QuantPattern, - ) + self.dump_patterns(config, self.pm_pass) diff --git a/vllm/compilation/passes/pass_manager.py b/vllm/compilation/passes/pass_manager.py index 70f86c8d2..5f75fc8db 100644 --- a/vllm/compilation/passes/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -14,7 +14,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.system_utils import set_env_var -from .vllm_inductor_pass import VllmInductorPass +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass if rocm_aiter_ops.is_enabled(): from .fusion.rocm_aiter_fusion import ( @@ -25,7 +25,7 @@ if rocm_aiter_ops.is_enabled(): if current_platform.is_cuda_alike(): from .fusion.act_quant_fusion import ActivationQuantFusionPass - from .fusion.attn_quant_fusion import AttnFusionPass + from .fusion.attn_quant_fusion import AttnQuantFusionPass from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass from .fusion.rms_quant_fusion import RMSNormQuantFusionPass from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass @@ -108,6 +108,8 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc] self.fix_functionalization(graph) VllmInductorPass.dump_prefix = None # Cleanup index + VllmPatternMatcherPass.log_match_summary() + def configure(self, config: VllmConfig) -> None: self.pass_config = config.compilation_config.pass_config @@ -144,7 +146,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc] self.passes += [RopeKVCacheFusionPass(config)] if self.pass_config.fuse_attn_quant: - self.passes += [AttnFusionPass(config)] + self.passes += [AttnQuantFusionPass(config)] if self.pass_config.enable_qk_norm_rope_fusion: self.passes += [SplitCoalescingPass(config)] diff --git a/vllm/compilation/passes/vllm_inductor_pass.py b/vllm/compilation/passes/vllm_inductor_pass.py index b64c89288..4eac620d1 100644 --- a/vllm/compilation/passes/vllm_inductor_pass.py +++ b/vllm/compilation/passes/vllm_inductor_pass.py @@ -3,19 +3,24 @@ import functools import operator import time +from abc import ABC, abstractmethod +from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass -from typing import ClassVar +from typing import Any, ClassVar, Generic, ParamSpec, TypeVar import regex as re import torch +import torch._inductor.pattern_matcher as pm +from torch import fx from torch._dynamo.utils import lazy_format_graph_code from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter from vllm.config import VllmConfig from vllm.logger import init_logger -from .inductor_pass import InductorPass +from .fx_utils import is_func +from .inductor_pass import InductorPass, enable_fake_mode logger = init_logger(__name__) @@ -79,18 +84,23 @@ class VllmInductorPass(InductorPass): logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) +def get_match_table() -> dict[str, int]: + """Return a snapshot of the match table.""" + return dict(VllmPatternMatcherPass.match_table) + + class VllmPatternMatcherPass(VllmInductorPass): """ A VllmInductorPass that uses the Inductor pattern matcher. - Its main use is providing the dump_patterns utility that dumps the - Inductor pattern matcher patterns into a file, which greatly aids debugging. - - TODO(luka) move more utilities to this pass. + Provides pattern registration with match counting, debug dumping, and logging. """ matched_count: int = 0 """The number of matched patterns in the pass.""" + match_table: ClassVar[defaultdict[str, int]] = defaultdict(int) + """Global table mapping pass name to its total match count.""" + _OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile( r"" ) @@ -104,6 +114,11 @@ class VllmPatternMatcherPass(VllmInductorPass): ) ) + @classmethod + def log_match_summary(cls) -> None: + if cls.match_table: + logger.debug("fusion pass matches: %s", dict(cls.match_table)) + def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass) -> None: """ If debug dumping is enabled, dump the Inductor pattern-matcher patterns @@ -171,6 +186,124 @@ class VllmPatternMatcherPass(VllmInductorPass): print(f"{pattern_repr}\n", file=f) +P = ParamSpec("P") +R = TypeVar("R") + + +class VllmPatternReplacement(ABC, Generic[P, R]): + """ + A pattern/replacement pair for FX graph fusion. + + Implement the three abstract members below, then pass + instances to VllmFusionPatternMatcherPass.register(). The pass will + find every occurrence of `pattern` in the graph and substitute it + with `replacement`. + """ + + # TODO(Badr): bound methods work for pattern registration since + # PyTorch 2.10. Once vLLM requires torch>=2.11, replace these properties + # with plain methods and drop the closure indirection. + @property + @abstractmethod + def pattern(self) -> Callable[P, R]: + """Returns a closure defining the FX subgraph to search for.""" + ... + + @property + @abstractmethod + def replacement(self) -> Callable[P, R]: + """ + Returns a closure defining the FX subgraph to + substitute in place of each match. + """ + ... + + @abstractmethod + def get_inputs(self) -> list[torch.Tensor]: + """Example tensors used to trace pattern and replacement.""" + ... + + # Helpers for get_inputs: uninitialized tensors of common dtypes. + @staticmethod + def empty(*args, **kwargs) -> torch.Tensor: + return torch.empty(*args, device="cuda", **kwargs) + + @staticmethod + def empty_bf16(*args, **kwargs) -> torch.Tensor: + return torch.empty(*args, dtype=torch.bfloat16, device="cuda", **kwargs) + + @staticmethod + def empty_fp16(*args, **kwargs) -> torch.Tensor: + return torch.empty(*args, dtype=torch.float16, device="cuda", **kwargs) + + @staticmethod + def empty_fp32(*args, **kwargs) -> torch.Tensor: + return torch.empty(*args, dtype=torch.float32, device="cuda", **kwargs) + + @staticmethod + def empty_i32(*args, **kwargs) -> torch.Tensor: + return torch.empty(*args, dtype=torch.int32, device="cuda", **kwargs) + + +def _fx_view_to_reshape(gm: fx.GraphModule) -> None: + from torch._inductor.fx_passes.post_grad import view_to_reshape + + view_to_reshape(gm) + + +def _remove_noop_permutes(gm: fx.GraphModule) -> None: + for node in gm.graph.nodes: + if not is_func(node, torch.ops.aten.permute.default): + continue + dims = node.args[1] + if any(dim != i for i, dim in enumerate(dims)): + continue + node.replace_all_uses_with(node.args[0]) + gm.graph.erase_node(node) + + +class VllmFusionPatternMatcherPass(VllmPatternMatcherPass): + """ + A VllmPatternMatcherPass for passes that use VllmPatternReplacement objects. + Subclasses register patterns via self.register() in their own __init__. + """ + + def __init__(self, config: VllmConfig, pass_name: str) -> None: + super().__init__(config) + self.pass_name = pass_name + self.pm_pass = PatternMatcherPass(pass_name=pass_name) + self._pattern_replacements: list[VllmPatternReplacement] = [] + + @enable_fake_mode + def register(self, pr: VllmPatternReplacement) -> None: + pm.register_replacement( + pr.pattern, + pr.replacement, + pr.get_inputs(), + self._trace_fn, + self.pm_pass, + ) + self._pattern_replacements.append(pr) + + def uuid(self) -> str: + return VllmInductorPass.hash_source( + type(self), + *[type(pr) for pr in self._pattern_replacements], + ) + + @staticmethod + def _trace_fn(*args: Any, **kwargs: Any) -> fx.GraphModule: + gm = pm.fwd_only(*args, **kwargs) + _fx_view_to_reshape(gm) + _remove_noop_permutes(gm) + return gm + + @VllmInductorPass.time_and_log + def __call__(self, graph: torch.fx.Graph) -> None: + self.matched_count = self.pm_pass.apply(graph) + VllmPatternMatcherPass.match_table[self.pass_name] += self.matched_count + + class PrinterInductorPass(VllmInductorPass): def __init__(self, name: str, config: VllmConfig) -> None: super().__init__(config) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 91dcdc2b9..ec6bc2a71 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -703,6 +703,11 @@ class Worker(WorkerBase): def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return self.model_runner.get_supported_tasks() + def get_compilation_match_table(self) -> dict[str, int]: + from vllm.compilation.passes.vllm_inductor_pass import get_match_table + + return get_match_table() + def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]: """Get encoder timing stats from model runner.""" return self.model_runner.get_encoder_timing_stats()