Signed-off-by: haosdent <haosdent@gmail.com> Co-authored-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
@@ -1190,6 +1190,58 @@ class CompilationConfig:
|
||||
self.max_cudagraph_capture_size = rounded_sizes[-1]
|
||||
self.cudagraph_capture_sizes = rounded_sizes
|
||||
|
||||
def adjust_cudagraph_sizes_for_mamba_cache(
|
||||
self, num_mamba_cache_blocks: int
|
||||
) -> None:
|
||||
"""Cap cudagraph capture sizes to available Mamba cache blocks.
|
||||
|
||||
For hybrid Mamba/attention models, the Mamba conv_state and
|
||||
ssm_state tensors have their first dimension equal to num_blocks
|
||||
(from KVCacheConfig). During CUDA graph capture the decode batch
|
||||
size equals num_tokens, so capture sizes exceeding num_blocks
|
||||
would cause out-of-bounds access in Mamba kernels.
|
||||
|
||||
See: https://github.com/vllm-project/vllm/issues/34094
|
||||
"""
|
||||
if not self.cudagraph_capture_sizes or num_mamba_cache_blocks <= 0:
|
||||
return
|
||||
|
||||
assert self.max_cudagraph_capture_size is not None
|
||||
|
||||
if num_mamba_cache_blocks >= self.max_cudagraph_capture_size:
|
||||
return
|
||||
|
||||
capped_sizes = [
|
||||
s for s in self.cudagraph_capture_sizes if s <= num_mamba_cache_blocks
|
||||
]
|
||||
|
||||
if len(capped_sizes) == 0:
|
||||
logger.warning(
|
||||
"No valid cudagraph capture sizes remain after capping "
|
||||
"to Mamba cache blocks (%d). The smallest capture size "
|
||||
"was %d. Disabling cudagraph capture. Consider reducing "
|
||||
"max_num_seqs or increasing available GPU memory.",
|
||||
num_mamba_cache_blocks,
|
||||
self.cudagraph_capture_sizes[0],
|
||||
)
|
||||
self.cudagraph_capture_sizes = []
|
||||
self.max_cudagraph_capture_size = 0
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"Capping cudagraph capture sizes from max %d to %d to fit "
|
||||
"Mamba cache blocks (%d blocks available). This limits the "
|
||||
"maximum batch size that can use CUDA graphs. To increase "
|
||||
"this limit, reduce max_num_seqs or increase available GPU "
|
||||
"memory.",
|
||||
self.max_cudagraph_capture_size,
|
||||
capped_sizes[-1],
|
||||
num_mamba_cache_blocks,
|
||||
)
|
||||
|
||||
self.max_cudagraph_capture_size = capped_sizes[-1]
|
||||
self.cudagraph_capture_sizes = capped_sizes
|
||||
|
||||
def get_compile_ranges(self) -> list[Range]:
|
||||
"""Get the compile ranges for the compilation config."""
|
||||
if self.compile_ranges_split_points is None:
|
||||
|
||||
Reference in New Issue
Block a user