[Model][1/N] Support multiple poolers at model level (#21227)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -15,7 +15,8 @@ from torch import nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
|
||||
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
|
||||
PoolingType)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
@@ -26,7 +27,7 @@ from .utils import AutoWeightsLoader, maybe_prefix
|
||||
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
is_pooling_model = True
|
||||
pooler: SimplePooler
|
||||
pooler: Pooler
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
@@ -94,12 +95,12 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
vllm_config.model_config.hf_config.num_labels = 1
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.ALL,
|
||||
normalize=False,
|
||||
softmax=False)
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)}, )
|
||||
|
||||
|
||||
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
||||
@@ -107,11 +108,17 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
vllm_config.model_config.hf_config.num_labels = 2
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.STEP,
|
||||
normalize=False,
|
||||
softmax=True,
|
||||
step_tag_id=151651,
|
||||
)
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(
|
||||
pooler_config,
|
||||
default_pooling_type=PoolingType.STEP,
|
||||
default_normalize=False,
|
||||
default_softmax=True,
|
||||
default_step_tag_id=151651,
|
||||
)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user