Modularize fused experts and integrate PPLX kernels (#15956)
This commit is contained in:
@@ -31,9 +31,7 @@ from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@@ -143,7 +141,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
||||
),
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
|
||||
@@ -154,6 +153,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
if hidden_states.dtype != torch.float16:
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
@@ -171,9 +171,11 @@ class DeepseekV2MoE(nn.Module):
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
final_hidden_states = final_hidden_states + shared_output \
|
||||
* (1. / self.routed_scaling_factor)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
final_hidden_states = (
|
||||
self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states))
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user