[Refactor] Separate sequence and token pooling types (#32026)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -96,7 +96,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
@default_pooling_type("ALL")
|
||||
@default_pooling_type(tok_pooling_type="ALL")
|
||||
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
vllm_config.model_config.hf_config.num_labels = 1
|
||||
@@ -108,7 +108,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
||||
self.pooler = pooler_for_token_classify(pooler_config)
|
||||
|
||||
|
||||
@default_pooling_type("STEP")
|
||||
@default_pooling_type(tok_pooling_type="STEP")
|
||||
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
vllm_config.model_config.hf_config.num_labels = 2
|
||||
|
||||
Reference in New Issue
Block a user