[Bugfix] set default set cuda_graph_sizes to min(self.max_num_seqs * 2, 512) (#20628)

Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
This commit is contained in:
zhrrr
2025-07-09 11:02:51 +08:00
committed by GitHub
parent 6db31e7a27
commit 34dad19e7b

View File

@@ -2147,11 +2147,12 @@ class SchedulerConfig:
NOTE: This will be replaced by speculative config in the future; it is NOTE: This will be replaced by speculative config in the future; it is
present to enable correctness tests until then.""" present to enable correctness tests until then."""
cuda_graph_sizes: list[int] = field(default_factory=lambda: [512]) cuda_graph_sizes: list[int] = field(default_factory=list)
"""Cuda graph capture sizes, default is 512. """Cuda graph capture sizes
1. if one value is provided, then the capture list would follow the 1. if none provided, then default set to [min(max_num_seqs * 2, 512)]
2. if one value is provided, then the capture list would follow the
pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)] pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)]
2. more than one value (e.g. 1 2 128) is provided, then the capture list 3. more than one value (e.g. 1 2 128) is provided, then the capture list
will follow the provided list.""" will follow the provided list."""
delay_factor: float = 0.0 delay_factor: float = 0.0
@@ -2316,6 +2317,13 @@ class SchedulerConfig:
self.max_num_partial_prefills, self.max_long_partial_prefills, self.max_num_partial_prefills, self.max_long_partial_prefills,
self.long_prefill_token_threshold) self.long_prefill_token_threshold)
# NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)].
# This avoids OOM in tight memory scenarios with small max_num_seqs,
# and prevents capture of many large graphs (>512) that would greatly
# increase startup time with limited performance benefit.
if not self.cuda_graph_sizes:
self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)]
@model_validator(mode='after') @model_validator(mode='after')
def _verify_args(self) -> Self: def _verify_args(self) -> Self:
if (self.max_num_batched_tokens < self.max_model_len if (self.max_num_batched_tokens < self.max_model_len