[Refactor] Separate sequence and token pooling types (#32026)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -8,7 +8,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
from vllm.config.pooler import TokenPoolingType
|
||||
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
@@ -113,12 +113,10 @@ class StepPool(AllPool):
|
||||
return pooled_data
|
||||
|
||||
|
||||
def get_tok_pooling_method(pooling_type: PoolingTypeStr | str):
|
||||
def get_tok_pooling_method(pooling_type: TokenPoolingType | str):
|
||||
if pooling_type == "ALL":
|
||||
return AllPool()
|
||||
if pooling_type == "STEP":
|
||||
return StepPool()
|
||||
|
||||
# TODO: Separate seq and tok pooling types so we don't need this fallback
|
||||
return AllPool()
|
||||
raise NotImplementedError(f"Unknown tokenwise pooling type: {pooling_type!r}")
|
||||
|
||||
@@ -85,7 +85,7 @@ class TokenPooler(Pooler):
|
||||
|
||||
|
||||
def pooler_for_token_embed(pooler_config: PoolerConfig):
|
||||
pooling = get_tok_pooling_method(pooler_config.get_pooling_type())
|
||||
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
|
||||
head = TokenEmbeddingPoolerHead()
|
||||
|
||||
return TokenPooler(pooling=pooling, head=head)
|
||||
@@ -99,7 +99,7 @@ def pooler_for_token_classify(
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
):
|
||||
if pooling is None:
|
||||
pooling = get_tok_pooling_method(pooler_config.get_pooling_type())
|
||||
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
|
||||
|
||||
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user