[MISC] cudagraph_capture_sizes related improvements (#26016)
Signed-off-by: fhl <2410591650@qq.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -8,6 +9,8 @@ from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
|
||||
|
||||
|
||||
@@ -233,3 +236,73 @@ def test_resolve_operator_overload():
|
||||
assert len(resolved) == 2 # Only 2 valid ops
|
||||
assert resolved[0] is torch.ops.aten.mm.default
|
||||
assert resolved[1] is torch.ops.aten.addmm.default
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.support_static_graph_mode(),
|
||||
reason="Skip if not cudagraph mode supported",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
(
|
||||
"cudagraph_capture_sizes",
|
||||
"max_cudagraph_capture_size",
|
||||
"tp_size",
|
||||
"enable_sequence_parallelism",
|
||||
"max_num_batched_tokens",
|
||||
"use_cudagraph",
|
||||
"expected_max_size",
|
||||
),
|
||||
[
|
||||
(None, None, 1, False, 2048, True, 512),
|
||||
([1, 2, 4], 4, 1, False, 2048, True, 4),
|
||||
([1, 2, 4], 8, 1, False, 2048, True, RuntimeError),
|
||||
([1, 256], None, 1, False, 2048, 256),
|
||||
([], None, 1, False, 2048, False, 0),
|
||||
(None, 0, 1, False, 2048, False, 0),
|
||||
# truncated to nearest multiple of 8 or 16
|
||||
(None, 257, 1, False, 2048, True, 256),
|
||||
([1, 2, 4, 15], None, 1, False, 2048, True, 15), # max from list
|
||||
([1, 2, 4, 15], None, 2, True, 2048, True, 4), # filtered out 15 due to SP
|
||||
([1, 2, 4, 15], None, 1, False, 8, True, 4), # limited by the max_tokens
|
||||
# the list should contain at least 1 element when use cudagraph
|
||||
([], None, 1, False, 2048, True, RuntimeError),
|
||||
# the max capturing size should be >= 1 when use cudagraph
|
||||
(None, 0, 1, False, 2048, True, RuntimeError),
|
||||
],
|
||||
)
|
||||
def test_cudagraph_sizes_post_init(
|
||||
cudagraph_capture_sizes,
|
||||
max_cudagraph_capture_size,
|
||||
tp_size,
|
||||
enable_sequence_parallelism,
|
||||
max_num_batched_tokens,
|
||||
use_cudagraph,
|
||||
expected_max_size,
|
||||
):
|
||||
ctx = nullcontext()
|
||||
if isinstance(expected_max_size, Exception):
|
||||
ctx = pytest.raises(expected_max_size)
|
||||
|
||||
cudagraph_mode = CUDAGraphMode.PIECEWISE if use_cudagraph else CUDAGraphMode.NONE
|
||||
with ctx:
|
||||
compilation_config = CompilationConfig(
|
||||
cudagraph_capture_sizes=cudagraph_capture_sizes,
|
||||
max_cudagraph_capture_size=max_cudagraph_capture_size,
|
||||
pass_config={
|
||||
"enable_sequence_parallelism": enable_sequence_parallelism,
|
||||
"enable_fusion": True,
|
||||
"enable_noop": True,
|
||||
},
|
||||
cudagraph_mode=cudagraph_mode,
|
||||
)
|
||||
engine_args = EngineArgs(
|
||||
model="facebook/opt-125m",
|
||||
tensor_parallel_size=tp_size,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
|
||||
assert (
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user