Add fused top-K softmax kernel for MoE (#2769)

This commit is contained in:
Woosuk Kwon
2024-02-05 17:38:02 -08:00
committed by GitHub
parent 2ccee3def6
commit f0d4e14557
9 changed files with 591 additions and 50 deletions

View File

@@ -25,7 +25,6 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
@@ -155,20 +154,12 @@ class DeepseekMoE(nn.Module):
shared_output = self.shared_experts(hidden_states)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
if self.config.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = fused_moe(hidden_states,
self.w1,
self.w2,
routing_weights,
selected_experts,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob,
inplace=True)
if self.config.n_shared_experts is not None: