[Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982)
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f9df8b4ad7
commit
a5354b3ed2
@@ -32,7 +32,6 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
@@ -56,8 +55,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
@@ -108,43 +107,6 @@ class DeepseekV2MLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Chunk x along the num_tokens axis for sequence parallelism
|
||||
# NOTE: This is wrapped in a torch custom op to work around the following issue:
|
||||
# The output tensor can have a sequence length 0 at small input sequence lengths
|
||||
# even though we explicitly pad to avoid this.
|
||||
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# all_gather needs the sequence length to be divisible by tp_size
|
||||
seq_len = x.size(0)
|
||||
remainder = seq_len % tp_size
|
||||
if remainder != 0:
|
||||
pad_len = tp_size - remainder
|
||||
x = nn.functional.pad(x, (0, 0, 0, pad_len))
|
||||
|
||||
chunk = x.shape[0] // tp_size
|
||||
start = tp_rank * chunk
|
||||
return torch.narrow(x, 0, start, chunk)
|
||||
|
||||
|
||||
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
seq_len = cdiv(x.size(0), tp_size)
|
||||
shape = list(x.shape)
|
||||
shape[0] = seq_len
|
||||
out = torch.empty(shape, dtype=x.dtype, device=x.device)
|
||||
return out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sequence_parallel_chunk",
|
||||
op_func=sequence_parallel_chunk,
|
||||
fake_impl=sequence_parallel_chunk_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -166,20 +128,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.n_routed_experts: int = config.n_routed_experts
|
||||
self.n_shared_experts: int = config.n_shared_experts
|
||||
|
||||
# The all_reduce at the end of attention (during o_proj) means that
|
||||
# inputs are replicated across each rank of the tensor parallel group.
|
||||
# If using expert-parallelism with DeepEP All2All ops, replicated
|
||||
# tokens results in useless duplicate computation and communication.
|
||||
#
|
||||
# In this case, ensure the input to the experts is sequence parallel
|
||||
# to avoid the excess work.
|
||||
#
|
||||
# Not needed for pplx-kernels as it can handle duplicate input tokens.
|
||||
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
|
||||
in ("deepep_high_throughput",
|
||||
"deepep_low_latency")
|
||||
and parallel_config.enable_expert_parallel
|
||||
and self.tp_size > 1)
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
@@ -278,8 +227,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
# TODO: We can replace the all_reduce at the end of attn with a
|
||||
# reduce_scatter instead of chunking here.
|
||||
if self.is_sequence_parallel:
|
||||
hidden_states = torch.ops.vllm.sequence_parallel_chunk(
|
||||
hidden_states)
|
||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user