Support Cross encoder models (#10400)

Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
Maximilien de Bayser
2024-11-24 23:56:20 -03:00
committed by GitHub
parent 49628fe13e
commit 214efc2c3c
28 changed files with 1370 additions and 62 deletions

View File

@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -6,10 +6,17 @@ from transformers import RobertaConfig
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import CrossEncodingPooler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
class RobertaEmbedding(nn.Module):
@@ -39,34 +46,93 @@ class RobertaEmbedding(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
# Input embeddings.
inputs_embeds = self.word_embeddings(input_ids)
# TODO: figure out if there is a better way
# to make to make position ids start at padding_idx + 1
# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
position_ids += self.padding_idx + 1
pos_list = []
token_list = []
offset = 0
for seq_len in seq_lens:
pos_list.append(position_ids[offset:offset + seq_len])
token_list.append(input_ids[offset:offset + seq_len])
offset += seq_len
new_pos_list = []
for positions, tokens in zip(pos_list, token_list):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
dtype=torch.long,
device=inputs_embeds.device)
assert torch.equal(positions, expected_pos)
new_pos_list.append(
create_position_ids_from_input_ids(tokens, self.padding_idx))
position_ids = torch.cat(new_pos_list)
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)
# Token type embeddings. (TODO: move off hotpath?)
token_type_embeddings = self.token_type_embeddings(
torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device))
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings
# Adapted from transformers
def create_position_ids_from_input_ids(input_ids,
padding_idx,
past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
past_key_values_length) * mask
return incremental_indices.long() + padding_idx
# Adapted from transformers
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config: RobertaConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features[0, :] # take <s> token (equiv. to [CLS])
x = self.dense(x)
x = torch.tanh(x)
x = self.out_proj(x)
return x
class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities.
@@ -85,6 +151,62 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
prefix=prefix,
embedding_class=RobertaEmbedding)
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
roberta: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.default_activation_function = \
get_cross_encoder_activation_function(config)
self.num_labels = config.num_labels
self.roberta = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"),
embedding_class=RobertaEmbedding,
add_pooling_layer=False)
self.classifier = RobertaClassificationHead(config)
self._pooler = CrossEncodingPooler(config, self.classifier)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self_weights = []
def weight_filter():
for name, weight in weights:
if name.startswith("roberta."):
yield (name[len("roberta."):], weight)
else:
self_weights.append((name, weight))
self.roberta.load_weights(weight_filter())
params_dict = dict(self.named_parameters())
for name, loaded_weight in self_weights:
if name.startswith("classifier"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def forward(
self,
input_ids: Optional[torch.Tensor],
@@ -93,25 +215,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Verify assumption that position are always a sequence from
# 0 to N. (Actually here we just check 0 and N to simplify).
# This is important to fix the position which are assumed to
# start from padding_idx + 1 instead of 0 in the Roberta models.
assert hasattr(attn_metadata, "seq_lens_tensor")
cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0)
start_pos = torch.cat(
(torch.tensor([0], device=attn_metadata.seq_lens_tensor.device),
cumulative[:-1]))
assert len(torch.nonzero(positions[start_pos])) == 0
end_pos = cumulative - 1
last_tokens = attn_metadata.seq_lens_tensor - 1
assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0
return super().forward(input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return self.roberta(input_ids=input_ids,
position_ids=positions,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
attn_metadata=attn_metadata,
token_type_ids=token_type_ids)