Support token_type_ids in V1 with less code changes (#21985)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
committed by
GitHub
parent
9c97a1c349
commit
39052dbca8
@@ -14,13 +14,16 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
|
||||
DispatchPooler, Pooler)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
|
||||
from vllm.model_executor.models.bert import (TOKEN_TYPE_SHIFT,
|
||||
BertEmbeddingModel, BertModel,
|
||||
_decode_token_type_ids,
|
||||
_encode_token_type_ids)
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
||||
maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .bert_with_rope import BertWithRope, JinaRobertaModel
|
||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
||||
from .interfaces import SupportsCrossEncoding
|
||||
|
||||
|
||||
class RobertaEmbedding(nn.Module):
|
||||
@@ -53,17 +56,12 @@ class RobertaEmbedding(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
input_shape = input_ids.size()
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
# Position embeddings.
|
||||
token_type_ids = _decode_token_type_ids(input_ids)
|
||||
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape,
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device)
|
||||
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||
@@ -107,7 +105,6 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -119,9 +116,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
position_ids=positions,
|
||||
padding_idx=self.padding_idx)
|
||||
|
||||
return self.model(input_ids,
|
||||
positions,
|
||||
token_type_ids=token_type_ids,
|
||||
return self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
@@ -153,8 +149,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
return loader.load_weights(weights_list, mapper=mapper)
|
||||
|
||||
|
||||
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
SupportsV0Only):
|
||||
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
"""A model that uses Roberta to provide embedding functionalities.
|
||||
|
||||
This class encapsulates the BertModel and provides an interface for
|
||||
@@ -226,11 +221,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
replace_roberta_positions(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
padding_idx=self.padding_idx)
|
||||
if token_type_ids is not None:
|
||||
assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
|
||||
assert input_ids is not None
|
||||
_encode_token_type_ids(input_ids, token_type_ids)
|
||||
return self.roberta(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
token_type_ids=token_type_ids)
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
|
||||
# Adapted from transformers
|
||||
|
||||
Reference in New Issue
Block a user