[Refactor] Clean up pooler modules (#31897)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -5,7 +5,7 @@ from collections.abc import Callable, Mapping, Set
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from itertools import groupby
|
||||
from typing import TypeVar
|
||||
from typing import TypeAlias, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -18,8 +18,8 @@ from vllm.model_executor.models.adapters import _load_st_projector
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.v1.outputs import PoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
|
||||
from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokensPoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -30,6 +30,15 @@ PoolingFn = Callable[
|
||||
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
|
||||
|
||||
|
||||
TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
TokensPoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
|
||||
TokensPoolingMethodOutputItem: TypeAlias = torch.Tensor | None
|
||||
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokensPoolingMethodOutput
|
||||
|
||||
TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
TokensPoolerHeadOutput: TypeAlias = torch.Tensor | None
|
||||
|
||||
|
||||
class PoolingType(IntEnum):
|
||||
"""Enumeration for different types of pooling methods."""
|
||||
|
||||
@@ -123,31 +132,24 @@ class PoolingMethod(nn.Module, ABC):
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooling_cursor = pooling_metadata.pooling_cursor
|
||||
return self.forward_all(hidden_states, pooling_cursor)
|
||||
) -> PoolingMethodOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CLSPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def forward_all(
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> PoolerOutput:
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with CLS pooling"
|
||||
)
|
||||
@@ -159,11 +161,12 @@ class LastPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def forward_all(
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> PoolerOutput:
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
return hidden_states[pooling_cursor.last_token_indices_gpu]
|
||||
|
||||
|
||||
@@ -179,19 +182,12 @@ class AllPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify"}
|
||||
|
||||
def forward_all(
|
||||
self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError(
|
||||
"forward_all is not implemented for AllPool. Use forward instead."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooling_cursor = pooling_metadata.pooling_cursor
|
||||
) -> TokensPoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
is_finished = pooling_cursor.is_finished()
|
||||
hidden_states_lst = list(
|
||||
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
|
||||
@@ -209,7 +205,7 @@ class AllPool(PoolingMethod):
|
||||
p.hidden_states_cache.append(hs_chunk)
|
||||
|
||||
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
|
||||
output_list: PoolerOutput = []
|
||||
output_list = list[torch.Tensor | None]()
|
||||
for p, finished in zip(pooling_states, is_finished):
|
||||
if finished:
|
||||
hidden_states_cache = p.hidden_states_cache
|
||||
@@ -228,11 +224,12 @@ class MeanPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def forward_all(
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> PoolerOutput:
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with MEAN pooling"
|
||||
)
|
||||
@@ -410,7 +407,7 @@ class Pooler(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: list[torch.Tensor] | torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError
|
||||
@@ -422,41 +419,42 @@ class DummyPooler(Pooler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: list[torch.Tensor] | torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PoolerHead(nn.Module):
|
||||
def __init__(self, activation: PoolerActivation) -> None:
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
class TokenPoolerHead(nn.Module, ABC):
|
||||
"""Applicable to pooling strategies that output one token."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
return self.activation(pooled_data)
|
||||
) -> TokenPoolerHeadOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EmbeddingPoolerHead(PoolerHead):
|
||||
class EmbeddingPoolerHead(TokenPoolerHead):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(activation=PoolerNormalize())
|
||||
super().__init__()
|
||||
|
||||
# Load ST projector if available
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.projector: nn.Module | None = (
|
||||
self.projector = (
|
||||
_load_st_projector(vllm_config.model_config) if vllm_config else None
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
self.activation = PoolerNormalize()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
) -> TokenPoolerHeadOutput:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_dimension]
|
||||
@@ -509,7 +507,7 @@ class SimplePooler(Pooler):
|
||||
3. Returns structured results as `PoolerOutput`.
|
||||
"""
|
||||
|
||||
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
|
||||
def __init__(self, pooling: PoolingMethod, head: TokenPoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = pooling
|
||||
@@ -523,9 +521,9 @@ class SimplePooler(Pooler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
) -> TokenPoolerHeadOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
@@ -591,9 +589,9 @@ class ClassifierPooler(Pooler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
) -> TokenPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
@@ -622,10 +620,36 @@ class ClassifierPooler(Pooler):
|
||||
return scores
|
||||
|
||||
|
||||
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
|
||||
class TokensPoolerHead(nn.Module, ABC):
|
||||
"""Applicable to pooling strategies that output multiple tokens."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams
|
||||
) -> PoolerOutput:
|
||||
self,
|
||||
pooled_data: TokensPoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokensPoolerHeadOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TokenEmbeddingPoolerHead(TokensPoolerHead):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Load ST projector if available
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.projector = (
|
||||
_load_st_projector(vllm_config.model_config) if vllm_config else None
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
self.activation = PoolerNormalize()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: TokensPoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokensPoolerHeadOutput:
|
||||
# for unfinished chunked prefill
|
||||
if pooled_data is None:
|
||||
return None
|
||||
@@ -649,57 +673,56 @@ class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
|
||||
return pooled_data
|
||||
|
||||
|
||||
class TokenClassifierPoolerHead(nn.Module):
|
||||
class TokenClassifierPoolerHead(TokensPoolerHead):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: ClassifierFn | None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
self.classifier = classifier
|
||||
self.act_fn = ClassifierPooler.resolve_act_fn(
|
||||
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
|
||||
)
|
||||
self.logit_bias: float | None = (
|
||||
vllm_config.model_config.pooler_config.logit_bias
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_classify"}
|
||||
self.activation = ClassifierPooler.resolve_act_fn(
|
||||
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | None,
|
||||
pooled_data: TokensPoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> PoolerOutput:
|
||||
) -> TokensPoolerHeadOutput:
|
||||
# for unfinished chunked prefill
|
||||
if hidden_states is None:
|
||||
if pooled_data is None:
|
||||
return None
|
||||
|
||||
hidden_states = hidden_states.to(self.head_dtype)
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
# hidden_states shape: [n_token, hidden_size]
|
||||
|
||||
if self.classifier is not None:
|
||||
scores = self.classifier(hidden_states)
|
||||
scores = self.classifier(pooled_data)
|
||||
else:
|
||||
scores = hidden_states
|
||||
scores = pooled_data
|
||||
# scores shape: [n_token, num_labels]
|
||||
|
||||
if self.logit_bias is not None:
|
||||
scores -= self.logit_bias
|
||||
|
||||
if pooling_param.use_activation:
|
||||
scores = self.act_fn(scores)
|
||||
scores = self.activation(scores)
|
||||
|
||||
# scores shape: [n_token, num_labels]
|
||||
return scores
|
||||
|
||||
|
||||
class AllPooler(Pooler):
|
||||
def __init__(self, head: nn.Module | PoolerHead) -> None:
|
||||
def __init__(self, head: TokensPoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = AllPool()
|
||||
@@ -712,17 +735,16 @@ class AllPooler(Pooler):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
) -> TokensPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
return pooled_data
|
||||
return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
|
||||
|
||||
class StepPooler(Pooler):
|
||||
def __init__(self, head: nn.Module | PoolerHead) -> None:
|
||||
def __init__(self, head: TokensPoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = AllPool()
|
||||
@@ -730,14 +752,14 @@ class StepPooler(Pooler):
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
) -> list[torch.Tensor | None]:
|
||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
|
||||
pooled_data: PoolerOutput = []
|
||||
pooled_data = list[torch.Tensor | None]()
|
||||
for data, token_id, pooling_param in zip(
|
||||
pooled_data_lst, prompt_token_ids, pooling_params
|
||||
):
|
||||
@@ -766,15 +788,14 @@ class StepPooler(Pooler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
) -> TokensPoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
return pooled_data
|
||||
return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
|
||||
|
||||
class DispatchPooler(Pooler):
|
||||
@@ -800,12 +821,12 @@ class DispatchPooler(Pooler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
poolers_by_task = self.poolers_by_task
|
||||
|
||||
outputs = list[torch.Tensor]()
|
||||
outputs = list[torch.Tensor | None]()
|
||||
offset = 0
|
||||
for task, group in groupby(pooling_metadata.tasks):
|
||||
if not (pooler := poolers_by_task.get(task)):
|
||||
|
||||
@@ -24,11 +24,14 @@ from vllm.model_executor.layers.pooler import (
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
PoolingType,
|
||||
TokenPoolerHeadOutput,
|
||||
TokenPoolingMethodOutput,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.outputs import TokenPoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||
@@ -97,24 +100,26 @@ class BertPooler(Pooler):
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def _head(self, pooled_output: torch.Tensor):
|
||||
pooled_output = self.dense(pooled_output)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
def head(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerHeadOutput:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
|
||||
pooled_data = self.dense(pooled_data)
|
||||
pooled_data = self.activation(pooled_data)
|
||||
return pooled_data
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
||||
|
||||
if isinstance(pooled_output, list):
|
||||
pooled_output = [self._head(output) for output in pooled_output]
|
||||
else:
|
||||
pooled_output = self._head(pooled_output)
|
||||
|
||||
return pooled_output
|
||||
) -> TokenPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
@@ -4,21 +4,22 @@ from collections.abc import Set
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
DispatchPooler,
|
||||
Pooler,
|
||||
PoolerHead,
|
||||
PoolerNormalize,
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
TokenPoolerHeadOutput,
|
||||
TokenPoolingMethodOutput,
|
||||
)
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
from vllm.v1.outputs import PoolerOutput
|
||||
from vllm.v1.outputs import TokenPoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces_base import default_pooling_type
|
||||
@@ -26,7 +27,7 @@ from .interfaces_base import default_pooling_type
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GritLMMeanPool(nn.Module):
|
||||
class GritLMMeanPool(PoolingMethod):
|
||||
"""As `MeanPool`, but only includes non-instruction tokens."""
|
||||
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
@@ -141,16 +142,16 @@ class GritLMMeanPool(nn.Module):
|
||||
return instruction_len
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode", "embed"}
|
||||
return {"embed"}
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
) -> TokenPoolingMethodOutput:
|
||||
prompt_lens = pooling_metadata.prompt_lens
|
||||
instr_lens = torch.tensor(
|
||||
[
|
||||
@@ -178,7 +179,7 @@ class GritLMPooler(Pooler):
|
||||
super().__init__()
|
||||
|
||||
self.pooling = GritLMMeanPool(model_config)
|
||||
self.head = PoolerHead(PoolerNormalize())
|
||||
self.activation = PoolerNormalize()
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return self.pooling.get_supported_tasks()
|
||||
@@ -186,11 +187,18 @@ class GritLMPooler(Pooler):
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def head(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerHeadOutput:
|
||||
return self.activation(pooled_data)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
) -> TokenPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
@@ -19,12 +19,15 @@ from vllm.model_executor.layers.pooler import (
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
PoolingType,
|
||||
TokenPoolerHeadOutput,
|
||||
TokenPoolingMethodOutput,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.outputs import TokenPoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces import SupportsCrossEncoding
|
||||
@@ -300,23 +303,25 @@ class ModernBertPooler(Pooler):
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def _head(self, pooled_output: torch.Tensor):
|
||||
pooled_output = pooled_output.to(self.dense.weight.dtype)
|
||||
return self.norm(self.act(self.dense(pooled_output)))
|
||||
def head(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerHeadOutput:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
|
||||
pooled_data = pooled_data.to(self.dense.weight.dtype)
|
||||
return self.norm(self.act(self.dense(pooled_data)))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
||||
|
||||
if isinstance(pooled_output, list):
|
||||
pooled_output = [self._head(output) for output in pooled_output]
|
||||
else:
|
||||
pooled_output = self._head(pooled_output)
|
||||
|
||||
return pooled_output
|
||||
) -> TokenPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
from typing import TYPE_CHECKING, NamedTuple, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -91,7 +91,9 @@ class LogprobsTensors(NamedTuple):
|
||||
|
||||
# [num_reqs, <dynamic>]
|
||||
# The shape of each element depends on the pooler used
|
||||
PoolerOutput = list[torch.Tensor | None] | torch.Tensor | None
|
||||
TokenPoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
TokensPoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
|
||||
PoolerOutput: TypeAlias = TokenPoolerOutput | TokensPoolerOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -90,6 +90,12 @@ class PoolingMetadata:
|
||||
|
||||
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
|
||||
|
||||
def get_pooling_cursor(self) -> PoolingCursor:
|
||||
pooling_cursor = self.pooling_cursor
|
||||
assert pooling_cursor is not None, "Should call `build_pooling_cursor` first"
|
||||
|
||||
return pooling_cursor
|
||||
|
||||
def build_pooling_cursor(
|
||||
self,
|
||||
num_scheduled_tokens_np: np.ndarray,
|
||||
|
||||
@@ -4680,7 +4680,7 @@ class GPUModelRunner(
|
||||
for task in supported_pooling_tasks:
|
||||
# Run a full batch with each task to ensure none of them OOMs
|
||||
output = self._dummy_pooler_run_task(hidden_states, task)
|
||||
output_size[task] = sum(o.nbytes for o in output)
|
||||
output_size[task] = sum(o.nbytes for o in output if o is not None)
|
||||
del output # Allow GC
|
||||
|
||||
max_task = max(output_size.items(), key=lambda x: x[1])[0]
|
||||
|
||||
Reference in New Issue
Block a user