[Model] Systematic support for fp32 head, pooling models part (#23810)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-09-09 22:29:50 +08:00
committed by GitHub
parent a55cf41a09
commit 19332c0479
14 changed files with 166 additions and 61 deletions

View File

@@ -5,9 +5,9 @@ from typing import Optional
import torch
import torch.nn as nn
from transformers import BatchFeature, PretrainedConfig
from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -28,13 +28,17 @@ logger = init_logger(__name__)
class JinaVLScorer(nn.Module):
def __init__(self, config: PretrainedConfig):
def __init__(self, model_config: "ModelConfig"):
super().__init__()
config = model_config.hf_config
head_dtype = model_config.head_dtype
self.dense = ColumnParallelLinear(config.hidden_size,
config.hidden_size,
params_dtype=head_dtype,
bias=True)
self.out_proj = RowParallelLinear(config.hidden_size,
config.num_labels,
params_dtype=head_dtype,
bias=True)
def forward(self, x, **kwargs):
@@ -88,11 +92,10 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "qwen2_vl"))
config = vllm_config.model_config.hf_config
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.score = JinaVLScorer(config)
self.score = JinaVLScorer(vllm_config.model_config)
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),