[Bugfix] Fix RobertaModel loading (#11940)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -20,6 +21,30 @@ from vllm.transformers_utils.config import (
|
||||
from .interfaces import SupportsCrossEncoding
|
||||
|
||||
|
||||
def roberta_task_weights_filter(
|
||||
all_weights: Iterable[Tuple[str, torch.Tensor]]
|
||||
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
|
||||
torch.Tensor]]]:
|
||||
"""
|
||||
Separate task-specific weights that are applied on top
|
||||
of the encoder-decoder bert base.
|
||||
To do so, return two generators over the original iterator.
|
||||
Also, remove the "roberta." prefix to make it loadable
|
||||
from vanilla BertModel.
|
||||
"""
|
||||
# Copy of a lazy iterator without in-memory overhead so both
|
||||
# iterators can be iterated upon independently.
|
||||
all_weights1, all_weights2 = itertools.tee(all_weights)
|
||||
|
||||
def encoder_decoder_weights():
|
||||
for name, weight in all_weights1:
|
||||
if name.startswith("roberta."):
|
||||
yield (name[len("roberta."):], weight)
|
||||
|
||||
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
|
||||
if not n.startswith("roberta."))
|
||||
|
||||
|
||||
class RobertaEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, config: RobertaConfig):
|
||||
@@ -152,6 +177,18 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
prefix=prefix,
|
||||
embedding_class=RobertaEmbedding)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
weights = self.hf_to_vllm_mapper.apply(weights)
|
||||
# Separate weights in "roberta"-prefixed and all else (not in memory).
|
||||
# For use with models like FacebookAI/roberta-base.
|
||||
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
||||
loaded = self.model.load_weights(bert_weights)
|
||||
if not len(loaded):
|
||||
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
|
||||
# which use the same architecture, but have no "roberta" prefix.
|
||||
loaded = self.model.load_weights(task_weights)
|
||||
assert len(loaded), "Unable to load RobertaEmbeddingModel"
|
||||
|
||||
|
||||
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
"""A model that uses Roberta to provide embedding functionalities.
|
||||
@@ -181,20 +218,12 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
self_weights = []
|
||||
|
||||
def weight_filter():
|
||||
for name, weight in weights:
|
||||
if name.startswith("roberta."):
|
||||
yield (name[len("roberta."):], weight)
|
||||
else:
|
||||
self_weights.append((name, weight))
|
||||
|
||||
self.roberta.load_weights(weight_filter())
|
||||
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
||||
self.roberta.load_weights(bert_weights)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in self_weights:
|
||||
for name, loaded_weight in task_weights:
|
||||
if name.startswith("classifier"):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
|
||||
Reference in New Issue
Block a user