[Model] Systematic support for fp32 head, pooling models part (#23810)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-09-09 22:29:50 +08:00
committed by GitHub
parent a55cf41a09
commit 19332c0479
14 changed files with 166 additions and 61 deletions

View File

@@ -62,7 +62,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
linear = nn.Linear(layer_config.get("in_features", 768),
layer_config.get("out_features", 768),
bias=layer_config.get("bias", True),
dtype=torch.float32)
dtype=model_config.head_dtype)
if not _load_dense_weights(linear, folder, model_config):
continue
@@ -70,7 +70,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
layers.append(linear)
if act_name := layer_config.get("activation_function"):
layers.append(get_act_fn(act_name))
return nn.Sequential(*layers).to(dtype=torch.float32)
return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
except Exception:
logger.exception("ST projector loading failed")
@@ -105,15 +105,13 @@ def _load_dense_weights(linear: nn.Linear, folder: str,
if weight_key in state_dict:
weight_loader = getattr(linear.weight, "weight_loader",
default_weight_loader)
weight_loader(linear.weight,
state_dict[weight_key].to(torch.float32))
weight_loader(linear.weight, state_dict[weight_key])
bias_key = weight_key.replace("weight", "bias")
if linear.bias is not None and bias_key in state_dict:
bias_loader = getattr(linear.bias, "weight_loader",
default_weight_loader)
bias_loader(linear.bias,
state_dict[bias_key].to(torch.float32))
bias_loader(linear.bias, state_dict[bias_key])
return True
except Exception:
logger.exception("Failed to load %s", filename)

View File

@@ -562,7 +562,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
self.bert = BertPoolingModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"),
embedding_class=BertEmbedding)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.classifier = nn.Linear(config.hidden_size,
config.num_labels,
dtype=vllm_config.model_config.head_dtype)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

View File

@@ -637,14 +637,14 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.new = GteNewModel(vllm_config=vllm_config,
prefix=prefix,
add_pooling_layer=True)
self.classifier = RowParallelLinear(config.hidden_size,
config.num_labels,
input_is_parallel=False,
bias=True,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "classifier"),
return_bias=False)
self.classifier = ReplicatedLinear(
config.hidden_size,
config.num_labels,
bias=True,
quant_config=quant_config,
params_dtype=vllm_config.model_config.head_dtype,
prefix=maybe_prefix(prefix, "classifier"),
return_bias=False)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

View File

@@ -339,7 +339,10 @@ class GPT2ForSequenceClassification(nn.Module):
config = vllm_config.model_config.hf_config
self.transformer = GPT2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "gpt2"))
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
self.score = nn.Linear(config.n_embd,
config.num_labels,
bias=False,
dtype=vllm_config.model_config.head_dtype)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
@@ -348,7 +351,7 @@ class GPT2ForSequenceClassification(nn.Module):
"encode":
Pooler.for_encode(pooler_config),
"classify":
Pooler.for_classify(pooler_config, classifier=None),
Pooler.for_classify(pooler_config, classifier=self.score),
})
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
@@ -367,8 +370,7 @@ class GPT2ForSequenceClassification(nn.Module):
position_ids=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
logits = self.score(hidden_states)
return logits
return hidden_states
def _add_transformer_prefix(

View File

@@ -423,13 +423,15 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
delattr(self, attr)
config = vllm_config.model_config.hf_config
self.v_head = RowParallelLinear(
config.hidden_size,
1,
bias=False,
input_is_parallel=False,
prefix=maybe_prefix(prefix, "v_head"),
)
self.head_dtype = vllm_config.model_config.head_dtype
self.v_head = RowParallelLinear(config.hidden_size,
1,
bias=False,
input_is_parallel=False,
params_dtype=self.head_dtype,
prefix=maybe_prefix(prefix, "v_head"),
return_bias=False)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
@@ -446,5 +448,6 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
logits, _ = self.v_head(hidden_states)
hidden_states = hidden_states.to(self.head_dtype)
logits = self.v_head(hidden_states)
return logits

View File

@@ -613,7 +613,7 @@ class JambaForSequenceClassification(JambaForCausalLM):
config.hidden_size,
num_labels,
bias=score_bias,
dtype=torch.float32,
dtype=vllm_config.model_config.head_dtype,
)
pooler_config = vllm_config.model_config.pooler_config

View File

@@ -5,9 +5,9 @@ from typing import Optional
import torch
import torch.nn as nn
from transformers import BatchFeature, PretrainedConfig
from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -28,13 +28,17 @@ logger = init_logger(__name__)
class JinaVLScorer(nn.Module):
def __init__(self, config: PretrainedConfig):
def __init__(self, model_config: "ModelConfig"):
super().__init__()
config = model_config.hf_config
head_dtype = model_config.head_dtype
self.dense = ColumnParallelLinear(config.hidden_size,
config.hidden_size,
params_dtype=head_dtype,
bias=True)
self.out_proj = RowParallelLinear(config.hidden_size,
config.num_labels,
params_dtype=head_dtype,
bias=True)
def forward(self, x, **kwargs):
@@ -88,11 +92,10 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "qwen2_vl"))
config = vllm_config.model_config.hf_config
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.score = JinaVLScorer(config)
self.score = JinaVLScorer(vllm_config.model_config)
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),

View File

@@ -306,7 +306,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.config = config
self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.classifier = nn.Linear(config.hidden_size,
config.num_labels,
dtype=vllm_config.model_config.head_dtype)
self.pooling = ModernBertPooler(config)
pooler_config = vllm_config.model_config.pooler_config

View File

@@ -53,15 +53,18 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
self.quant_config = quant_config
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.head_dtype = vllm_config.model_config.head_dtype
self.score = nn.Sequential(
ColumnParallelLinear(config.hidden_size,
config.hidden_size,
quant_config=quant_config,
params_dtype=self.head_dtype,
return_bias=False),
nn.ReLU(),
RowParallelLinear(config.hidden_size,
config.num_labels,
params_dtype=self.head_dtype,
quant_config=quant_config,
return_bias=False),
)
@@ -80,6 +83,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
hidden_states = hidden_states.to(self.head_dtype)
logits = self.score(hidden_states)
return logits

View File

@@ -8,7 +8,7 @@ import torch
from torch import nn
from transformers import RobertaConfig
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -73,10 +73,16 @@ class RobertaEmbedding(nn.Module):
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config: RobertaConfig):
def __init__(self, model_config: "ModelConfig"):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
config = model_config.hf_config
head_dtype = model_config.head_dtype
self.dense = nn.Linear(config.hidden_size,
config.hidden_size,
dtype=head_dtype)
self.out_proj = nn.Linear(config.hidden_size,
config.num_labels,
dtype=head_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# CLSPool has already been applied in `pooling`
@@ -184,7 +190,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.roberta = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"),
embedding_class=RobertaEmbedding)
self.classifier = RobertaClassificationHead(config)
self.classifier = RobertaClassificationHead(vllm_config.model_config)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None