[Core] Enable CUDA graphs for DP + All2All kernels (#18724)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-05-28 18:55:30 -04:00
committed by GitHub
parent 6dbe5b5c93
commit 7951d78738
4 changed files with 100 additions and 37 deletions

View File

@@ -24,7 +24,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture,
prepare_communication_buffer_for_model)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context)
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model
@@ -1104,6 +1105,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items()
})
def get_dp_padding(self, num_tokens: int):
dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
if dp_size == 1:
# Early exit.
return 0
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
return max_tokens_across_dp_cpu - num_tokens
@torch.inference_mode()
def execute_model(
self,
@@ -1141,6 +1154,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
num_input_tokens = num_scheduled_tokens
# Padding for DP
num_input_tokens += self.get_dp_padding(num_input_tokens)
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
if self.is_multimodal_model:
@@ -1658,6 +1674,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
skip_attn: bool = True,
) -> torch.Tensor:
# Padding for DP
num_tokens += self.get_dp_padding(num_tokens)
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.