[torch.compile] CUDAGraph Inductor partition integration (#24281)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@@ -10,9 +11,13 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from tests.v1.attention.utils import _Backend
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationLevel, PassConfig
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
PassConfig)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
|
||||
@@ -105,6 +110,18 @@ def test_full_graph(
|
||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
debug_dump_path=tempfile.gettempdir()),
|
||||
("facebook/opt-125m", {})),
|
||||
] + [
|
||||
# graph inductor partition
|
||||
(
|
||||
CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
# inductor graph partition uses
|
||||
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
|
||||
use_inductor_graph_partition=True,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
compile_sizes=[1, 2]),
|
||||
model) for model in models_list(all=False)
|
||||
if is_torch_equal_or_newer("2.9.0.dev")
|
||||
])
|
||||
# only test some of the models
|
||||
@create_new_process_for_each_test()
|
||||
@@ -112,11 +129,51 @@ def test_custom_compile_config(
|
||||
compilation_config: CompilationConfig,
|
||||
model_info: tuple[str, dict[str, Any]],
|
||||
):
|
||||
if (compilation_config.use_inductor_graph_partition
|
||||
and not is_torch_equal_or_newer("2.9.0.dev")):
|
||||
pytest.skip("inductor graph partition is only available "
|
||||
"in PyTorch 2.9+")
|
||||
|
||||
model, model_kwargs = model_info
|
||||
print(f"MODEL={model}")
|
||||
run_model(compilation_config, model, model_kwargs)
|
||||
|
||||
|
||||
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available "
|
||||
"in PyTorch 2.9+")
|
||||
|
||||
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_inductor_graph_partition=True,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
custom_ops=["+quant_fp8"],
|
||||
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
|
||||
)
|
||||
model_kwargs = {
|
||||
"kv_cache_dtype": "fp8",
|
||||
"max_model_len": 1024,
|
||||
}
|
||||
with caplog_vllm.at_level(
|
||||
logging.DEBUG), global_force_attn_backend_context_manager(
|
||||
_Backend.FLASHINFER):
|
||||
run_model(compilation_config, model, model_kwargs)
|
||||
|
||||
try:
|
||||
assert ("Fused quantization onto 48 attention nodes"
|
||||
in caplog_vllm.text), caplog_vllm.text
|
||||
except AssertionError:
|
||||
# Note: this message is only triggered when the compilation goes
|
||||
# through the custom pass. Due to multiple layers of cache on
|
||||
# PyTorch side, the compilation of a graph may be cached such
|
||||
# that custom pass directly goes through cache. In this case,
|
||||
# we go through this branch and assert that the pass is not
|
||||
# triggered.
|
||||
assert "Fused quantization" not in caplog_vllm.text
|
||||
|
||||
|
||||
def run_model(compile_config: Union[int, CompilationConfig], model: str,
|
||||
model_kwargs: dict[str, Any]):
|
||||
prompts = [
|
||||
|
||||
Reference in New Issue
Block a user