Support encoder-only models without KV-Cache (#21270)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
committed by
GitHub
parent
f27fdfc3ed
commit
1cd6eaba54
@@ -9,6 +9,7 @@ from torch import nn
|
||||
from transformers import RobertaConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
|
||||
DispatchPooler, Pooler)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -51,33 +52,12 @@ class RobertaEmbedding(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
input_shape = input_ids.size()
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
# Replace position ids because in RoBERTa models
|
||||
# they have to start at padding_idx + 1 and ignore
|
||||
# existing padding tokens
|
||||
# References:
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
new_pos_list = []
|
||||
for positions, tokens in zip(position_ids.split(seq_lens_list),
|
||||
input_ids.split(seq_lens_list)):
|
||||
# Verify assumption that incoming position are
|
||||
# always a sequence from 0 to N.
|
||||
expected_pos = torch.arange(positions.size()[0],
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device)
|
||||
assert torch.equal(positions, expected_pos)
|
||||
new_pos_list.append(
|
||||
create_position_ids_from_input_ids(tokens, self.padding_idx))
|
||||
position_ids = torch.cat(new_pos_list)
|
||||
|
||||
# Position embeddings.
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
if token_type_ids is None:
|
||||
@@ -119,6 +99,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Fix Roberta positions here outside of the CUDA graph.
|
||||
# Because we need the to extract the sequences from
|
||||
# input_ids the control flow is data dependent.
|
||||
replace_roberta_positions(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
padding_idx=self.padding_idx)
|
||||
|
||||
return self.model(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
def _build_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "") -> Union[BertModel, BertWithRope]:
|
||||
@@ -175,6 +181,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.roberta = BertModel(vllm_config=vllm_config,
|
||||
@@ -216,6 +223,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
replace_roberta_positions(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
padding_idx=self.padding_idx)
|
||||
return self.roberta(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -245,3 +255,36 @@ def create_position_ids_from_input_ids(input_ids,
|
||||
past_key_values_length) * mask
|
||||
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
def replace_roberta_positions(input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
padding_idx: int) -> None:
|
||||
|
||||
seq_lens: Optional[torch.Tensor] = None
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata is not None: # can be None during warmup
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = next(iter(attn_metadata.values()))
|
||||
# TODO: remove "seq_lens_tensor" after V0 is removed
|
||||
seq_lens = getattr(attn_metadata, "seq_lens_tensor",
|
||||
getattr(attn_metadata, "seq_lens", None))
|
||||
|
||||
if seq_lens is not None:
|
||||
assert isinstance(seq_lens, torch.Tensor)
|
||||
|
||||
# Replace position ids because in RoBERTa models
|
||||
# they have to start at padding_idx + 1 and ignore
|
||||
# existing padding tokens
|
||||
# References:
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
||||
token_list = torch.split(input_ids[:torch.sum(seq_lens)],
|
||||
seq_lens.tolist())
|
||||
|
||||
offset = 0
|
||||
for tokens in token_list:
|
||||
length = tokens.shape[0]
|
||||
position_ids[offset:offset+length] = \
|
||||
create_position_ids_from_input_ids(tokens, padding_idx)
|
||||
offset = offset + length
|
||||
|
||||
Reference in New Issue
Block a user