[torch.compile] Make inductor partition rules respect splitting_ops #25691 (#25845)

Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
baonudesifeizhai
2025-10-10 12:35:28 -04:00
committed by GitHub
parent e519281920
commit cddce79fda
9 changed files with 267 additions and 112 deletions

View File

@@ -4,10 +4,12 @@ import pytest
from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.utils import _is_torch_equal_or_newer
from vllm.config.compilation import CompilationLevel
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
def test_version():
# Test the version comparison logic using the private function
assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
@@ -17,6 +19,9 @@ def test_version():
def test_use_cudagraphs_dynamic():
vllm_config = VllmConfig()
# Default V1 configuration now starts without cudagraphs enabled; the
# engine decides when to capture based on runtime settings instead of a
# blanket default.
assert vllm_config.compilation_config.use_cudagraph
@@ -137,58 +142,77 @@ def test_enforce_eager(vllm_runner, monkeypatch):
def test_splitting_ops_dynamic():
# Default config
config = VllmConfig()
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
assert config.compilation_config.splitting_ops_contain_attention()
# Default V1 config leaves cudagraph mode unset; splitting ops are only
# populated when the engine decides to use piecewise compilation.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
assert not config.compilation_config.splitting_ops_contain_attention()
# When use_inductor_graph_partition=True
if _is_torch_equal_or_newer("2.9.0.dev"):
# inductor graph partition is only available in PyTorch 2.9+.
# this is a fast config check so we are not using pytest.skip.
if is_torch_equal_or_newer("2.9.0.dev"):
config = VllmConfig(
compilation_config=CompilationConfig(
use_inductor_graph_partition=True, splitting_ops=["silly_attention"]
level=CompilationLevel.PIECEWISE,
use_inductor_graph_partition=True,
splitting_ops=["vllm::unified_attention"],
)
)
# should ignore splitting_ops
assert config.compilation_config.splitting_ops == []
# with inductor partition we use splitting_ops directly for
# partition rules
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
# When attn_fusion pass enabled.
# When attn_fusion pass enabled, splitting_ops now default to attention ops.
config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
assert config.compilation_config.splitting_ops == []
# cudagraph mode also fall back to FULL
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
# splitting_ops can not contain attention ops when attn_fusion
# pass enabled.
with pytest.raises(AssertionError):
config = VllmConfig(
compilation_config=CompilationConfig(
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
# work around for accessing all attntion ops
splitting_ops=CompilationConfig()._attention_ops,
)
)
# With the new simplified logic, attention fusion works with splitting_ops
assert config.compilation_config.splitting_ops_contain_attention()
# cudagraph mode remains PIECEWISE
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
# When both use_inductor_graph_partition and attn_fusion pass enabled.
if _is_torch_equal_or_newer("2.9.0.dev"):
if is_torch_equal_or_newer("2.9.0.dev"):
config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_inductor_graph_partition=True,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
assert config.compilation_config.splitting_ops == []
# enable_attn_fusion is directly support under
# With inductor graph partition, attn_fusion and splitting_ops
# work together. Default splitting_ops include attention ops.
assert config.compilation_config.splitting_ops_contain_attention()
# enable_attn_fusion is directly supported under
# use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
def test_resolve_operator_overload():
import torch
from vllm.compilation.partition_rules import resolve_defined_ops
# Test valid operator names
resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
assert len(resolved) == 2
assert resolved[0] is torch.ops.aten.mm.default
assert resolved[1] is torch.ops.aten.addmm.default
# Test that invalid operators are skipped (not raising exceptions)
resolved = resolve_defined_ops(
[
"aten::mm.default",
"aten::nonexistent_op.default", # This should be skipped
"aten::addmm.default",
]
)
assert len(resolved) == 2 # Only 2 valid ops
assert resolved[0] is torch.ops.aten.mm.default
assert resolved[1] is torch.ops.aten.addmm.default