[Model] Update pooling model interface (#21058)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-18 00:05:40 +08:00
committed by GitHub
parent 9fb2d22032
commit 90bd2ab6e3
17 changed files with 247 additions and 345 deletions

View File

@@ -28,9 +28,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
@@ -404,6 +403,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
class InternLM2ForRewardModel(InternLM2ForCausalLM):
is_pooling_model = True
def __init__(
self,
*,
@@ -428,7 +429,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
)
pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
@@ -446,10 +447,3 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
inputs_embeds)
logits, _ = self.v_head(hidden_states)
return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)