[Bugfix] Fix encoder-only model support for transformers backend (#28021)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Isotr0py
2025-11-05 14:24:41 +08:00
committed by GitHub
parent 428bc7bf1c
commit 0ff05e3770
4 changed files with 16 additions and 10 deletions

View File

@@ -28,6 +28,7 @@ from transformers import AutoModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices
@@ -317,7 +318,7 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
# vLLM does not support encoder-decoder models, so if any encoder layer is
# found in a text only model, we assume the whole model is an encoder model
if has_encoder(self.model) and not is_multimodal(self.config):
self.check_version("4.57.0.dev0", "encoder models support")
self.check_version("5.0.0.dev0", "encoder models support")
attn_type = AttentionType.ENCODER_ONLY
else:
attn_type = AttentionType.DECODER
@@ -336,7 +337,12 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
):
per_layer_sliding_window = self.config.sliding_window
attention_instances[i] = Attention(
attn_cls = (
EncoderOnlyAttention
if attn_type == AttentionType.ENCODER_ONLY
else Attention
)
attention_instances[i] = attn_cls(
num_heads=num_heads,
head_size=head_size,
# NOTE: We use Llama scale as default, if it's set by