2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-06-03 11:20:17 -07:00
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
2025-02-02 14:58:18 -05:00
|
|
|
|
2024-09-29 12:19:39 +08:00
|
|
|
# Adapted from
|
|
|
|
|
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
|
|
|
|
|
# Copyright 2024 The Qwen team.
|
|
|
|
|
# Copyright 2023 The vLLM team.
|
|
|
|
|
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
|
2025-10-05 15:06:22 +01:00
|
|
|
|
2025-05-15 06:06:50 +01:00
|
|
|
from collections.abc import Iterable
|
2024-09-29 12:19:39 +08:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch import nn
|
|
|
|
|
|
2024-11-08 22:17:28 -08:00
|
|
|
from vllm.config import VllmConfig
|
2024-09-29 12:19:39 +08:00
|
|
|
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
2026-01-09 19:02:14 +08:00
|
|
|
from vllm.model_executor.layers.pooler import Pooler
|
|
|
|
|
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
|
2025-07-18 00:05:40 +08:00
|
|
|
from vllm.sequence import IntermediateTensors
|
2024-09-29 12:19:39 +08:00
|
|
|
|
2025-08-27 21:24:09 +08:00
|
|
|
from .interfaces import SupportsLoRA, SupportsPP
|
|
|
|
|
from .interfaces_base import default_pooling_type
|
2024-10-06 16:35:27 +08:00
|
|
|
from .qwen2 import Qwen2Model
|
2024-11-10 22:41:46 -08:00
|
|
|
from .utils import AutoWeightsLoader, maybe_prefix
|
2024-09-29 12:19:39 +08:00
|
|
|
|
|
|
|
|
|
2025-06-23 17:31:06 +08:00
|
|
|
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
2025-07-18 00:05:40 +08:00
|
|
|
is_pooling_model = True
|
2025-07-21 17:22:21 +08:00
|
|
|
pooler: Pooler
|
2025-07-18 00:05:40 +08:00
|
|
|
|
2024-09-29 12:19:39 +08:00
|
|
|
packed_modules_mapping = {
|
|
|
|
|
"qkv_proj": [
|
|
|
|
|
"q_proj",
|
|
|
|
|
"k_proj",
|
|
|
|
|
"v_proj",
|
|
|
|
|
],
|
|
|
|
|
"gate_up_proj": [
|
|
|
|
|
"gate_proj",
|
|
|
|
|
"up_proj",
|
|
|
|
|
],
|
|
|
|
|
}
|
|
|
|
|
|
2024-11-10 22:41:46 -08:00
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
2024-11-08 22:17:28 -08:00
|
|
|
super().__init__()
|
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
|
|
|
quant_config = vllm_config.quant_config
|
2024-09-29 12:19:39 +08:00
|
|
|
|
|
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
|
self.quant_config = quant_config
|
2024-11-10 22:41:46 -08:00
|
|
|
self.model = Qwen2Model(
|
|
|
|
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
2025-10-05 15:06:22 +01:00
|
|
|
)
|
2025-09-09 22:29:50 +08:00
|
|
|
self.head_dtype = vllm_config.model_config.head_dtype
|
2024-09-29 12:19:39 +08:00
|
|
|
|
|
|
|
|
self.score = nn.Sequential(
|
|
|
|
|
ColumnParallelLinear(
|
|
|
|
|
config.hidden_size,
|
|
|
|
|
config.hidden_size,
|
2025-06-23 17:31:06 +08:00
|
|
|
quant_config=quant_config,
|
2025-09-09 22:29:50 +08:00
|
|
|
params_dtype=self.head_dtype,
|
2025-06-23 17:31:06 +08:00
|
|
|
return_bias=False,
|
|
|
|
|
),
|
|
|
|
|
nn.ReLU(),
|
2025-01-20 14:59:46 +08:00
|
|
|
RowParallelLinear(
|
|
|
|
|
config.hidden_size,
|
|
|
|
|
config.num_labels,
|
2025-09-09 22:29:50 +08:00
|
|
|
params_dtype=self.head_dtype,
|
2025-06-23 17:31:06 +08:00
|
|
|
quant_config=quant_config,
|
|
|
|
|
return_bias=False,
|
|
|
|
|
),
|
2024-09-29 12:19:39 +08:00
|
|
|
)
|
2024-10-06 16:35:27 +08:00
|
|
|
self.make_empty_intermediate_tensors = (
|
|
|
|
|
self.model.make_empty_intermediate_tensors
|
|
|
|
|
)
|
|
|
|
|
|
2025-11-13 03:14:33 +00:00
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
return self.model.embed_input_ids(input_ids)
|
2024-11-16 21:18:46 -08:00
|
|
|
|
2024-09-29 12:19:39 +08:00
|
|
|
def forward(
|
|
|
|
|
self,
|
2026-01-26 22:02:10 +08:00
|
|
|
input_ids: torch.Tensor | None,
|
2024-09-29 12:19:39 +08:00
|
|
|
positions: torch.Tensor,
|
|
|
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
2024-11-16 21:18:46 -08:00
|
|
|
inputs_embeds: torch.Tensor | None = None,
|
2024-10-06 16:35:27 +08:00
|
|
|
) -> torch.Tensor | IntermediateTensors:
|
2025-02-25 01:13:52 +00:00
|
|
|
hidden_states = self.model(
|
|
|
|
|
input_ids, positions, intermediate_tensors, inputs_embeds
|
2024-11-16 21:18:46 -08:00
|
|
|
)
|
2025-09-09 22:29:50 +08:00
|
|
|
hidden_states = hidden_states.to(self.head_dtype)
|
2025-06-23 17:31:06 +08:00
|
|
|
logits = self.score(hidden_states)
|
2024-09-29 12:19:39 +08:00
|
|
|
return logits
|
|
|
|
|
|
2025-05-15 06:06:50 +01:00
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
2024-10-24 14:12:05 +08:00
|
|
|
loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["lm_head."])
|
2024-11-18 09:07:46 +08:00
|
|
|
return loader.load_weights(weights)
|
2025-01-20 14:59:46 +08:00
|
|
|
|
|
|
|
|
|
2026-01-10 12:53:24 +08:00
|
|
|
@default_pooling_type(tok_pooling_type="ALL")
|
2025-01-20 14:59:46 +08:00
|
|
|
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
2025-07-18 00:05:40 +08:00
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
2025-01-20 14:59:46 +08:00
|
|
|
vllm_config.model_config.hf_config.num_labels = 1
|
|
|
|
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
2025-07-21 17:22:21 +08:00
|
|
|
|
2025-01-20 14:59:46 +08:00
|
|
|
pooler_config = vllm_config.model_config.pooler_config
|
2025-07-21 17:22:21 +08:00
|
|
|
assert pooler_config is not None
|
|
|
|
|
|
2026-01-09 19:02:14 +08:00
|
|
|
self.pooler = pooler_for_token_classify(pooler_config)
|
2025-01-20 14:59:46 +08:00
|
|
|
|
|
|
|
|
|
2026-01-10 12:53:24 +08:00
|
|
|
@default_pooling_type(tok_pooling_type="STEP")
|
2025-01-20 14:59:46 +08:00
|
|
|
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
2025-07-18 00:05:40 +08:00
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
2025-01-20 14:59:46 +08:00
|
|
|
vllm_config.model_config.hf_config.num_labels = 2
|
|
|
|
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
2025-07-21 17:22:21 +08:00
|
|
|
|
2025-01-20 14:59:46 +08:00
|
|
|
pooler_config = vllm_config.model_config.pooler_config
|
2025-07-21 17:22:21 +08:00
|
|
|
assert pooler_config is not None
|
|
|
|
|
|
2026-01-09 19:02:14 +08:00
|
|
|
self.pooler = pooler_for_token_classify(pooler_config)
|