[Model] Support math-shepherd-mistral-7b-prm model (#9697)

Signed-off-by: Went-Liang <wenteng_liang@163.com>
This commit is contained in:
Went-Liang
2024-10-31 00:33:42 +08:00
committed by GitHub
parent cc98f1e079
commit 81f09cfd80
14 changed files with 312 additions and 62 deletions

View File

@@ -29,7 +29,7 @@ from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
@@ -502,6 +502,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
pooler_config: Optional[PoolerConfig] = None,
) -> None:
super().__init__()
@@ -543,6 +544,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.STEP,
normalize=False,
softmax=False)
def forward(
self,
@@ -565,6 +571,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
logits = self.compute_logits(hidden_states, None)
return self._pooler(logits, pooling_metadata)
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
@@ -630,12 +644,17 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP):
def __init__(
self,
pooler_config: Optional[PoolerConfig] = None,
**kwargs,
) -> None:
super().__init__()
self.model = LlamaModel(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)