[New Model]: nomic-embed-text-v2-moe (#17785)

This commit is contained in:
wang.yuqi
2025-05-11 15:59:43 +08:00
committed by GitHub
parent 06c0922a69
commit e4b8713380
9 changed files with 899 additions and 364 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import itertools
from typing import Iterable, Optional, Tuple
from typing import Iterable, Optional, Tuple, Union
import torch
from torch import nn
@@ -19,6 +19,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding, SupportsV0Only
@@ -125,39 +126,20 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
prefix: str = "") -> Union[BertModel, BertWithRope]:
if (vllm_config.model_config.hf_config.position_embedding_type ==
"rotary"):
config = vllm_config.model_config.hf_config
head_dim = config.hidden_size // config.num_attention_heads
rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rotary_emb_base,
"rope_scaling": getattr(config, "rope_scaling", None)
}
return BertModel(vllm_config=vllm_config,
rotary_kwargs=rotary_kwargs,
prefix=prefix)
return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix)
else:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=RobertaEmbedding)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if getattr(self.config, "lora_rank", 0) > 0:
scaling = self.config.lora_alpha / self.config.lora_rank
weights = jina_merge_lora_weights(weights, scaling)
weights = self.hf_to_vllm_mapper.apply(weights)
# Separate weights in "roberta"-prefixed and all else (not in memory).
# For use with models like FacebookAI/roberta-base.
bert_weights, task_weights = roberta_task_weights_filter(weights)
bert_weights = jina_to_vllm_mapper.apply(bert_weights)
loaded = self.model.load_weights(bert_weights)
if not len(loaded):
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
@@ -178,6 +160,18 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
_pooler: An instance of Pooler used for pooling operations.
"""
jina_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
'emb_ln': "embeddings.LayerNorm",
'layers': "layer",
'mixer.Wqkv': "attention.self.qkv_proj",
'mixer.out_proj': "attention.output.dense",
'norm1': "attention.output.LayerNorm",
'mlp.fc1': "intermediate.dense",
'mlp.fc2': "output.dense",
'norm2': "output.LayerNorm",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@@ -195,7 +189,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
bert_weights, task_weights = roberta_task_weights_filter(weights)
bert_weights = jina_to_vllm_mapper.apply(bert_weights)
bert_weights = self.jina_to_vllm_mapper.apply(bert_weights)
self.roberta.load_weights(bert_weights)
@@ -276,57 +270,3 @@ def roberta_task_weights_filter(
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
if not n.startswith("roberta."))
jina_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
'emb_ln': "embeddings.LayerNorm",
'layers': "layer",
'mixer.Wqkv': "attention.self.qkv_proj",
'mixer.out_proj': "attention.output.dense",
'norm1': "attention.output.LayerNorm",
'mlp.fc1': "intermediate.dense",
'mlp.fc2': "output.dense",
'norm2': "output.LayerNorm",
})
@torch.inference_mode()
def jina_merge_lora_weights(weights: Iterable[Tuple[str, torch.Tensor]],
scaling: float = 1.0):
# use for jina-embeddings-v3
# Merge Lora weights into a single weight tensor.
# This is a temporary solution until we have a better way to handle
weights = {name: weight for name, weight in weights}
o = ".original"
a = ".0.lora_A"
b = ".0.lora_B"
# text-matching
i = -1
for name in list(weights.keys()):
if o in name:
dtype = weights[name].dtype
shape = weights[name].shape
weight_name = name[:-len(o)]
if "embeddings" in weight_name:
B = weights[weight_name + a][i].cuda().float()
A = weights[weight_name + b][i].cuda().float()
else:
B = weights[weight_name + b][i].cuda().float()
A = weights[weight_name + a][i].cuda().float()
weight = (weights[weight_name + o].cuda() +
torch.matmul(B, A).view(shape) * scaling)
weight = weight.cpu().to(dtype)
weights[weight_name.replace(".parametrizations", "")] = weight
del weights[weight_name + o], weights[weight_name +
a], weights[weight_name + b]
return [(name, weight) for name, weight in weights.items()]