[Model Runner V2] Limit cudagraph size to max decode batch size (#29221)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-11-21 20:21:35 -08:00
committed by GitHub
parent 1489902b53
commit e9056056fb

View File

@@ -27,9 +27,11 @@ class CudaGraphManager:
device: torch.device, device: torch.device,
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.device = device self.device = device
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None assert self.compilation_config is not None
@@ -39,9 +41,11 @@ class CudaGraphManager:
else: else:
self.cudagraph_mode = self.compilation_config.cudagraph_mode self.cudagraph_mode = self.compilation_config.cudagraph_mode
if self.compilation_config.cudagraph_capture_sizes is not None: if self.compilation_config.cudagraph_capture_sizes is not None:
self.cudagraph_sizes = sorted( cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
self.compilation_config.cudagraph_capture_sizes # Limit the cudagraph sizes to the max decode batch size.
) self.cudagraph_sizes = [
x for x in cudagraph_sizes if x <= self.max_num_reqs
]
else: else:
self.cudagraph_sizes = [] self.cudagraph_sizes = []
self.padded_sizes = self._init_padded_sizes() self.padded_sizes = self._init_padded_sizes()
@@ -54,9 +58,10 @@ class CudaGraphManager:
if not self.cudagraph_mode.has_full_cudagraphs(): if not self.cudagraph_mode.has_full_cudagraphs():
# Full cuda graphs are not used. # Full cuda graphs are not used.
return {} return {}
if not self.cudagraph_sizes:
return {}
padded_sizes: dict[int, int] = {} padded_sizes: dict[int, int] = {}
assert len(self.cudagraph_sizes) > 0
for i in range(1, self.cudagraph_sizes[-1] + 1): for i in range(1, self.cudagraph_sizes[-1] + 1):
for x in self.cudagraph_sizes: for x in self.cudagraph_sizes:
if i <= x: if i <= x: