[Model Runner V2][Perf] align dummy_run tokens to uniform decode for dp cudagraph (#35376)
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
@@ -39,6 +39,7 @@ from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
@@ -327,12 +328,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
|
||||
self,
|
||||
num_tokens: int,
|
||||
*args,
|
||||
skip_attn: bool = True,
|
||||
uniform_decode: bool = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
# Create a dummy scheduler output.
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
|
||||
num_tokens_per_request[-1] += num_tokens % num_reqs
|
||||
if uniform_decode:
|
||||
# Align tokens to uniform_decode_query_len for cudagraph
|
||||
# compatibility across DP ranks.
|
||||
query_len = self.cudagraph_manager.uniform_decode_query_len
|
||||
num_reqs = min(cdiv(num_tokens, query_len), self.max_num_reqs)
|
||||
num_tokens = num_reqs * query_len
|
||||
num_tokens_per_request = [query_len] * num_reqs
|
||||
else:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
|
||||
num_tokens_per_request[-1] += num_tokens % num_reqs
|
||||
assert sum(num_tokens_per_request) == num_tokens
|
||||
num_scheduled_tokens = {
|
||||
f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
|
||||
|
||||
Reference in New Issue
Block a user