[New Model]: jinaai/jina-embeddings-v3 (#16120)

This commit is contained in:
wang.yuqi
2025-04-08 23:39:12 +08:00
committed by GitHub
parent 90cb44eb02
commit 1f5d13ab9f
6 changed files with 297 additions and 86 deletions

View File

@@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -38,19 +39,24 @@ class BertEmbedding(nn.Module):
self.size = config.hidden_size
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.position_embeddings = VocabParallelEmbedding(
config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = VocabParallelEmbedding(
config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError("Only 'absolute' position_embedding_type" +
" is supported")
if self.position_embedding_type == "absolute":
self.position_embeddings = VocabParallelEmbedding(
config.max_position_embeddings, config.hidden_size)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
elif self.position_embedding_type == "rotary":
self.position_embeddings = None
self.position_ids = None
else:
raise ValueError("Only 'absolute' and 'rotary' " +
"position_embedding_type is supported")
def forward(
self,
@@ -64,9 +70,6 @@ class BertEmbedding(nn.Module):
# Input embeddings.
inputs_embeds = self.word_embeddings(input_ids)
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
@@ -74,7 +77,12 @@ class BertEmbedding(nn.Module):
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings
@@ -98,7 +106,10 @@ class BertPooler(nn.Module):
@support_torch_compile
class BertEncoder(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
vllm_config: VllmConfig,
rotary_kwargs: Optional[dict] = None,
prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
@@ -107,16 +118,18 @@ class BertEncoder(nn.Module):
BertLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
for layer in self.layer:
hidden_states = layer(hidden_states)
hidden_states = layer(positions, hidden_states)
return hidden_states
@@ -126,6 +139,7 @@ class BertLayer(nn.Module):
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rotary_kwargs: Optional[dict] = None,
prefix: str = ""):
super().__init__()
@@ -135,6 +149,7 @@ class BertLayer(nn.Module):
layer_norm_eps=config.layer_norm_eps,
cache_config=cache_config,
quant_config=quant_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.attention")
self.intermediate = BertIntermediate(
@@ -150,8 +165,8 @@ class BertLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.output")
def forward(self, hidden_states: torch.Tensor):
attn_output = self.attention(hidden_states)
def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor):
attn_output = self.attention(positions, hidden_states)
intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output)
return output
@@ -166,6 +181,7 @@ class BertAttention(nn.Module):
layer_norm_eps: float,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rotary_kwargs: Optional[dict] = None,
prefix: str = "",
):
super().__init__()
@@ -174,6 +190,7 @@ class BertAttention(nn.Module):
num_attention_heads=num_attention_heads,
cache_config=cache_config,
quant_config=quant_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.output")
self.output = BertSelfOutput(hidden_size=hidden_size,
@@ -183,9 +200,10 @@ class BertAttention(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
self_output = self.self(hidden_states)
self_output = self.self(positions, hidden_states)
return self.output(self_output, hidden_states)
@@ -197,6 +215,7 @@ class BertSelfAttention(nn.Module):
num_attention_heads: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rotary_kwargs: Optional[dict] = None,
prefix: str = "",
):
super().__init__()
@@ -225,6 +244,11 @@ class BertSelfAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
if rotary_kwargs:
self.rotary_emb = get_rope(**rotary_kwargs)
else:
self.rotary_emb = None
self.attn = Attention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
@@ -236,10 +260,15 @@ class BertSelfAttention(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.rotary_emb:
q, k = self.rotary_emb(positions, q, k)
output = self.attn(q, k, v)
return output
@@ -321,11 +350,13 @@ class BertModel(nn.Module, SupportsQuant):
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type = BertEmbedding,
rotary_kwargs: Optional[dict] = None,
add_pooling_layer: bool = False):
super().__init__()
config = vllm_config.model_config.hf_config
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(vllm_config=vllm_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.encoder")
self.pooler = BertPooler(config) if add_pooling_layer else None
@@ -347,7 +378,7 @@ class BertModel(nn.Module, SupportsQuant):
seq_lens=attn_metadata.seq_lens_tensor,
position_ids=position_ids,
token_type_ids=token_type_ids)
return self.encoder(hidden_states)
return self.encoder(position_ids, hidden_states)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
@@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
self.config = vllm_config.model_config.hf_config
self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = self._build_pooler(pooler_config)