Fix pooling adapters for Transformers backend (#27338)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-24 04:23:55 +01:00
committed by GitHub
parent 70022ffc00
commit 1f9460c4c1
6 changed files with 97 additions and 74 deletions

View File

@@ -49,6 +49,7 @@ from vllm.model_executor.models.transformers.utils import (
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
@@ -92,6 +93,27 @@ ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Add `model.` prefix for base model checkpoints,
# handling the case where it is already present
"": "model.",
"model.model.": "model.",
# Heads will be adjacent to `model` (pooling included because of adapters)
"model.lm_head.": "lm_head.",
"model.score.": "classifier.",
"model.classifier.": "classifier.",
}
)
def __init_subclass__(cls, *args, **kwargs):
"""Merge hf_to_vllm_mapper in MRO from most specific to least specific."""
super().__init_subclass__(*args, **kwargs)
hf_to_vllm_mapper = WeightsMapper()
for base in cls.__mro__:
if base_hf_to_vllm_mapper := getattr(base, "hf_to_vllm_mapper", None):
hf_to_vllm_mapper |= base_hf_to_vllm_mapper
cls.hf_to_vllm_mapper = hf_to_vllm_mapper
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
super().__init__()

View File

@@ -34,13 +34,6 @@ class LegacyMixin:
# Handle BERT-like models
"roberta": "model",
"bert": "model",
# Add `model.` prefix for base model checkpoints
"": "model.",
# Remove `model.` prefix if it was already there
"model.model.": "model.",
# Classifier/scoring heads will be adjacent to `model`
"model.score": "classifier",
"model.classifier": "classifier",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms

View File

@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
import torch
from transformers import AutoModelForSequenceClassification
from vllm.config.utils import getattr_iter
from vllm.model_executor.layers.pooler import (
ClassifierPooler,
CLSPool,
@@ -82,14 +83,14 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
if hasattr(module, "pooler") and module.pooler is None:
self.model.pooler = None
break
if self.model.pooler is not None:
raise ValueError(
"Sequence classification models with pooling layers are not "
"supported yet in the Transformers backend."
)
# Unlike `lm_head`, `classifier` is not always `nn.Linear`.
self.classifier = seq_cls_model.classifier
self.classifier = getattr_iter(seq_cls_model, ["classifier", "score"], None)
if self.classifier is None:
raise ValueError(
"Could not find `classifier` or `score` layer in the "
"`AutoModelForSequenceClassification` instance."
)
self.init_parameters(self.classifier, dtype=self.model_config.head_dtype)
class ClassifierWithReshape(self.classifier.__class__):