[Model] Systematic support for fp32 head, pooling models part (#23810)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -5,9 +5,9 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@@ -28,13 +28,17 @@ logger = init_logger(__name__)
|
||||
|
||||
class JinaVLScorer(nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
def __init__(self, model_config: "ModelConfig"):
|
||||
super().__init__()
|
||||
config = model_config.hf_config
|
||||
head_dtype = model_config.head_dtype
|
||||
self.dense = ColumnParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
params_dtype=head_dtype,
|
||||
bias=True)
|
||||
self.out_proj = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
params_dtype=head_dtype,
|
||||
bias=True)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
@@ -88,11 +92,10 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "qwen2_vl"))
|
||||
config = vllm_config.model_config.hf_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.score = JinaVLScorer(config)
|
||||
self.score = JinaVLScorer(vllm_config.model_config)
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
|
||||
Reference in New Issue
Block a user