[Misc]add configurable cuda graph size (#17201)

Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Chen Xia
2025-05-01 11:04:50 -07:00
committed by GitHub
parent 4acfa3354a
commit 61c299f81f
2 changed files with 22 additions and 3 deletions

View File

@@ -1865,6 +1865,13 @@ class SchedulerConfig:
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
cuda_graph_sizes: list[int] = field(default_factory=lambda: [512])
"""Cuda graph capture sizes, default is 512.
1. if one value is provided, then the capture list would follow the pattern:
[1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)]
2. more than one value (e.g. 1 2 128) is provided,
then the capture list will follow the provided list."""
max_num_seqs: int = None # type: ignore
"""Maximum number of sequences to be processed in a single iteration.
@@ -4235,13 +4242,20 @@ class VllmConfig:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]
cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes
if len(cuda_graph_sizes) == 1:
batch_size_capture_list = [1, 2, 4] + [
i for i in range(8, cuda_graph_sizes[0] + 1, 8)
]
elif len(cuda_graph_sizes) > 1:
batch_size_capture_list = sorted(cuda_graph_sizes)
else:
raise TypeError(
f"Invalid value for {cuda_graph_sizes=}.")
if self.parallel_config.tensor_parallel_size > 1 and \
self.compilation_config.pass_config.enable_sequence_parallelism:
batch_size_capture_list = \
self.update_sizes_for_sequence_parallelism(batch_size_capture_list)
max_num_tokens = self.scheduler_config.max_num_batched_tokens
batch_size_capture_list = [
size for size in batch_size_capture_list