[Core] Enable CUDA graphs for DP + All2All kernels (#18724)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
6dbe5b5c93
commit
7951d78738
@@ -828,6 +828,21 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
# Chunked all2all staging tensor
|
||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||
if self.moe_parallel_config.use_pplx_kernels:
|
||||
act_dtype = vllm_config.model_config.dtype
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
(MOE_DP_CHUNK_SIZE, self.hidden_size),
|
||||
dtype=act_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
self.batched_router_logits = torch.zeros(
|
||||
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
|
||||
dtype=act_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
@@ -1217,18 +1232,39 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor):
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
|
||||
assert self.batched_router_logits.dtype == full_router_logits.dtype
|
||||
# Check size compatibility.
|
||||
assert (
|
||||
self.batched_hidden_states.size(-1) == full_hidden_states.size(-1))
|
||||
assert (
|
||||
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
|
||||
|
||||
full_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
|
||||
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
||||
chunk_size = chunk_end - chunk_start
|
||||
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
||||
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
||||
|
||||
assert (self.batched_hidden_states.size(0) # type: ignore
|
||||
>= chunk_size)
|
||||
assert (self.batched_router_logits.size(0) # type: ignore
|
||||
>= chunk_size)
|
||||
staged_hidden_states = self.batched_hidden_states[:
|
||||
chunk_size, :] # type: ignore
|
||||
staged_router_logits = self.batched_router_logits[:
|
||||
chunk_size, :] # type: ignore
|
||||
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
||||
staged_router_logits.copy_(router_logits, non_blocking=True)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
x=staged_hidden_states,
|
||||
router_logits=staged_router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
@@ -1244,7 +1280,7 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
if not skip_result_store:
|
||||
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
final_hidden_states)
|
||||
final_hidden_states, non_blocking=True)
|
||||
|
||||
ctx = get_forward_context()
|
||||
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
|
||||
Reference in New Issue
Block a user