[Model] Support math-shepherd-mistral-7b-prm model (#9697)
Signed-off-by: Went-Liang <wenteng_liang@163.com>
This commit is contained in:
@@ -29,7 +29,7 @@ from transformers import LlamaConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@@ -502,6 +502,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
pooler_config: Optional[PoolerConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -543,6 +544,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.STEP,
|
||||
normalize=False,
|
||||
softmax=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -565,6 +571,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
logits = self.compute_logits(hidden_states, None)
|
||||
return self._pooler(logits, pooling_metadata)
|
||||
|
||||
def sample(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
@@ -630,12 +644,17 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pooler_config: Optional[PoolerConfig] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model = LlamaModel(**kwargs)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user