[BugFix] Temporary fix for IMA with MTP = 2 and full-cg (#28315)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-11-17 09:41:22 -05:00
committed by GitHub
parent 1b82fb0ad3
commit 64e39d667c
2 changed files with 80 additions and 13 deletions

View File

@@ -18,6 +18,7 @@ from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import is_torch_equal_or_newer
if TYPE_CHECKING:
@@ -773,19 +774,8 @@ class CompilationConfig:
if self.cudagraph_capture_sizes:
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size
# pre-compute the mapping from batch size to padded graph size
self.bs_to_padded_graph_size = [
0 for i in range(self.max_cudagraph_capture_size + 1)
]
for end, start in zip(
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
[0] + self.cudagraph_capture_sizes,
):
for bs in range(start, end):
if bs == start:
self.bs_to_padded_graph_size[bs] = start
else:
self.bs_to_padded_graph_size[bs] = end
# May get recomputed in the model runner if adjustment is needed for spec-decode
self.compute_bs_to_padded_graph_size()
def set_splitting_ops_for_v1(self):
# NOTE: this function needs to be called only when mode is
@@ -922,3 +912,64 @@ class CompilationConfig:
enable_str,
op,
)
def adjust_cudagraph_sizes_for_spec_decode(
self, uniform_decode_query_len: int, tensor_parallel_size: int
):
multiple_of = uniform_decode_query_len
if tensor_parallel_size > 1:
multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
if (
multiple_of % uniform_decode_query_len != 0
or multiple_of % tensor_parallel_size != 0
):
raise ValueError(
f"Can't determine cudagraph shapes that are both a "
f"multiple of {uniform_decode_query_len} "
f"(num_speculative_tokens + 1) required by spec-decode "
f"and {tensor_parallel_size} (tensor_parallel_size) "
f"required by sequence parallelism please adjust "
f"num_speculative_tokens or disable sequence parallelism"
)
if not self.cudagraph_capture_sizes or multiple_of <= 1:
return
assert self.max_cudagraph_capture_size is not None
rounded_sizes = sorted(
set(
round_up(size, multiple_of)
for size in self.cudagraph_capture_sizes
if round_up(size, multiple_of) <= self.max_cudagraph_capture_size
)
)
if len(rounded_sizes) == 0:
logger.warning(
"No valid cudagraph sizes after rounding to multiple of "
" num_speculative_tokens + 1 (%d); please adjust num_speculative_tokens"
" or max_cudagraph_capture_size (or cudagraph_capture_sizes)",
multiple_of,
)
return
self.max_cudagraph_capture_size = rounded_sizes[-1]
self.cudagraph_capture_sizes = rounded_sizes
# Recompute after adjusting the cudagraph sizes
self.compute_bs_to_padded_graph_size()
def compute_bs_to_padded_graph_size(self):
# pre-compute the mapping from batch size to padded graph size
self.bs_to_padded_graph_size = [
0 for i in range(self.max_cudagraph_capture_size + 1)
]
for end, start in zip(
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
[0] + self.cudagraph_capture_sizes,
):
for bs in range(start, end):
if bs == start:
self.bs_to_padded_graph_size[bs] = start
else:
self.bs_to_padded_graph_size[bs] = end