[Model] Reorganize pooling layers (#31973)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-09 19:02:14 +08:00
committed by GitHub
parent 020732800c
commit c8ed39b9dd
34 changed files with 1290 additions and 1143 deletions

View File

@@ -18,19 +18,25 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.pooler import (
ClassifierPooler,
DispatchPooler,
Pooler,
PoolingMethod,
PoolingParamsUpdate,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
)
from vllm.model_executor.layers.pooler.seqwise import (
CLSPool,
SequencePooler,
SequencePoolerHeadOutput,
SequencePoolerOutput,
SequencePoolingMethodOutput,
)
from vllm.model_executor.layers.pooler.tokwise import (
pooler_for_token_classify,
pooler_for_token_embed,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding, SupportsQuant
@@ -85,25 +91,21 @@ class BertEmbedding(nn.Module):
return embeddings
class BertPooler(Pooler):
class BertPooler(SequencePooler):
def __init__(self, config: BertConfig):
super().__init__()
super().__init__(
pooling=CLSPool(),
head=self.head,
)
self.pooling = PoolingMethod.from_pooling_type("CLS")
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def head(
self,
pooled_data: TokenPoolingMethodOutput,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
) -> SequencePoolerHeadOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
@@ -111,15 +113,6 @@ class BertPooler(Pooler):
pooled_data = self.activation(pooled_data)
return pooled_data
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
class BertEncoder(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
@@ -524,12 +517,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
)
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
return DispatchPooler.for_embedding(pooler_config)
# Here we encode the token type ids together with the input ids.
@@ -620,6 +608,7 @@ class SPLADESparsePooler(Pooler):
remove_cls_sep: bool = True,
):
super().__init__()
assert pooling in ("max", "sum")
self.mlm_head = mlm_head
self.cls_token_id = cls_token_id
@@ -637,10 +626,8 @@ class SPLADESparsePooler(Pooler):
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
assert isinstance(hidden_states, torch.Tensor) and hidden_states.dim() == 2
lens_tensor: torch.Tensor = pooling_metadata.prompt_lens
) -> SequencePoolerOutput:
lens_tensor = pooling_metadata.prompt_lens
lens: list[int] = lens_tensor.tolist()
B: int = len(lens)
@@ -729,7 +716,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
return DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"token_embed": pooler_for_token_embed(pooler_config),
"embed": SPLADESparsePooler(
mlm_head=self.mlm_head,
cls_token_id=cls_id,
@@ -824,20 +811,10 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=self.bert.pooler,
classifier=self.classifier,
act_fn="classify",
),
"score": ClassifierPooler(
pooling=self.bert.pooler, classifier=self.classifier, act_fn="score"
),
}
self.pooler = DispatchPooler.for_seq_cls(
pooler_config,
pooling=self.bert.pooler,
classifier=self.classifier,
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
@@ -891,13 +868,7 @@ class BertForTokenClassification(nn.Module):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
),
}
)
self.pooler = pooler_for_token_classify(pooler_config)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.embed_input_ids(input_ids)