[Model Runner V2] Limit cudagraph size to max decode batch size (#29221)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -27,9 +27,11 @@ class CudaGraphManager:
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
self.scheduler_config = vllm_config.scheduler_config
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
self.max_model_len = vllm_config.model_config.max_model_len
|
self.max_model_len = vllm_config.model_config.max_model_len
|
||||||
|
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
assert self.compilation_config is not None
|
assert self.compilation_config is not None
|
||||||
@@ -39,9 +41,11 @@ class CudaGraphManager:
|
|||||||
else:
|
else:
|
||||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||||
if self.compilation_config.cudagraph_capture_sizes is not None:
|
if self.compilation_config.cudagraph_capture_sizes is not None:
|
||||||
self.cudagraph_sizes = sorted(
|
cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
|
||||||
self.compilation_config.cudagraph_capture_sizes
|
# Limit the cudagraph sizes to the max decode batch size.
|
||||||
)
|
self.cudagraph_sizes = [
|
||||||
|
x for x in cudagraph_sizes if x <= self.max_num_reqs
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
self.cudagraph_sizes = []
|
self.cudagraph_sizes = []
|
||||||
self.padded_sizes = self._init_padded_sizes()
|
self.padded_sizes = self._init_padded_sizes()
|
||||||
@@ -54,9 +58,10 @@ class CudaGraphManager:
|
|||||||
if not self.cudagraph_mode.has_full_cudagraphs():
|
if not self.cudagraph_mode.has_full_cudagraphs():
|
||||||
# Full cuda graphs are not used.
|
# Full cuda graphs are not used.
|
||||||
return {}
|
return {}
|
||||||
|
if not self.cudagraph_sizes:
|
||||||
|
return {}
|
||||||
|
|
||||||
padded_sizes: dict[int, int] = {}
|
padded_sizes: dict[int, int] = {}
|
||||||
assert len(self.cudagraph_sizes) > 0
|
|
||||||
for i in range(1, self.cudagraph_sizes[-1] + 1):
|
for i in range(1, self.cudagraph_sizes[-1] + 1):
|
||||||
for x in self.cudagraph_sizes:
|
for x in self.cudagraph_sizes:
|
||||||
if i <= x:
|
if i <= x:
|
||||||
|
|||||||
Reference in New Issue
Block a user