[Model] Add support for nvidia/llama-nemotron-rerank-vl-1b-v2 (#35735)
Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
This commit is contained in:
@@ -664,6 +664,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
|
||||
"LlamaBidirectionalModel": LlamaBidirectionalConfig,
|
||||
"LlamaNemotronVLModel": LlamaNemotronVLConfig,
|
||||
"LlamaNemotronVLForSequenceClassification": LlamaNemotronVLConfig,
|
||||
"NomicBertModel": NomicBertModelConfig,
|
||||
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
|
||||
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
# Copyright (c) 2023 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
import math
|
||||
from abc import ABC
|
||||
from collections.abc import Iterable
|
||||
|
||||
@@ -18,6 +19,7 @@ from transformers import AutoModel, PretrainedConfig
|
||||
from transformers.image_processing_utils_fast import BaseImageProcessorFast
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
@@ -42,6 +44,7 @@ from vllm.transformers_utils.repo_utils import get_hf_file_to_dict
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsCrossEncoding,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
@@ -883,3 +886,57 @@ class LlamaNemotronVLForEmbedding(LlamaNemotronVLChatModel, VllmModelForPooling)
|
||||
"""Override to use different weight mapping for SigLIP."""
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.weight_mapper)
|
||||
|
||||
|
||||
class LlamaNemotronVLForSequenceClassification(
|
||||
LlamaNemotronVLForEmbedding, SupportsCrossEncoding
|
||||
):
|
||||
"""LlamaNemotronVL model variant for sequence classification / reranking."""
|
||||
|
||||
# Reranker checkpoint places base model weights under `model.*`,
|
||||
# while `score.*` remains at the top level.
|
||||
weight_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) | (
|
||||
LlamaNemotronVLForEmbedding.weight_mapper
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
text_config = vllm_config.model_config.hf_config.get_text_config()
|
||||
model_config = vllm_config.model_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.score = ReplicatedLinear(
|
||||
model_config.get_hidden_size(),
|
||||
text_config.num_labels,
|
||||
bias=False,
|
||||
params_dtype=model_config.head_dtype,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=maybe_prefix(prefix, "score"),
|
||||
)
|
||||
|
||||
pooler_config = model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loaded_weights = super().load_weights(weights)
|
||||
|
||||
# reranker checkpoint omits the inner LM seq-cls head
|
||||
# (`language_model.score.*`). It is unused by this outer model, but
|
||||
# the default loader expects all parameters to be initialized.
|
||||
for name, param in self.named_parameters():
|
||||
if not name.startswith("language_model.score.") or name in loaded_weights:
|
||||
continue
|
||||
|
||||
if name.endswith(".weight"):
|
||||
torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
|
||||
elif name.endswith(".bias"):
|
||||
torch.nn.init.zeros_(param)
|
||||
else:
|
||||
torch.nn.init.normal_(param, mean=0.0, std=0.02)
|
||||
|
||||
loaded_weights.add(name)
|
||||
|
||||
return loaded_weights
|
||||
|
||||
@@ -284,6 +284,10 @@ _CROSS_ENCODER_MODELS = {
|
||||
"llama",
|
||||
"LlamaBidirectionalForSequenceClassification",
|
||||
),
|
||||
"LlamaNemotronVLForSequenceClassification": (
|
||||
"nemotron_vl",
|
||||
"LlamaNemotronVLForSequenceClassification",
|
||||
),
|
||||
"ModernBertForSequenceClassification": (
|
||||
"modernbert",
|
||||
"ModernBertForSequenceClassification",
|
||||
|
||||
Reference in New Issue
Block a user