[MISC] cudagraph_capture_sizes related improvements (#26016)

Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
fhl2000
2025-10-24 20:11:05 +08:00
committed by GitHub
parent 435be10db9
commit 284cc92275
14 changed files with 303 additions and 110 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from contextlib import nullcontext
import pytest
@@ -8,6 +9,8 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
@@ -233,3 +236,73 @@ def test_resolve_operator_overload():
assert len(resolved) == 2 # Only 2 valid ops
assert resolved[0] is torch.ops.aten.mm.default
assert resolved[1] is torch.ops.aten.addmm.default
@pytest.mark.skipif(
not current_platform.support_static_graph_mode(),
reason="Skip if not cudagraph mode supported",
)
@pytest.mark.parametrize(
(
"cudagraph_capture_sizes",
"max_cudagraph_capture_size",
"tp_size",
"enable_sequence_parallelism",
"max_num_batched_tokens",
"use_cudagraph",
"expected_max_size",
),
[
(None, None, 1, False, 2048, True, 512),
([1, 2, 4], 4, 1, False, 2048, True, 4),
([1, 2, 4], 8, 1, False, 2048, True, RuntimeError),
([1, 256], None, 1, False, 2048, 256),
([], None, 1, False, 2048, False, 0),
(None, 0, 1, False, 2048, False, 0),
# truncated to nearest multiple of 8 or 16
(None, 257, 1, False, 2048, True, 256),
([1, 2, 4, 15], None, 1, False, 2048, True, 15), # max from list
([1, 2, 4, 15], None, 2, True, 2048, True, 4), # filtered out 15 due to SP
([1, 2, 4, 15], None, 1, False, 8, True, 4), # limited by the max_tokens
# the list should contain at least 1 element when use cudagraph
([], None, 1, False, 2048, True, RuntimeError),
# the max capturing size should be >= 1 when use cudagraph
(None, 0, 1, False, 2048, True, RuntimeError),
],
)
def test_cudagraph_sizes_post_init(
cudagraph_capture_sizes,
max_cudagraph_capture_size,
tp_size,
enable_sequence_parallelism,
max_num_batched_tokens,
use_cudagraph,
expected_max_size,
):
ctx = nullcontext()
if isinstance(expected_max_size, Exception):
ctx = pytest.raises(expected_max_size)
cudagraph_mode = CUDAGraphMode.PIECEWISE if use_cudagraph else CUDAGraphMode.NONE
with ctx:
compilation_config = CompilationConfig(
cudagraph_capture_sizes=cudagraph_capture_sizes,
max_cudagraph_capture_size=max_cudagraph_capture_size,
pass_config={
"enable_sequence_parallelism": enable_sequence_parallelism,
"enable_fusion": True,
"enable_noop": True,
},
cudagraph_mode=cudagraph_mode,
)
engine_args = EngineArgs(
model="facebook/opt-125m",
tensor_parallel_size=tp_size,
max_num_batched_tokens=max_num_batched_tokens,
compilation_config=compilation_config,
)
vllm_config = engine_args.create_engine_config()
assert (
vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size
)