[Kernels] Overlap shared experts with send/recv (#23273)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
@@ -147,63 +148,85 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.physical_expert_end = (self.physical_expert_start +
|
||||
self.n_local_physical_experts)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=config.scoring_func,
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
if config.n_shared_experts is None:
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=config.scoring_func,
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
self.shared_experts = None
|
||||
else:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.n_shared_experts)
|
||||
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
||||
),
|
||||
reduce_results=False,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
|
||||
self.experts = SharedFusedMoE(
|
||||
shared_experts=self.shared_experts,
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=config.scoring_func,
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
if hidden_states.dtype != torch.float16:
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits) * self.routed_scaling_factor
|
||||
fused_moe_out = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output, final_hidden_states = fused_moe_out
|
||||
else:
|
||||
# Fix FP16 overflow
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
if shared_output is not None:
|
||||
if hidden_states.dtype != torch.float16:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
else:
|
||||
# Fix FP16 overflow
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
final_hidden_states = final_hidden_states + shared_output \
|
||||
* (1. / self.routed_scaling_factor)
|
||||
shared_output = None
|
||||
final_hidden_states = fused_moe_out
|
||||
|
||||
# Fix FP16 overflow
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
if hidden_states.dtype != torch.float16:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
elif self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
shared_output *= (1. / self.routed_scaling_factor)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
final_hidden_states += shared_output
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = (
|
||||
|
||||
Reference in New Issue
Block a user