[Refactor] Separate sequence and token pooling types (#32026)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-10 12:53:24 +08:00
committed by GitHub
parent 52d428295d
commit 583a90e005
42 changed files with 324 additions and 204 deletions

View File

@@ -7,7 +7,7 @@ from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.config.pooler import PoolingTypeStr
from vllm.config.pooler import SequencePoolingType
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
@@ -82,11 +82,11 @@ class MeanPool(SequencePoolingMethod):
) / prompt_lens.unsqueeze(1)
def get_seq_pooling_method(pooling_type: PoolingTypeStr | str):
if pooling_type == "LAST":
return LastPool()
def get_seq_pooling_method(pooling_type: SequencePoolingType | str):
if pooling_type == "CLS":
return CLSPool()
if pooling_type == "LAST":
return LastPool()
if pooling_type == "MEAN":
return MeanPool()

View File

@@ -85,7 +85,7 @@ class SequencePooler(Pooler):
def pooler_for_embed(pooler_config: PoolerConfig):
pooling = get_seq_pooling_method(pooler_config.get_pooling_type())
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
head = EmbeddingPoolerHead()
return SequencePooler(pooling=pooling, head=head)
@@ -99,7 +99,7 @@ def pooler_for_classify(
act_fn: PoolerActivation | str | None = None,
):
if pooling is None:
pooling = get_seq_pooling_method(pooler_config.get_pooling_type())
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn)

View File

@@ -8,7 +8,7 @@ import torch
import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.config.pooler import PoolingTypeStr
from vllm.config.pooler import TokenPoolingType
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
@@ -113,12 +113,10 @@ class StepPool(AllPool):
return pooled_data
def get_tok_pooling_method(pooling_type: PoolingTypeStr | str):
def get_tok_pooling_method(pooling_type: TokenPoolingType | str):
if pooling_type == "ALL":
return AllPool()
if pooling_type == "STEP":
return StepPool()
# TODO: Separate seq and tok pooling types so we don't need this fallback
return AllPool()
raise NotImplementedError(f"Unknown tokenwise pooling type: {pooling_type!r}")

View File

@@ -85,7 +85,7 @@ class TokenPooler(Pooler):
def pooler_for_token_embed(pooler_config: PoolerConfig):
pooling = get_tok_pooling_method(pooler_config.get_pooling_type())
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
head = TokenEmbeddingPoolerHead()
return TokenPooler(pooling=pooling, head=head)
@@ -99,7 +99,7 @@ def pooler_for_token_classify(
act_fn: PoolerActivation | str | None = None,
):
if pooling is None:
pooling = get_tok_pooling_method(pooler_config.get_pooling_type())
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)