[Model] Support math-shepherd-mistral-7b-prm model (#9697)
Signed-off-by: Went-Liang <wenteng_liang@163.com>
This commit is contained in:
@@ -11,7 +11,7 @@ from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
@@ -64,6 +64,7 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
pooler_config: Optional[PoolerConfig] = None,
|
||||
) -> None:
|
||||
# TODO (@robertgshaw2): see if this can be moved out
|
||||
if (cache_config.sliding_window is not None
|
||||
@@ -93,8 +94,11 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
|
||||
RowParallelLinear(config.hidden_size, 1,
|
||||
quant_config=quant_config),
|
||||
)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False)
|
||||
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.ALL,
|
||||
normalize=False,
|
||||
softmax=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user