[Model] Systematic support for fp32 head, pooling models part (#23810)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -62,7 +62,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
|
||||
linear = nn.Linear(layer_config.get("in_features", 768),
|
||||
layer_config.get("out_features", 768),
|
||||
bias=layer_config.get("bias", True),
|
||||
dtype=torch.float32)
|
||||
dtype=model_config.head_dtype)
|
||||
|
||||
if not _load_dense_weights(linear, folder, model_config):
|
||||
continue
|
||||
@@ -70,7 +70,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
|
||||
layers.append(linear)
|
||||
if act_name := layer_config.get("activation_function"):
|
||||
layers.append(get_act_fn(act_name))
|
||||
return nn.Sequential(*layers).to(dtype=torch.float32)
|
||||
return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
|
||||
except Exception:
|
||||
logger.exception("ST projector loading failed")
|
||||
|
||||
@@ -105,15 +105,13 @@ def _load_dense_weights(linear: nn.Linear, folder: str,
|
||||
if weight_key in state_dict:
|
||||
weight_loader = getattr(linear.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(linear.weight,
|
||||
state_dict[weight_key].to(torch.float32))
|
||||
weight_loader(linear.weight, state_dict[weight_key])
|
||||
|
||||
bias_key = weight_key.replace("weight", "bias")
|
||||
if linear.bias is not None and bias_key in state_dict:
|
||||
bias_loader = getattr(linear.bias, "weight_loader",
|
||||
default_weight_loader)
|
||||
bias_loader(linear.bias,
|
||||
state_dict[bias_key].to(torch.float32))
|
||||
bias_loader(linear.bias, state_dict[bias_key])
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to load %s", filename)
|
||||
|
||||
@@ -562,7 +562,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
self.bert = BertPoolingModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "bert"),
|
||||
embedding_class=BertEmbedding)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self.classifier = nn.Linear(config.hidden_size,
|
||||
config.num_labels,
|
||||
dtype=vllm_config.model_config.head_dtype)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
@@ -637,14 +637,14 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
self.new = GteNewModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
add_pooling_layer=True)
|
||||
self.classifier = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
input_is_parallel=False,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "classifier"),
|
||||
return_bias=False)
|
||||
self.classifier = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_labels,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
params_dtype=vllm_config.model_config.head_dtype,
|
||||
prefix=maybe_prefix(prefix, "classifier"),
|
||||
return_bias=False)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
@@ -339,7 +339,10 @@ class GPT2ForSequenceClassification(nn.Module):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.transformer = GPT2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "gpt2"))
|
||||
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
|
||||
self.score = nn.Linear(config.n_embd,
|
||||
config.num_labels,
|
||||
bias=False,
|
||||
dtype=vllm_config.model_config.head_dtype)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
@@ -348,7 +351,7 @@ class GPT2ForSequenceClassification(nn.Module):
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"classify":
|
||||
Pooler.for_classify(pooler_config, classifier=None),
|
||||
Pooler.for_classify(pooler_config, classifier=self.score),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
@@ -367,8 +370,7 @@ class GPT2ForSequenceClassification(nn.Module):
|
||||
position_ids=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
logits = self.score(hidden_states)
|
||||
return logits
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _add_transformer_prefix(
|
||||
|
||||
@@ -423,13 +423,15 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
delattr(self, attr)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.v_head = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
1,
|
||||
bias=False,
|
||||
input_is_parallel=False,
|
||||
prefix=maybe_prefix(prefix, "v_head"),
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
self.v_head = RowParallelLinear(config.hidden_size,
|
||||
1,
|
||||
bias=False,
|
||||
input_is_parallel=False,
|
||||
params_dtype=self.head_dtype,
|
||||
prefix=maybe_prefix(prefix, "v_head"),
|
||||
return_bias=False)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
@@ -446,5 +448,6 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
logits, _ = self.v_head(hidden_states)
|
||||
hidden_states = hidden_states.to(self.head_dtype)
|
||||
logits = self.v_head(hidden_states)
|
||||
return logits
|
||||
|
||||
@@ -613,7 +613,7 @@ class JambaForSequenceClassification(JambaForCausalLM):
|
||||
config.hidden_size,
|
||||
num_labels,
|
||||
bias=score_bias,
|
||||
dtype=torch.float32,
|
||||
dtype=vllm_config.model_config.head_dtype,
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@@ -28,13 +28,17 @@ logger = init_logger(__name__)
|
||||
|
||||
class JinaVLScorer(nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
def __init__(self, model_config: "ModelConfig"):
|
||||
super().__init__()
|
||||
config = model_config.hf_config
|
||||
head_dtype = model_config.head_dtype
|
||||
self.dense = ColumnParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
params_dtype=head_dtype,
|
||||
bias=True)
|
||||
self.out_proj = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
params_dtype=head_dtype,
|
||||
bias=True)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
@@ -88,11 +92,10 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "qwen2_vl"))
|
||||
config = vllm_config.model_config.hf_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.score = JinaVLScorer(config)
|
||||
self.score = JinaVLScorer(vllm_config.model_config)
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
|
||||
@@ -306,7 +306,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
self.config = config
|
||||
self.model = ModernBertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "modernbert"))
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self.classifier = nn.Linear(config.hidden_size,
|
||||
config.num_labels,
|
||||
dtype=vllm_config.model_config.head_dtype)
|
||||
self.pooling = ModernBertPooler(config)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
|
||||
@@ -53,15 +53,18 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
self.score = nn.Sequential(
|
||||
ColumnParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
params_dtype=self.head_dtype,
|
||||
return_bias=False),
|
||||
nn.ReLU(),
|
||||
RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
params_dtype=self.head_dtype,
|
||||
quant_config=quant_config,
|
||||
return_bias=False),
|
||||
)
|
||||
@@ -80,6 +83,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = hidden_states.to(self.head_dtype)
|
||||
logits = self.score(hidden_states)
|
||||
return logits
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import RobertaConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
|
||||
DispatchPooler, Pooler)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -73,10 +73,16 @@ class RobertaEmbedding(nn.Module):
|
||||
class RobertaClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
def __init__(self, config: RobertaConfig):
|
||||
def __init__(self, model_config: "ModelConfig"):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
||||
config = model_config.hf_config
|
||||
head_dtype = model_config.head_dtype
|
||||
self.dense = nn.Linear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
dtype=head_dtype)
|
||||
self.out_proj = nn.Linear(config.hidden_size,
|
||||
config.num_labels,
|
||||
dtype=head_dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# CLSPool has already been applied in `pooling`
|
||||
@@ -184,7 +190,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
self.roberta = BertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "bert"),
|
||||
embedding_class=RobertaEmbedding)
|
||||
self.classifier = RobertaClassificationHead(config)
|
||||
self.classifier = RobertaClassificationHead(vllm_config.model_config)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
Reference in New Issue
Block a user