diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 46211e6ed..c7c292e70 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -79,7 +79,14 @@ class RobertaEmbedding(nn.Module): if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) + # RoBERTa positions start at padding_idx + 1 instead of 0. + # Use non-in-place add to avoid mutating the persistent positions + # buffer -- in-place += would accumulate on CUDA graph padding + # slots that aren't refreshed between requests, eventually + # overflowing max_position_embeddings. + position_embeddings = self.position_embeddings( + position_ids + self.padding_idx + 1 + ) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings + position_embeddings @@ -123,13 +130,6 @@ class RobertaEmbeddingModel(BertEmbeddingModel): intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: - # Fix Roberta positions here outside of the CUDA graph. - # Because we need the to extract the sequences from - # input_ids the control flow is data dependent. - replace_roberta_positions( - input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx - ) - return self.model( input_ids=input_ids, positions=positions, @@ -324,9 +324,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): inputs_embeds: torch.Tensor | None = None, token_type_ids: torch.Tensor | None = None, ) -> torch.Tensor: - replace_roberta_positions( - input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx - ) if token_type_ids is not None: assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) assert input_ids is not None @@ -337,16 +334,3 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors, ) - - -def replace_roberta_positions( - input_ids: torch.Tensor, position_ids: torch.Tensor, padding_idx: int -) -> None: - # Replace position ids because in RoBERTa models - # they have to start at padding_idx + 1 and ignore - # existing padding tokens - # References: - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - # vllm does not use padding tokens, let's make things simpler - position_ids += padding_idx + 1 diff --git a/vllm/model_executor/models/transformers/legacy.py b/vllm/model_executor/models/transformers/legacy.py index 1704d0bfd..49c5e9dcf 100644 --- a/vllm/model_executor/models/transformers/legacy.py +++ b/vllm/model_executor/models/transformers/legacy.py @@ -65,8 +65,10 @@ class LegacyMixin: inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: if self.is_roberta: - # RoBERTa-specific positions padding - positions += self.padding_idx + 1 + # RoBERTa positions start at padding_idx + 1. + # Non-in-place add to avoid mutating the persistent GPU buffer -- + # in-place += would accumulate on CUDA graph padding slots. + positions = positions + self.padding_idx + 1 return super().forward( input_ids=input_ids, positions=positions,