Add fused top-K softmax kernel for MoE (#2769)
This commit is contained in:
@@ -25,7 +25,6 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
@@ -155,20 +154,12 @@ class DeepseekMoE(nn.Module):
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights,
|
||||
self.top_k,
|
||||
dim=-1)
|
||||
|
||||
if self.config.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
final_hidden_states = fused_moe(hidden_states,
|
||||
self.w1,
|
||||
self.w2,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=self.config.norm_topk_prob,
|
||||
inplace=True)
|
||||
|
||||
if self.config.n_shared_experts is not None:
|
||||
|
||||
Reference in New Issue
Block a user