[Model]: Fused MoE for nomic-embed-text-v2-moe (#18321)
Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -10,9 +10,12 @@ from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
|
||||
get_act_fn)
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@@ -26,6 +29,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import SupportsV0Only
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
@@ -201,114 +206,101 @@ class BertWithRopeMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NomicRouter(nn.Module):
|
||||
class NomicMoE(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int):
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.moe_top_k = moe_top_k
|
||||
self.layer = ReplicatedLinear(hidden_size, moe_num_experts, bias=False)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
|
||||
weights = self.layer(x.view(-1, x.shape[-1]))[0].softmax(
|
||||
dim=-1, dtype=torch.float32)
|
||||
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
|
||||
weights = weights.to(x.dtype)
|
||||
top_weights = top_weights.to(x.dtype)
|
||||
return weights, top_weights, top_experts # type: ignore
|
||||
|
||||
|
||||
class NomicExpertMLP(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, ffn_hidden_size: int,
|
||||
moe_num_experts: int, ffn_act_fn: str):
|
||||
super().__init__()
|
||||
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
||||
self.num_total_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.hidden_size = hidden_size
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.moe_num_experts = moe_num_experts
|
||||
self.total_intermediate_size = intermediate_size
|
||||
self.intermediate_size = divide(intermediate_size, self.tp_size)
|
||||
self.hidden_act = hidden_act
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
self.router = ReplicatedLinear(self.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False)
|
||||
self.w1 = nn.Parameter(
|
||||
torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
|
||||
torch.empty(self.num_total_experts,
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
device=current_platform.device_type,
|
||||
dtype=self.params_dtype))
|
||||
self.w2 = nn.Parameter(
|
||||
torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
|
||||
self.activation_fn = get_act_fn(ffn_act_fn)
|
||||
torch.empty(self.num_total_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
device=current_platform.device_type,
|
||||
dtype=self.params_dtype))
|
||||
self.bias = nn.Parameter(torch.zeros(self.hidden_size))
|
||||
set_weight_attrs(self.w1, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.w2, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
|
||||
def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
|
||||
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size,
|
||||
self.hidden_size)[expert_idx]
|
||||
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size,
|
||||
self.hidden_size)[expert_idx]
|
||||
def weight_loader(
|
||||
self,
|
||||
param: nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
):
|
||||
# NOTE: Nomic-MoE has fused experts weights with shape
|
||||
# (num_experts * intermediate_size, hidden_size)
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
param_data = param.data
|
||||
shard_size = self.intermediate_size
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
if weight_name.endswith("w1"):
|
||||
loaded_weight = loaded_weight.reshape(
|
||||
self.num_total_experts,
|
||||
self.total_intermediate_size,
|
||||
self.hidden_size,
|
||||
)[:, shard]
|
||||
if weight_name.endswith("w2"):
|
||||
loaded_weight = loaded_weight.reshape(
|
||||
self.num_total_experts,
|
||||
self.total_intermediate_size,
|
||||
self.hidden_size,
|
||||
)[:, shard].transpose(1, 2)
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
x1 = x.matmul(expert_w1.t())
|
||||
act_out = self.activation_fn(x1)
|
||||
x2 = act_out.matmul(expert_w2)
|
||||
return x2
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
final_hidden_states = fused_moe(hidden_states,
|
||||
self.w1,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=False,
|
||||
inplace=False,
|
||||
activation=self.hidden_act,
|
||||
is_act_and_mul=False)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
class NomicExperts(nn.Module):
|
||||
|
||||
def __init__(self, config, hidden_size: int, ffn_hidden_size: int,
|
||||
moe_num_experts: int):
|
||||
super().__init__()
|
||||
self.moe_num_experts = moe_num_experts
|
||||
|
||||
self.mlp = NomicExpertMLP(hidden_size=config.n_embd,
|
||||
ffn_hidden_size=config.n_inner,
|
||||
moe_num_experts=moe_num_experts,
|
||||
ffn_act_fn=config.hidden_act)
|
||||
self.bias = nn.Parameter(torch.zeros(config.n_embd))
|
||||
|
||||
def forward(self, x: torch.Tensor, weights: torch.Tensor,
|
||||
top_weights: torch.Tensor,
|
||||
top_experts: torch.LongTensor) -> torch.Tensor:
|
||||
q_len, hidden_size = x.shape
|
||||
x = x.view(-1, hidden_size)
|
||||
out = torch.zeros_like(x)
|
||||
|
||||
expert_mask = nn.functional.one_hot(
|
||||
top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
|
||||
for expert_idx in range(0, self.moe_num_experts):
|
||||
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
|
||||
if token_idx.shape[0] == 0:
|
||||
continue
|
||||
|
||||
token_list = token_idx.tolist()
|
||||
topk_list = topk_idx.tolist()
|
||||
|
||||
expert_tokens = x[None, token_list].reshape(-1, hidden_size)
|
||||
expert_out = self.mlp(
|
||||
expert_tokens, expert_idx) * top_weights[token_list, topk_list,
|
||||
None]
|
||||
|
||||
out.index_add_(0, token_idx, expert_out)
|
||||
|
||||
out = out.reshape(q_len, hidden_size)
|
||||
return out + self.bias
|
||||
|
||||
|
||||
class NomicMoELayer(nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
|
||||
self.router = NomicRouter(
|
||||
config.n_embd,
|
||||
moe_num_experts=config.num_experts,
|
||||
moe_top_k=config.moe_top_k,
|
||||
)
|
||||
|
||||
self.experts = NomicExperts(
|
||||
config,
|
||||
hidden_size=config.n_embd,
|
||||
ffn_hidden_size=config.n_inner,
|
||||
moe_num_experts=config.num_experts,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
weights, top_weights, top_experts = self.router(x)
|
||||
out = self.experts(x, weights, top_weights, top_experts)
|
||||
return out
|
||||
return final_hidden_states.view(num_tokens, hidden_size) + self.bias
|
||||
|
||||
|
||||
class BertWithRopeBlock(nn.Module):
|
||||
@@ -332,7 +324,11 @@ class BertWithRopeBlock(nn.Module):
|
||||
prefix=f"{prefix}.attention")
|
||||
|
||||
if moe:
|
||||
self.mlp = NomicMoELayer(config=config, )
|
||||
self.mlp = NomicMoE(num_experts=config.num_experts,
|
||||
top_k=config.moe_top_k,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act)
|
||||
else:
|
||||
if config.hidden_act in ["silu", "geglu"]:
|
||||
self.mlp = BertWithRopeGatedMLP(
|
||||
@@ -463,7 +459,11 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if name.endswith((".w1", ".w2")):
|
||||
# Nomic-MoE has fused experts weights
|
||||
weight_loader(param, loaded_weight, name)
|
||||
else:
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@@ -481,6 +481,10 @@ class NomicBertModel(BertWithRope):
|
||||
"mlp.fc12": "mlp.gate_proj",
|
||||
"mlp.fc2": "mlp.down_proj",
|
||||
"norm2": "mlp_ln",
|
||||
# MoE mapping
|
||||
"experts.mlp.": "",
|
||||
"experts.": "",
|
||||
"router.layer": "router",
|
||||
})
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user