[Refactor] Clean up pooler modules (#31897)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-08 00:07:43 +08:00
committed by GitHub
parent cc6dafaef2
commit b7036c87a1
7 changed files with 167 additions and 120 deletions

View File

@@ -24,11 +24,14 @@ from vllm.model_executor.layers.pooler import (
PoolingMethod,
PoolingParamsUpdate,
PoolingType,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding, SupportsQuant
@@ -97,24 +100,26 @@ class BertPooler(Pooler):
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def _head(self, pooled_output: torch.Tensor):
pooled_output = self.dense(pooled_output)
pooled_output = self.activation(pooled_output)
return pooled_output
def head(
self,
pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
pooled_data = self.dense(pooled_data)
pooled_data = self.activation(pooled_data)
return pooled_data
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_output, list):
pooled_output = [self._head(output) for output in pooled_output]
else:
pooled_output = self._head(pooled_output)
return pooled_output
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
class BertEncoder(nn.Module):