[Refactor] Clean up pooler modules (#31897)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-08 00:07:43 +08:00
committed by GitHub
parent cc6dafaef2
commit b7036c87a1
7 changed files with 167 additions and 120 deletions

View File

@@ -24,11 +24,14 @@ from vllm.model_executor.layers.pooler import (
PoolingMethod,
PoolingParamsUpdate,
PoolingType,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding, SupportsQuant
@@ -97,24 +100,26 @@ class BertPooler(Pooler):
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def _head(self, pooled_output: torch.Tensor):
pooled_output = self.dense(pooled_output)
pooled_output = self.activation(pooled_output)
return pooled_output
def head(
self,
pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
pooled_data = self.dense(pooled_data)
pooled_data = self.activation(pooled_data)
return pooled_data
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_output, list):
pooled_output = [self._head(output) for output in pooled_output]
else:
pooled_output = self._head(pooled_output)
return pooled_output
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
class BertEncoder(nn.Module):

View File

@@ -4,21 +4,22 @@ from collections.abc import Set
import numpy as np
import torch
import torch.nn as nn
from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import (
DispatchPooler,
Pooler,
PoolerHead,
PoolerNormalize,
PoolingMethod,
PoolingParamsUpdate,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
)
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.tasks import PoolingTask
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.v1.outputs import PoolerOutput
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces_base import default_pooling_type
@@ -26,7 +27,7 @@ from .interfaces_base import default_pooling_type
logger = init_logger(__name__)
class GritLMMeanPool(nn.Module):
class GritLMMeanPool(PoolingMethod):
"""As `MeanPool`, but only includes non-instruction tokens."""
def __init__(self, model_config: ModelConfig):
@@ -141,16 +142,16 @@ class GritLMMeanPool(nn.Module):
return instruction_len
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed"}
return {"embed"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor] | torch.Tensor:
) -> TokenPoolingMethodOutput:
prompt_lens = pooling_metadata.prompt_lens
instr_lens = torch.tensor(
[
@@ -178,7 +179,7 @@ class GritLMPooler(Pooler):
super().__init__()
self.pooling = GritLMMeanPool(model_config)
self.head = PoolerHead(PoolerNormalize())
self.activation = PoolerNormalize()
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
@@ -186,11 +187,18 @@ class GritLMPooler(Pooler):
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def head(
self,
pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
return self.activation(pooled_data)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data

View File

@@ -19,12 +19,15 @@ from vllm.model_executor.layers.pooler import (
PoolingMethod,
PoolingParamsUpdate,
PoolingType,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding
@@ -300,23 +303,25 @@ class ModernBertPooler(Pooler):
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def _head(self, pooled_output: torch.Tensor):
pooled_output = pooled_output.to(self.dense.weight.dtype)
return self.norm(self.act(self.dense(pooled_output)))
def head(
self,
pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
pooled_data = pooled_data.to(self.dense.weight.dtype)
return self.norm(self.act(self.dense(pooled_data)))
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_output, list):
pooled_output = [self._head(output) for output in pooled_output]
else:
pooled_output = self._head(pooled_output)
return pooled_output
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
@default_pooling_type("CLS")