[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -39,7 +39,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@@ -136,13 +136,18 @@ class MiniCPMMoE(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = fused_moe(hidden_states,
|
||||
self.ws,
|
||||
self.w2s,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=True,
|
||||
inplace=True)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=True)
|
||||
|
||||
final_hidden_states = fused_experts(hidden_states,
|
||||
self.ws,
|
||||
self.w2s,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=True)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
|
||||
Reference in New Issue
Block a user