[New Model]: jinaai/jina-embeddings-v3 (#16120)
This commit is contained in:
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
|
||||
PoolingType)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@@ -38,19 +39,24 @@ class BertEmbedding(nn.Module):
|
||||
self.size = config.hidden_size
|
||||
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.position_embeddings = VocabParallelEmbedding(
|
||||
config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
self.token_type_embeddings = VocabParallelEmbedding(
|
||||
config.type_vocab_size, config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.position_ids = nn.Parameter(
|
||||
torch.empty((1, config.max_position_embeddings)), )
|
||||
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
if self.position_embedding_type != "absolute":
|
||||
raise ValueError("Only 'absolute' position_embedding_type" +
|
||||
" is supported")
|
||||
if self.position_embedding_type == "absolute":
|
||||
self.position_embeddings = VocabParallelEmbedding(
|
||||
config.max_position_embeddings, config.hidden_size)
|
||||
self.position_ids = nn.Parameter(
|
||||
torch.empty((1, config.max_position_embeddings)), )
|
||||
elif self.position_embedding_type == "rotary":
|
||||
self.position_embeddings = None
|
||||
self.position_ids = None
|
||||
else:
|
||||
raise ValueError("Only 'absolute' and 'rotary' " +
|
||||
"position_embedding_type is supported")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -64,9 +70,6 @@ class BertEmbedding(nn.Module):
|
||||
# Input embeddings.
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
# Position embeddings.
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape,
|
||||
dtype=torch.long,
|
||||
@@ -74,7 +77,12 @@ class BertEmbedding(nn.Module):
|
||||
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||
embeddings = inputs_embeds + token_type_embeddings
|
||||
|
||||
if self.position_embedding_type == "absolute":
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings += position_embeddings
|
||||
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
return embeddings
|
||||
|
||||
@@ -98,7 +106,10 @@ class BertPooler(nn.Module):
|
||||
@support_torch_compile
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
@@ -107,16 +118,18 @@ class BertEncoder(nn.Module):
|
||||
BertLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.layer.{layer_idx}")
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
for layer in self.layer:
|
||||
hidden_states = layer(hidden_states)
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -126,6 +139,7 @@ class BertLayer(nn.Module):
|
||||
config: BertConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@@ -135,6 +149,7 @@ class BertLayer(nn.Module):
|
||||
layer_norm_eps=config.layer_norm_eps,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.attention")
|
||||
|
||||
self.intermediate = BertIntermediate(
|
||||
@@ -150,8 +165,8 @@ class BertLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
attn_output = self.attention(hidden_states)
|
||||
def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor):
|
||||
attn_output = self.attention(positions, hidden_states)
|
||||
intermediate_output = self.intermediate(attn_output)
|
||||
output = self.output(intermediate_output, attn_output)
|
||||
return output
|
||||
@@ -166,6 +181,7 @@ class BertAttention(nn.Module):
|
||||
layer_norm_eps: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -174,6 +190,7 @@ class BertAttention(nn.Module):
|
||||
num_attention_heads=num_attention_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.output")
|
||||
|
||||
self.output = BertSelfOutput(hidden_size=hidden_size,
|
||||
@@ -183,9 +200,10 @@ class BertAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
self_output = self.self(hidden_states)
|
||||
self_output = self.self(positions, hidden_states)
|
||||
return self.output(self_output, hidden_states)
|
||||
|
||||
|
||||
@@ -197,6 +215,7 @@ class BertSelfAttention(nn.Module):
|
||||
num_attention_heads: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -225,6 +244,11 @@ class BertSelfAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
if rotary_kwargs:
|
||||
self.rotary_emb = get_rope(**rotary_kwargs)
|
||||
else:
|
||||
self.rotary_emb = None
|
||||
|
||||
self.attn = Attention(num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scaling,
|
||||
@@ -236,10 +260,15 @@ class BertSelfAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
if self.rotary_emb:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
output = self.attn(q, k, v)
|
||||
return output
|
||||
|
||||
@@ -321,11 +350,13 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
embedding_class: type = BertEmbedding,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
add_pooling_layer: bool = False):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.embeddings = embedding_class(config)
|
||||
self.encoder = BertEncoder(vllm_config=vllm_config,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||
|
||||
@@ -347,7 +378,7 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
seq_lens=attn_metadata.seq_lens_tensor,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
return self.encoder(hidden_states)
|
||||
return self.encoder(position_ids, hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
@@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = self._build_model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self._pooler = self._build_pooler(pooler_config)
|
||||
|
||||
Reference in New Issue
Block a user