Support Roberta embedding models (#9387)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
Maximilien de Bayser
2024-11-14 18:23:29 -03:00
committed by GitHub
parent 1dbae0329c
commit 4a18fd14ba
10 changed files with 202 additions and 14 deletions

View File

@@ -5,7 +5,7 @@ from torch import nn
from transformers import BertConfig
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -305,14 +305,16 @@ class BertOutput(nn.Module):
class BertModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type = BertEmbedding):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embeddings = BertEmbedding(config)
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(config,
cache_config,
quant_config,
@@ -382,13 +384,9 @@ class BertEmbeddingModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
self.model = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = self._build_pooler(pooler_config)
def forward(
self,
@@ -415,3 +413,16 @@ class BertEmbeddingModel(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=BertEmbedding)
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return Pooler.from_config_with_defaults(pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)