Support Roberta embedding models (#9387)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Flavia Beo <flavia.beo@ibm.com> Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
committed by
GitHub
parent
1dbae0329c
commit
4a18fd14ba
@@ -5,7 +5,7 @@ from torch import nn
|
||||
from transformers import BertConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@@ -305,14 +305,16 @@ class BertOutput(nn.Module):
|
||||
|
||||
class BertModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
embedding_class: type = BertEmbedding):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.embeddings = BertEmbedding(config)
|
||||
self.embeddings = embedding_class(config)
|
||||
self.encoder = BertEncoder(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
@@ -382,13 +384,9 @@ class BertEmbeddingModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.model = BertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.CLS,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
self.model = self._build_model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self._pooler = self._build_pooler(pooler_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -415,3 +413,16 @@ class BertEmbeddingModel(nn.Module):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
self.model.load_weights(weights)
|
||||
|
||||
def _build_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "") -> BertModel:
|
||||
return BertModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
embedding_class=BertEmbedding)
|
||||
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
return Pooler.from_config_with_defaults(pooler_config,
|
||||
pooling_type=PoolingType.CLS,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
|
||||
Reference in New Issue
Block a user