[Core] Enable CUDA graphs for DP + All2All kernels (#18724)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-05-28 18:55:30 -04:00
committed by GitHub
parent 6dbe5b5c93
commit 7951d78738
4 changed files with 100 additions and 37 deletions

View File

@@ -828,6 +828,21 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_weights(layer=self, **moe_quant_params)
# Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
if self.moe_parallel_config.use_pplx_kernels:
act_dtype = vllm_config.model_config.dtype
self.batched_hidden_states = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.hidden_size),
dtype=act_dtype,
device=torch.cuda.current_device())
self.batched_router_logits = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
dtype=act_dtype,
device=torch.cuda.current_device())
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@@ -1217,18 +1232,39 @@ class FusedMoE(torch.nn.Module):
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor):
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
assert self.batched_router_logits.dtype == full_router_logits.dtype
# Check size compatibility.
assert (
self.batched_hidden_states.size(-1) == full_hidden_states.size(-1))
assert (
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
full_final_hidden_states = torch.empty_like(full_hidden_states)
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
chunk_size = chunk_end - chunk_start
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :]
assert (self.batched_hidden_states.size(0) # type: ignore
>= chunk_size)
assert (self.batched_router_logits.size(0) # type: ignore
>= chunk_size)
staged_hidden_states = self.batched_hidden_states[:
chunk_size, :] # type: ignore
staged_router_logits = self.batched_router_logits[:
chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
x=staged_hidden_states,
router_logits=staged_router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
@@ -1244,7 +1280,7 @@ class FusedMoE(torch.nn.Module):
if not skip_result_store:
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states)
final_hidden_states, non_blocking=True)
ctx = get_forward_context()
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu