[New Model]: jinaai/jina-embeddings-v3 (#16120)
This commit is contained in:
@@ -22,30 +22,6 @@ from vllm.transformers_utils.config import (
|
||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
||||
|
||||
|
||||
def roberta_task_weights_filter(
|
||||
all_weights: Iterable[Tuple[str, torch.Tensor]]
|
||||
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
|
||||
torch.Tensor]]]:
|
||||
"""
|
||||
Separate task-specific weights that are applied on top
|
||||
of the encoder-decoder bert base.
|
||||
To do so, return two generators over the original iterator.
|
||||
Also, remove the "roberta." prefix to make it loadable
|
||||
from vanilla BertModel.
|
||||
"""
|
||||
# Copy of a lazy iterator without in-memory overhead so both
|
||||
# iterators can be iterated upon independently.
|
||||
all_weights1, all_weights2 = itertools.tee(all_weights)
|
||||
|
||||
def encoder_decoder_weights():
|
||||
for name, weight in all_weights1:
|
||||
if name.startswith("roberta."):
|
||||
yield (name[len("roberta."):], weight)
|
||||
|
||||
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
|
||||
if not n.startswith("roberta."))
|
||||
|
||||
|
||||
class RobertaEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, config: RobertaConfig):
|
||||
@@ -119,30 +95,6 @@ class RobertaEmbedding(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Adapted from transformers
|
||||
def create_position_ids_from_input_ids(input_ids,
|
||||
padding_idx,
|
||||
past_key_values_length=0):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers.
|
||||
Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
|
||||
Args:
|
||||
x: torch.Tensor x:
|
||||
|
||||
Returns: torch.Tensor
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully
|
||||
# balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
|
||||
incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
|
||||
past_key_values_length) * mask
|
||||
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
# Adapted from transformers
|
||||
class RobertaClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
@@ -174,15 +126,38 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
def _build_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "") -> BertModel:
|
||||
return BertModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
embedding_class=RobertaEmbedding)
|
||||
if (vllm_config.model_config.hf_config.position_embedding_type ==
|
||||
"rotary"):
|
||||
config = vllm_config.model_config.hf_config
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
|
||||
rotary_kwargs = {
|
||||
"head_size": head_dim,
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": config.max_position_embeddings,
|
||||
"base": config.rotary_emb_base,
|
||||
"rope_scaling": getattr(config, "rope_scaling", None)
|
||||
}
|
||||
|
||||
return BertModel(vllm_config=vllm_config,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=prefix)
|
||||
else:
|
||||
return BertModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
embedding_class=RobertaEmbedding)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
if getattr(self.config, "lora_rank", 0) > 0:
|
||||
scaling = self.config.lora_alpha / self.config.lora_rank
|
||||
weights = jina_merge_lora_weights(weights, scaling)
|
||||
|
||||
weights = self.hf_to_vllm_mapper.apply(weights)
|
||||
# Separate weights in "roberta"-prefixed and all else (not in memory).
|
||||
# For use with models like FacebookAI/roberta-base.
|
||||
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
||||
bert_weights = jina_to_vllm_mapper.apply(bert_weights)
|
||||
|
||||
loaded = self.model.load_weights(bert_weights)
|
||||
if not len(loaded):
|
||||
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
|
||||
@@ -203,18 +178,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
jina_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
'emb_ln': "embeddings.LayerNorm",
|
||||
'layers': "layer",
|
||||
'mixer.Wqkv': "attention.self.qkv_proj",
|
||||
'mixer.out_proj': "attention.output.dense",
|
||||
'norm1': "attention.output.LayerNorm",
|
||||
'mlp.fc1': "intermediate.dense",
|
||||
'mlp.fc2': "output.dense",
|
||||
'norm2': "output.LayerNorm",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -232,7 +195,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
||||
bert_weights = self.jina_to_vllm_mapper.apply(bert_weights)
|
||||
bert_weights = jina_to_vllm_mapper.apply(bert_weights)
|
||||
|
||||
self.roberta.load_weights(bert_weights)
|
||||
|
||||
@@ -265,3 +228,105 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
token_type_ids=token_type_ids)
|
||||
|
||||
|
||||
# Adapted from transformers
|
||||
def create_position_ids_from_input_ids(input_ids,
|
||||
padding_idx,
|
||||
past_key_values_length=0):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers.
|
||||
Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
|
||||
Args:
|
||||
x: torch.Tensor x:
|
||||
|
||||
Returns: torch.Tensor
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully
|
||||
# balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
|
||||
incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
|
||||
past_key_values_length) * mask
|
||||
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
def roberta_task_weights_filter(
|
||||
all_weights: Iterable[Tuple[str, torch.Tensor]]
|
||||
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
|
||||
torch.Tensor]]]:
|
||||
"""
|
||||
Separate task-specific weights that are applied on top
|
||||
of the encoder-decoder bert base.
|
||||
To do so, return two generators over the original iterator.
|
||||
Also, remove the "roberta." prefix to make it loadable
|
||||
from vanilla BertModel.
|
||||
"""
|
||||
# Copy of a lazy iterator without in-memory overhead so both
|
||||
# iterators can be iterated upon independently.
|
||||
all_weights1, all_weights2 = itertools.tee(all_weights)
|
||||
|
||||
def encoder_decoder_weights():
|
||||
for name, weight in all_weights1:
|
||||
if name.startswith("roberta."):
|
||||
yield (name[len("roberta."):], weight)
|
||||
|
||||
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
|
||||
if not n.startswith("roberta."))
|
||||
|
||||
|
||||
jina_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
'emb_ln': "embeddings.LayerNorm",
|
||||
'layers': "layer",
|
||||
'mixer.Wqkv': "attention.self.qkv_proj",
|
||||
'mixer.out_proj': "attention.output.dense",
|
||||
'norm1': "attention.output.LayerNorm",
|
||||
'mlp.fc1': "intermediate.dense",
|
||||
'mlp.fc2': "output.dense",
|
||||
'norm2': "output.LayerNorm",
|
||||
})
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def jina_merge_lora_weights(weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
scaling: float = 1.0):
|
||||
# use for jina-embeddings-v3
|
||||
# Merge Lora weights into a single weight tensor.
|
||||
# This is a temporary solution until we have a better way to handle
|
||||
|
||||
weights = {name: weight for name, weight in weights}
|
||||
|
||||
o = ".original"
|
||||
a = ".0.lora_A"
|
||||
b = ".0.lora_B"
|
||||
|
||||
# text-matching
|
||||
i = -1
|
||||
|
||||
for name in list(weights.keys()):
|
||||
if o in name:
|
||||
dtype = weights[name].dtype
|
||||
shape = weights[name].shape
|
||||
weight_name = name[:-len(o)]
|
||||
|
||||
if "embeddings" in weight_name:
|
||||
B = weights[weight_name + a][i].cuda().float()
|
||||
A = weights[weight_name + b][i].cuda().float()
|
||||
else:
|
||||
B = weights[weight_name + b][i].cuda().float()
|
||||
A = weights[weight_name + a][i].cuda().float()
|
||||
|
||||
weight = (weights[weight_name + o].cuda() +
|
||||
torch.matmul(B, A).view(shape) * scaling)
|
||||
weight = weight.cpu().to(dtype)
|
||||
|
||||
weights[weight_name.replace(".parametrizations", "")] = weight
|
||||
|
||||
del weights[weight_name + o], weights[weight_name +
|
||||
a], weights[weight_name + b]
|
||||
|
||||
return [(name, weight) for name, weight in weights.items()]
|
||||
|
||||
Reference in New Issue
Block a user