Support token_type_ids in V1 with less code changes (#21985)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
committed by
GitHub
parent
9c97a1c349
commit
39052dbca8
@@ -28,7 +28,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
||||
|
||||
@@ -60,21 +60,13 @@ class BertEmbedding(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
input_shape = input_ids.size()
|
||||
|
||||
# Input embeddings.
|
||||
token_type_ids = _decode_token_type_ids(input_ids)
|
||||
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
# Position embeddings.
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape,
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device)
|
||||
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||
@@ -350,25 +342,23 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.embeddings = embedding_class(config)
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.embeddings = embedding_class(self.config)
|
||||
self.encoder = BertEncoder(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embeddings(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
position_ids=positions)
|
||||
return self.encoder(hidden_states)
|
||||
|
||||
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
@@ -468,13 +458,11 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
token_type_ids=token_type_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
@@ -508,8 +496,53 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
})
|
||||
|
||||
|
||||
class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
SupportsCrossEncoding, SupportsQuant):
|
||||
# Here we encode the token type ids together with the input ids.
|
||||
# Since we use int 32 for the input IDs and the vocabulary size
|
||||
# is way lower than 2**31, there is room to encode additional
|
||||
# bits. At the same time, for cross-encoder use cases, the
|
||||
# token type ids are only 0 or 1, requiring only 1 bit.
|
||||
# This means that we can store the token type ids in the 31st
|
||||
# bit. We void the 32nd bit because that would produce a negative
|
||||
# number, which could be used to signal other things.
|
||||
#
|
||||
# The reason for all of this is that all the tensors that are
|
||||
# passed as input to the forward function of a module marked
|
||||
# with @support_torch_compile have to be persistent. So to
|
||||
# avoid adding more persistent tensors in the model runner, we
|
||||
# encode more information in the same persistent tensor.
|
||||
#
|
||||
# Since the *ForClassification module is outside of the BertModel
|
||||
# which is compiled, we can do the encoding here and then separate
|
||||
# the information again in the Embedding layer. Since with bit masks
|
||||
# we can do this entirely with torch operations and without branching,
|
||||
# it works with torch compile.
|
||||
|
||||
TOKEN_TYPE_SHIFT = 30
|
||||
|
||||
|
||||
def _encode_token_type_ids(input_ids: torch.Tensor,
|
||||
token_type_ids: torch.Tensor) -> None:
|
||||
# input_ids can be padded to the right
|
||||
input_ids[:token_type_ids.shape[0]].bitwise_or_(
|
||||
token_type_ids << TOKEN_TYPE_SHIFT)
|
||||
|
||||
|
||||
def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
ids_mask = torch.ones(input_ids.shape,
|
||||
dtype=torch.int32,
|
||||
device=input_ids.device) << TOKEN_TYPE_SHIFT
|
||||
tokens_mask = ids_mask.bitwise_not()
|
||||
|
||||
token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT
|
||||
|
||||
input_ids.bitwise_and_(tokens_mask)
|
||||
|
||||
return token_type_ids
|
||||
|
||||
|
||||
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
SupportsQuant):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
This class encapsulates the BertModel and provides an interface for
|
||||
@@ -567,8 +600,13 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if token_type_ids is not None:
|
||||
assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
|
||||
assert input_ids is not None
|
||||
_encode_token_type_ids(input_ids, token_type_ids)
|
||||
|
||||
return self.bert(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
token_type_ids=token_type_ids)
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
Reference in New Issue
Block a user