[Model][1/N] Support multiple poolers at model level (#21227)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-21 17:22:21 +08:00
committed by GitHub
parent 378d33c392
commit 042af0c8d3
22 changed files with 549 additions and 413 deletions

View File

@@ -15,7 +15,8 @@ from torch import nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
PoolingType)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
@@ -26,7 +27,7 @@ from .utils import AutoWeightsLoader, maybe_prefix
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
is_pooling_model = True
pooler: SimplePooler
pooler: Pooler
packed_modules_mapping = {
"qkv_proj": [
@@ -94,12 +95,12 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 1
super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
softmax=False)
assert pooler_config is not None
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, )
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
@@ -107,11 +108,17 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 2
super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.STEP,
normalize=False,
softmax=True,
step_tag_id=151651,
)
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(
pooler_config,
default_pooling_type=PoolingType.STEP,
default_normalize=False,
default_softmax=True,
default_step_tag_id=151651,
)
})