[Model] Support math-shepherd-mistral-7b-prm model (#9697)
Signed-off-by: Went-Liang <wenteng_liang@163.com>
This commit is contained in:
@@ -22,7 +22,7 @@ from transformers import Gemma2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
@@ -473,13 +473,17 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pooler_config: Optional[PoolerConfig] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model = Gemma2Model(**kwargs)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user