[torch.compile] Reorganize vllm/compilation and tests/compile (0/N for vLLM IR) (#33731)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: ProExpertProg <luka.govedic@gmail.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Luka Govedič
2026-02-06 07:19:49 -05:00
committed by GitHub
parent f79d9dce16
commit ac32e66cf9
47 changed files with 717 additions and 651 deletions

View File

@@ -0,0 +1,83 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import pytest
import torch
from vllm.compilation.passes.inductor_pass import (
CallableInductorPass,
InductorPass,
pass_context,
)
from vllm.compilation.passes.pass_manager import PostGradPassManager
from vllm.config import ModelConfig, VllmConfig
from vllm.config.utils import Range
# dummy custom pass that doesn't inherit
def simple_callable(graph: torch.fx.Graph):
pass
# Should fail to add directly to the pass manager
def test_bad_callable():
config = VllmConfig()
pass_manager = PostGradPassManager()
pass_manager.configure(config)
with pytest.raises(AssertionError):
pass_manager.add(simple_callable) # type: ignore[arg-type]
# Pass that inherits from InductorPass
class ProperPass(InductorPass):
def __call__(self, graph: torch.fx.graph.Graph) -> None:
pass
@pytest.mark.parametrize(
"callable",
[
ProperPass(),
# Can also wrap callables in CallableInductorPass for compliance
CallableInductorPass(simple_callable),
CallableInductorPass(simple_callable, InductorPass.hash_source(__file__)),
],
)
def test_pass_manager_uuid(callable):
# Set the pass context as PassManager uuid uses it
with pass_context(Range(start=1, end=8)):
# Some passes need dtype to be set
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
pass_manager = PostGradPassManager()
pass_manager.configure(config)
# Check that UUID is different if the same pass is added 2x
pass_manager.add(callable)
uuid1 = pass_manager.uuid()
pass_manager.add(callable)
uuid2 = pass_manager.uuid()
assert uuid1 != uuid2
# UUID should be the same as the original one,
# as we constructed in the same way.
pass_manager2 = PostGradPassManager()
pass_manager2.configure(config)
pass_manager2.add(callable)
assert uuid1 == pass_manager2.uuid()
# UUID should be different due to config change
config2 = copy.deepcopy(config)
config2.compilation_config.pass_config.fuse_norm_quant = (
not config2.compilation_config.pass_config.fuse_norm_quant
)
config2.compilation_config.pass_config.fuse_act_quant = (
not config2.compilation_config.pass_config.fuse_act_quant
)
pass_manager3 = PostGradPassManager()
pass_manager3.configure(config2)
pass_manager3.add(callable)
assert uuid1 != pass_manager3.uuid()