[Core] Refactor padding logic and pad for CUDA graphs before attention metadata building (#28579)

This commit is contained in:
Lucas Wilkinson
2025-11-26 14:07:13 -05:00
committed by GitHub
parent 430dd4d9eb
commit 56539cddac
10 changed files with 401 additions and 283 deletions

View File

@@ -8,12 +8,13 @@ from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any, cast
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import (
ensure_model_parallel_initialized,
init_distributed_environment,
@@ -487,6 +488,7 @@ class Worker(WorkerBase):
hidden_states, last_hidden_states = self.model_runner._dummy_run(
num_tokens=max_num_reqs,
skip_eplb=True,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
)
if self.model_runner.is_pooling_model:
self.model_runner._dummy_pooler_run(hidden_states)
@@ -534,12 +536,39 @@ class Worker(WorkerBase):
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
all_gather_tensors = {
"residual": not is_residual_scattered_for_sp(
self.vllm_config, num_input_tokens
all_gather_tensors = {}
compilation_config = self.vllm_config.compilation_config
parallel_config = self.vllm_config.parallel_config
if (
parallel_config.pipeline_parallel_size > 1
and compilation_config.pass_config.enable_sequence_parallelism
and forward_pass
):
# currently only supported by V1 GPUModelRunner
assert isinstance(self.model_runner, GPUModelRunner)
num_scheduled_tokens_np = np.array(
list(scheduler_output.num_scheduled_tokens.values()),
dtype=np.int32,
)
}
# TODO(lucas): This is pretty gross; ideally we should only ever call
# `_determine_batch_execution_and_padding` once (will get called again
# in `execute_model`) but this requires a larger refactor of PP.
_, batch_desc, _, _ = (
self.model_runner._determine_batch_execution_and_padding(
num_tokens=num_scheduled_tokens,
num_reqs=len(num_scheduled_tokens_np),
num_scheduled_tokens_np=num_scheduled_tokens_np,
max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
use_cascade_attn=False, # TODO(lucas): Handle cascade attention
)
)
all_gather_tensors = {
"residual": not is_residual_scattered_for_sp(
self.vllm_config, batch_desc.num_tokens
)
}
if forward_pass and not get_pp_group().is_first_rank:
tensor_dict = get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group(),