[Refactor] Clean up pooler modules (#31897)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -5,7 +5,7 @@ from collections.abc import Callable, Mapping, Set
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from itertools import groupby
|
||||
from typing import TypeVar
|
||||
from typing import TypeAlias, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -18,8 +18,8 @@ from vllm.model_executor.models.adapters import _load_st_projector
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.v1.outputs import PoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
|
||||
from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokensPoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -30,6 +30,15 @@ PoolingFn = Callable[
|
||||
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
|
||||
|
||||
|
||||
TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
TokensPoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
|
||||
TokensPoolingMethodOutputItem: TypeAlias = torch.Tensor | None
|
||||
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokensPoolingMethodOutput
|
||||
|
||||
TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
TokensPoolerHeadOutput: TypeAlias = torch.Tensor | None
|
||||
|
||||
|
||||
class PoolingType(IntEnum):
|
||||
"""Enumeration for different types of pooling methods."""
|
||||
|
||||
@@ -123,31 +132,24 @@ class PoolingMethod(nn.Module, ABC):
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooling_cursor = pooling_metadata.pooling_cursor
|
||||
return self.forward_all(hidden_states, pooling_cursor)
|
||||
) -> PoolingMethodOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CLSPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def forward_all(
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> PoolerOutput:
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with CLS pooling"
|
||||
)
|
||||
@@ -159,11 +161,12 @@ class LastPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def forward_all(
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> PoolerOutput:
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
return hidden_states[pooling_cursor.last_token_indices_gpu]
|
||||
|
||||
|
||||
@@ -179,19 +182,12 @@ class AllPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify"}
|
||||
|
||||
def forward_all(
|
||||
self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError(
|
||||
"forward_all is not implemented for AllPool. Use forward instead."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooling_cursor = pooling_metadata.pooling_cursor
|
||||
) -> TokensPoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
is_finished = pooling_cursor.is_finished()
|
||||
hidden_states_lst = list(
|
||||
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
|
||||
@@ -209,7 +205,7 @@ class AllPool(PoolingMethod):
|
||||
p.hidden_states_cache.append(hs_chunk)
|
||||
|
||||
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
|
||||
output_list: PoolerOutput = []
|
||||
output_list = list[torch.Tensor | None]()
|
||||
for p, finished in zip(pooling_states, is_finished):
|
||||
if finished:
|
||||
hidden_states_cache = p.hidden_states_cache
|
||||
@@ -228,11 +224,12 @@ class MeanPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def forward_all(
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> PoolerOutput:
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with MEAN pooling"
|
||||
)
|
||||
@@ -410,7 +407,7 @@ class Pooler(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: list[torch.Tensor] | torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError
|
||||
@@ -422,41 +419,42 @@ class DummyPooler(Pooler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: list[torch.Tensor] | torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PoolerHead(nn.Module):
|
||||
def __init__(self, activation: PoolerActivation) -> None:
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
class TokenPoolerHead(nn.Module, ABC):
|
||||
"""Applicable to pooling strategies that output one token."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
return self.activation(pooled_data)
|
||||
) -> TokenPoolerHeadOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EmbeddingPoolerHead(PoolerHead):
|
||||
class EmbeddingPoolerHead(TokenPoolerHead):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(activation=PoolerNormalize())
|
||||
super().__init__()
|
||||
|
||||
# Load ST projector if available
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.projector: nn.Module | None = (
|
||||
self.projector = (
|
||||
_load_st_projector(vllm_config.model_config) if vllm_config else None
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
self.activation = PoolerNormalize()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
) -> TokenPoolerHeadOutput:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_dimension]
|
||||
@@ -509,7 +507,7 @@ class SimplePooler(Pooler):
|
||||
3. Returns structured results as `PoolerOutput`.
|
||||
"""
|
||||
|
||||
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
|
||||
def __init__(self, pooling: PoolingMethod, head: TokenPoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = pooling
|
||||
@@ -523,9 +521,9 @@ class SimplePooler(Pooler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
) -> TokenPoolerHeadOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
@@ -591,9 +589,9 @@ class ClassifierPooler(Pooler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
) -> TokenPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
@@ -622,10 +620,36 @@ class ClassifierPooler(Pooler):
|
||||
return scores
|
||||
|
||||
|
||||
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
|
||||
class TokensPoolerHead(nn.Module, ABC):
|
||||
"""Applicable to pooling strategies that output multiple tokens."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams
|
||||
) -> PoolerOutput:
|
||||
self,
|
||||
pooled_data: TokensPoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokensPoolerHeadOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TokenEmbeddingPoolerHead(TokensPoolerHead):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Load ST projector if available
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.projector = (
|
||||
_load_st_projector(vllm_config.model_config) if vllm_config else None
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
self.activation = PoolerNormalize()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: TokensPoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokensPoolerHeadOutput:
|
||||
# for unfinished chunked prefill
|
||||
if pooled_data is None:
|
||||
return None
|
||||
@@ -649,57 +673,56 @@ class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
|
||||
return pooled_data
|
||||
|
||||
|
||||
class TokenClassifierPoolerHead(nn.Module):
|
||||
class TokenClassifierPoolerHead(TokensPoolerHead):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: ClassifierFn | None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
self.classifier = classifier
|
||||
self.act_fn = ClassifierPooler.resolve_act_fn(
|
||||
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
|
||||
)
|
||||
self.logit_bias: float | None = (
|
||||
vllm_config.model_config.pooler_config.logit_bias
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_classify"}
|
||||
self.activation = ClassifierPooler.resolve_act_fn(
|
||||
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | None,
|
||||
pooled_data: TokensPoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> PoolerOutput:
|
||||
) -> TokensPoolerHeadOutput:
|
||||
# for unfinished chunked prefill
|
||||
if hidden_states is None:
|
||||
if pooled_data is None:
|
||||
return None
|
||||
|
||||
hidden_states = hidden_states.to(self.head_dtype)
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
# hidden_states shape: [n_token, hidden_size]
|
||||
|
||||
if self.classifier is not None:
|
||||
scores = self.classifier(hidden_states)
|
||||
scores = self.classifier(pooled_data)
|
||||
else:
|
||||
scores = hidden_states
|
||||
scores = pooled_data
|
||||
# scores shape: [n_token, num_labels]
|
||||
|
||||
if self.logit_bias is not None:
|
||||
scores -= self.logit_bias
|
||||
|
||||
if pooling_param.use_activation:
|
||||
scores = self.act_fn(scores)
|
||||
scores = self.activation(scores)
|
||||
|
||||
# scores shape: [n_token, num_labels]
|
||||
return scores
|
||||
|
||||
|
||||
class AllPooler(Pooler):
|
||||
def __init__(self, head: nn.Module | PoolerHead) -> None:
|
||||
def __init__(self, head: TokensPoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = AllPool()
|
||||
@@ -712,17 +735,16 @@ class AllPooler(Pooler):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
) -> TokensPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
return pooled_data
|
||||
return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
|
||||
|
||||
class StepPooler(Pooler):
|
||||
def __init__(self, head: nn.Module | PoolerHead) -> None:
|
||||
def __init__(self, head: TokensPoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = AllPool()
|
||||
@@ -730,14 +752,14 @@ class StepPooler(Pooler):
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
) -> list[torch.Tensor | None]:
|
||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
|
||||
pooled_data: PoolerOutput = []
|
||||
pooled_data = list[torch.Tensor | None]()
|
||||
for data, token_id, pooling_param in zip(
|
||||
pooled_data_lst, prompt_token_ids, pooling_params
|
||||
):
|
||||
@@ -766,15 +788,14 @@ class StepPooler(Pooler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
) -> TokensPoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
return pooled_data
|
||||
return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
|
||||
|
||||
class DispatchPooler(Pooler):
|
||||
@@ -800,12 +821,12 @@ class DispatchPooler(Pooler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
poolers_by_task = self.poolers_by_task
|
||||
|
||||
outputs = list[torch.Tensor]()
|
||||
outputs = list[torch.Tensor | None]()
|
||||
offset = 0
|
||||
for task, group in groupby(pooling_metadata.tasks):
|
||||
if not (pooler := poolers_by_task.get(task)):
|
||||
|
||||
Reference in New Issue
Block a user