Enable headless models for pooling in the Transformers backend (#21767)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -270,8 +270,9 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
|
||||
}
|
||||
|
||||
_TRANSFORMERS_BACKEND_MODELS = {
|
||||
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
|
||||
"TransformersModel": ("transformers", "TransformersModel"),
|
||||
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
||||
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
|
||||
}
|
||||
# yapf: enable
|
||||
|
||||
|
||||
@@ -651,6 +651,18 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class TransformersModel(TransformersBase):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# Add `model.` prefix for base model checkpoints
|
||||
"": "model.",
|
||||
# Remove `model.` from places it should not be
|
||||
"model.model.": "model.",
|
||||
"model.score": "score",
|
||||
})
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class TransformersForCausalLM(TransformersBase):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user