Support encoder-only models without KV-Cache (#21270)

Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Maximilien de Bayser
2025-07-26 10:09:52 -03:00
committed by GitHub
parent f27fdfc3ed
commit 1cd6eaba54
17 changed files with 352 additions and 99 deletions

View File

@@ -9,6 +9,7 @@ from torch import nn
from transformers import RobertaConfig
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -51,33 +52,12 @@ class RobertaEmbedding(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)
# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
seq_lens_list = seq_lens.tolist()
new_pos_list = []
for positions, tokens in zip(position_ids.split(seq_lens_list),
input_ids.split(seq_lens_list)):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
dtype=torch.long,
device=inputs_embeds.device)
assert torch.equal(positions, expected_pos)
new_pos_list.append(
create_position_ids_from_input_ids(tokens, self.padding_idx))
position_ids = torch.cat(new_pos_list)
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
@@ -119,6 +99,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Fix Roberta positions here outside of the CUDA graph.
# Because we need the to extract the sequences from
# input_ids the control flow is data dependent.
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)
return self.model(input_ids=input_ids,
position_ids=positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> Union[BertModel, BertWithRope]:
@@ -175,6 +181,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
self.num_labels = config.num_labels
self.roberta = BertModel(vllm_config=vllm_config,
@@ -216,6 +223,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)
return self.roberta(input_ids=input_ids,
position_ids=positions,
inputs_embeds=inputs_embeds,
@@ -245,3 +255,36 @@ def create_position_ids_from_input_ids(input_ids,
past_key_values_length) * mask
return incremental_indices.long() + padding_idx
def replace_roberta_positions(input_ids: torch.Tensor,
position_ids: torch.Tensor,
padding_idx: int) -> None:
seq_lens: Optional[torch.Tensor] = None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None: # can be None during warmup
if isinstance(attn_metadata, dict):
attn_metadata = next(iter(attn_metadata.values()))
# TODO: remove "seq_lens_tensor" after V0 is removed
seq_lens = getattr(attn_metadata, "seq_lens_tensor",
getattr(attn_metadata, "seq_lens", None))
if seq_lens is not None:
assert isinstance(seq_lens, torch.Tensor)
# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
token_list = torch.split(input_ids[:torch.sum(seq_lens)],
seq_lens.tolist())
offset = 0
for tokens in token_list:
length = tokens.shape[0]
position_ids[offset:offset+length] = \
create_position_ids_from_input_ids(tokens, padding_idx)
offset = offset + length