diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index d1942689d..f2f518353 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -5,7 +5,7 @@ from collections.abc import Callable, Mapping, Set from dataclasses import dataclass from enum import IntEnum from itertools import groupby -from typing import TypeVar +from typing import TypeAlias, TypeVar import torch import torch.nn as nn @@ -18,8 +18,8 @@ from vllm.model_executor.models.adapters import _load_st_projector from vllm.pooling_params import PoolingParams from vllm.tasks import PoolingTask from vllm.utils.import_utils import resolve_obj_by_qualname -from vllm.v1.outputs import PoolerOutput -from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata +from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokensPoolerOutput +from vllm.v1.pool.metadata import PoolingMetadata logger = init_logger(__name__) @@ -30,6 +30,15 @@ PoolingFn = Callable[ ClassifierFn = Callable[[torch.Tensor], torch.Tensor] +TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor] +TokensPoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None] +TokensPoolingMethodOutputItem: TypeAlias = torch.Tensor | None +PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokensPoolingMethodOutput + +TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor] +TokensPoolerHeadOutput: TypeAlias = torch.Tensor | None + + class PoolingType(IntEnum): """Enumeration for different types of pooling methods.""" @@ -123,31 +132,24 @@ class PoolingMethod(nn.Module, ABC): return PoolingParamsUpdate() @abstractmethod - def forward_all( - self, - hidden_states: torch.Tensor, - pooling_cursor: PoolingCursor, - ) -> PoolerOutput: - raise NotImplementedError - def forward( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - pooling_cursor = pooling_metadata.pooling_cursor - return self.forward_all(hidden_states, pooling_cursor) + ) -> PoolingMethodOutput: + raise NotImplementedError class CLSPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"token_embed", "token_classify", "embed", "classify", "score"} - def forward_all( + def forward( self, hidden_states: torch.Tensor, - pooling_cursor: PoolingCursor, - ) -> PoolerOutput: + pooling_metadata: PoolingMetadata, + ) -> TokenPoolingMethodOutput: + pooling_cursor = pooling_metadata.get_pooling_cursor() assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with CLS pooling" ) @@ -159,11 +161,12 @@ class LastPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"token_embed", "token_classify", "embed", "classify", "score"} - def forward_all( + def forward( self, hidden_states: torch.Tensor, - pooling_cursor: PoolingCursor, - ) -> PoolerOutput: + pooling_metadata: PoolingMetadata, + ) -> TokenPoolingMethodOutput: + pooling_cursor = pooling_metadata.get_pooling_cursor() return hidden_states[pooling_cursor.last_token_indices_gpu] @@ -179,19 +182,12 @@ class AllPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"token_embed", "token_classify"} - def forward_all( - self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor - ) -> PoolerOutput: - raise NotImplementedError( - "forward_all is not implemented for AllPool. Use forward instead." - ) - def forward( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - pooling_cursor = pooling_metadata.pooling_cursor + ) -> TokensPoolingMethodOutput: + pooling_cursor = pooling_metadata.get_pooling_cursor() is_finished = pooling_cursor.is_finished() hidden_states_lst = list( hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist()) @@ -209,7 +205,7 @@ class AllPool(PoolingMethod): p.hidden_states_cache.append(hs_chunk) # 2. Once prefill is finished, send hidden_states_cache to PoolerHead - output_list: PoolerOutput = [] + output_list = list[torch.Tensor | None]() for p, finished in zip(pooling_states, is_finished): if finished: hidden_states_cache = p.hidden_states_cache @@ -228,11 +224,12 @@ class MeanPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"token_embed", "token_classify", "embed", "classify", "score"} - def forward_all( + def forward( self, hidden_states: torch.Tensor, - pooling_cursor: PoolingCursor, - ) -> PoolerOutput: + pooling_metadata: PoolingMetadata, + ) -> TokenPoolingMethodOutput: + pooling_cursor = pooling_metadata.get_pooling_cursor() assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with MEAN pooling" ) @@ -410,7 +407,7 @@ class Pooler(nn.Module, ABC): @abstractmethod def forward( self, - hidden_states: list[torch.Tensor] | torch.Tensor, + hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> PoolerOutput: raise NotImplementedError @@ -422,41 +419,42 @@ class DummyPooler(Pooler): def forward( self, - hidden_states: list[torch.Tensor] | torch.Tensor, + hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> PoolerOutput: return hidden_states -class PoolerHead(nn.Module): - def __init__(self, activation: PoolerActivation) -> None: - super().__init__() - self.activation = activation +class TokenPoolerHead(nn.Module, ABC): + """Applicable to pooling strategies that output one token.""" + @abstractmethod def forward( self, - pooled_data: list[torch.Tensor] | torch.Tensor, + pooled_data: TokenPoolingMethodOutput, pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - return self.activation(pooled_data) + ) -> TokenPoolerHeadOutput: + raise NotImplementedError -class EmbeddingPoolerHead(PoolerHead): +class EmbeddingPoolerHead(TokenPoolerHead): def __init__(self) -> None: - super().__init__(activation=PoolerNormalize()) + super().__init__() # Load ST projector if available vllm_config = get_current_vllm_config() - self.projector: nn.Module | None = ( + self.projector = ( _load_st_projector(vllm_config.model_config) if vllm_config else None ) self.head_dtype = vllm_config.model_config.head_dtype + self.activation = PoolerNormalize() + def forward( self, - pooled_data: list[torch.Tensor] | torch.Tensor, + pooled_data: TokenPoolingMethodOutput, pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: + ) -> TokenPoolerHeadOutput: if isinstance(pooled_data, list): pooled_data = torch.stack(pooled_data) # pooled_data shape: [batchsize, hidden_dimension] @@ -509,7 +507,7 @@ class SimplePooler(Pooler): 3. Returns structured results as `PoolerOutput`. """ - def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None: + def __init__(self, pooling: PoolingMethod, head: TokenPoolerHead) -> None: super().__init__() self.pooling = pooling @@ -523,9 +521,9 @@ class SimplePooler(Pooler): def forward( self, - hidden_states: torch.Tensor | list[torch.Tensor], + hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: + ) -> TokenPoolerHeadOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) pooled_data = self.head(pooled_data, pooling_metadata) return pooled_data @@ -591,9 +589,9 @@ class ClassifierPooler(Pooler): def forward( self, - hidden_states: torch.Tensor | list[torch.Tensor], + hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: + ) -> TokenPoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) if isinstance(pooled_data, list): pooled_data = torch.stack(pooled_data) @@ -622,10 +620,36 @@ class ClassifierPooler(Pooler): return scores -class TokenEmbeddingPoolerHead(EmbeddingPoolerHead): +class TokensPoolerHead(nn.Module, ABC): + """Applicable to pooling strategies that output multiple tokens.""" + + @abstractmethod def forward( - self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams - ) -> PoolerOutput: + self, + pooled_data: TokensPoolingMethodOutputItem, + pooling_param: PoolingParams, + ) -> TokensPoolerHeadOutput: + raise NotImplementedError + + +class TokenEmbeddingPoolerHead(TokensPoolerHead): + def __init__(self) -> None: + super().__init__() + + # Load ST projector if available + vllm_config = get_current_vllm_config() + self.projector = ( + _load_st_projector(vllm_config.model_config) if vllm_config else None + ) + self.head_dtype = vllm_config.model_config.head_dtype + + self.activation = PoolerNormalize() + + def forward( + self, + pooled_data: TokensPoolingMethodOutputItem, + pooling_param: PoolingParams, + ) -> TokensPoolerHeadOutput: # for unfinished chunked prefill if pooled_data is None: return None @@ -649,57 +673,56 @@ class TokenEmbeddingPoolerHead(EmbeddingPoolerHead): return pooled_data -class TokenClassifierPoolerHead(nn.Module): +class TokenClassifierPoolerHead(TokensPoolerHead): def __init__( self, classifier: ClassifierFn | None, act_fn: PoolerActivation | str | None = None, ) -> None: super().__init__() + vllm_config = get_current_vllm_config() self.classifier = classifier - self.act_fn = ClassifierPooler.resolve_act_fn( - vllm_config.model_config, static_num_labels=False, act_fn=act_fn - ) self.logit_bias: float | None = ( vllm_config.model_config.pooler_config.logit_bias ) self.head_dtype = vllm_config.model_config.head_dtype - def get_supported_tasks(self) -> Set[PoolingTask]: - return {"token_classify"} + self.activation = ClassifierPooler.resolve_act_fn( + vllm_config.model_config, static_num_labels=False, act_fn=act_fn + ) def forward( self, - hidden_states: torch.Tensor | None, + pooled_data: TokensPoolingMethodOutputItem, pooling_param: PoolingParams, - ) -> PoolerOutput: + ) -> TokensPoolerHeadOutput: # for unfinished chunked prefill - if hidden_states is None: + if pooled_data is None: return None - hidden_states = hidden_states.to(self.head_dtype) + pooled_data = pooled_data.to(self.head_dtype) # hidden_states shape: [n_token, hidden_size] if self.classifier is not None: - scores = self.classifier(hidden_states) + scores = self.classifier(pooled_data) else: - scores = hidden_states + scores = pooled_data # scores shape: [n_token, num_labels] if self.logit_bias is not None: scores -= self.logit_bias if pooling_param.use_activation: - scores = self.act_fn(scores) + scores = self.activation(scores) # scores shape: [n_token, num_labels] return scores class AllPooler(Pooler): - def __init__(self, head: nn.Module | PoolerHead) -> None: + def __init__(self, head: TokensPoolerHead) -> None: super().__init__() self.pooling = AllPool() @@ -712,17 +735,16 @@ class AllPooler(Pooler): self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: + ) -> TokensPoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) pooling_params = pooling_metadata.pooling_params assert len(pooled_data) == len(pooling_params) - pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] - return pooled_data + return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] class StepPooler(Pooler): - def __init__(self, head: nn.Module | PoolerHead) -> None: + def __init__(self, head: TokensPoolerHead) -> None: super().__init__() self.pooling = AllPool() @@ -730,14 +752,14 @@ class StepPooler(Pooler): def extract_states( self, - hidden_states: torch.Tensor | list[torch.Tensor], + hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: + ) -> list[torch.Tensor | None]: pooled_data_lst = self.pooling(hidden_states, pooling_metadata) prompt_token_ids = pooling_metadata.get_prompt_token_ids() pooling_params = pooling_metadata.pooling_params - pooled_data: PoolerOutput = [] + pooled_data = list[torch.Tensor | None]() for data, token_id, pooling_param in zip( pooled_data_lst, prompt_token_ids, pooling_params ): @@ -766,15 +788,14 @@ class StepPooler(Pooler): def forward( self, - hidden_states: torch.Tensor | list[torch.Tensor], + hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: + ) -> TokensPoolerOutput: pooled_data = self.extract_states(hidden_states, pooling_metadata) pooling_params = pooling_metadata.pooling_params assert len(pooled_data) == len(pooling_params) - pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] - return pooled_data + return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] class DispatchPooler(Pooler): @@ -800,12 +821,12 @@ class DispatchPooler(Pooler): def forward( self, - hidden_states: torch.Tensor | list[torch.Tensor], + hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> PoolerOutput: poolers_by_task = self.poolers_by_task - outputs = list[torch.Tensor]() + outputs = list[torch.Tensor | None]() offset = 0 for task, group in groupby(pooling_metadata.tasks): if not (pooler := poolers_by_task.get(task)): diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index ee429bf45..8fee4bfe4 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -24,11 +24,14 @@ from vllm.model_executor.layers.pooler import ( PoolingMethod, PoolingParamsUpdate, PoolingType, + TokenPoolerHeadOutput, + TokenPoolingMethodOutput, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask +from vllm.v1.outputs import TokenPoolerOutput from vllm.v1.pool.metadata import PoolingMetadata from .interfaces import SupportsCrossEncoding, SupportsQuant @@ -97,24 +100,26 @@ class BertPooler(Pooler): def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: return self.pooling.get_pooling_updates(task) - def _head(self, pooled_output: torch.Tensor): - pooled_output = self.dense(pooled_output) - pooled_output = self.activation(pooled_output) - return pooled_output + def head( + self, + pooled_data: TokenPoolingMethodOutput, + pooling_metadata: PoolingMetadata, + ) -> TokenPoolerHeadOutput: + if isinstance(pooled_data, list): + pooled_data = torch.stack(pooled_data) + + pooled_data = self.dense(pooled_data) + pooled_data = self.activation(pooled_data) + return pooled_data def forward( self, - hidden_states: torch.Tensor | list[torch.Tensor], + hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> torch.Tensor | list[torch.Tensor]: - pooled_output = self.pooling(hidden_states, pooling_metadata) - - if isinstance(pooled_output, list): - pooled_output = [self._head(output) for output in pooled_output] - else: - pooled_output = self._head(pooled_output) - - return pooled_output + ) -> TokenPoolerOutput: + pooled_data = self.pooling(hidden_states, pooling_metadata) + pooled_data = self.head(pooled_data, pooling_metadata) + return pooled_data class BertEncoder(nn.Module): diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 2aba626a7..5bd731e6e 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -4,21 +4,22 @@ from collections.abc import Set import numpy as np import torch -import torch.nn as nn from vllm.config import ModelConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.pooler import ( DispatchPooler, Pooler, - PoolerHead, PoolerNormalize, + PoolingMethod, PoolingParamsUpdate, + TokenPoolerHeadOutput, + TokenPoolingMethodOutput, ) from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.tasks import PoolingTask from vllm.tokenizers import cached_tokenizer_from_config -from vllm.v1.outputs import PoolerOutput +from vllm.v1.outputs import TokenPoolerOutput from vllm.v1.pool.metadata import PoolingMetadata from .interfaces_base import default_pooling_type @@ -26,7 +27,7 @@ from .interfaces_base import default_pooling_type logger = init_logger(__name__) -class GritLMMeanPool(nn.Module): +class GritLMMeanPool(PoolingMethod): """As `MeanPool`, but only includes non-instruction tokens.""" def __init__(self, model_config: ModelConfig): @@ -141,16 +142,16 @@ class GritLMMeanPool(nn.Module): return instruction_len def get_supported_tasks(self) -> Set[PoolingTask]: - return {"encode", "embed"} + return {"embed"} def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: return PoolingParamsUpdate(requires_token_ids=True) def forward( self, - hidden_states: torch.Tensor | list[torch.Tensor], + hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> TokenPoolingMethodOutput: prompt_lens = pooling_metadata.prompt_lens instr_lens = torch.tensor( [ @@ -178,7 +179,7 @@ class GritLMPooler(Pooler): super().__init__() self.pooling = GritLMMeanPool(model_config) - self.head = PoolerHead(PoolerNormalize()) + self.activation = PoolerNormalize() def get_supported_tasks(self) -> Set[PoolingTask]: return self.pooling.get_supported_tasks() @@ -186,11 +187,18 @@ class GritLMPooler(Pooler): def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: return self.pooling.get_pooling_updates(task) + def head( + self, + pooled_data: TokenPoolingMethodOutput, + pooling_metadata: PoolingMetadata, + ) -> TokenPoolerHeadOutput: + return self.activation(pooled_data) + def forward( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: + ) -> TokenPoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) pooled_data = self.head(pooled_data, pooling_metadata) return pooled_data diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index fb8f6a28e..45d4a3c6a 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -19,12 +19,15 @@ from vllm.model_executor.layers.pooler import ( PoolingMethod, PoolingParamsUpdate, PoolingType, + TokenPoolerHeadOutput, + TokenPoolingMethodOutput, ) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask +from vllm.v1.outputs import TokenPoolerOutput from vllm.v1.pool.metadata import PoolingMetadata from .interfaces import SupportsCrossEncoding @@ -300,23 +303,25 @@ class ModernBertPooler(Pooler): def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: return self.pooling.get_pooling_updates(task) - def _head(self, pooled_output: torch.Tensor): - pooled_output = pooled_output.to(self.dense.weight.dtype) - return self.norm(self.act(self.dense(pooled_output))) + def head( + self, + pooled_data: TokenPoolingMethodOutput, + pooling_metadata: PoolingMetadata, + ) -> TokenPoolerHeadOutput: + if isinstance(pooled_data, list): + pooled_data = torch.stack(pooled_data) + + pooled_data = pooled_data.to(self.dense.weight.dtype) + return self.norm(self.act(self.dense(pooled_data))) def forward( self, - hidden_states: torch.Tensor | list[torch.Tensor], + hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> torch.Tensor | list[torch.Tensor]: - pooled_output = self.pooling(hidden_states, pooling_metadata) - - if isinstance(pooled_output, list): - pooled_output = [self._head(output) for output in pooled_output] - else: - pooled_output = self._head(pooled_output) - - return pooled_output + ) -> TokenPoolerOutput: + pooled_data = self.pooling(hidden_states, pooling_metadata) + pooled_data = self.head(pooled_data, pooling_metadata) + return pooled_data @default_pooling_type("CLS") diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 2ac44e3bb..da92eb0a7 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, NamedTuple, TypeAlias import numpy as np import torch @@ -91,7 +91,9 @@ class LogprobsTensors(NamedTuple): # [num_reqs, ] # The shape of each element depends on the pooler used -PoolerOutput = list[torch.Tensor | None] | torch.Tensor | None +TokenPoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor] +TokensPoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None] +PoolerOutput: TypeAlias = TokenPoolerOutput | TokensPoolerOutput @dataclass diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 7ed022bb9..0764d5e6f 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -90,6 +90,12 @@ class PoolingMetadata: return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)] + def get_pooling_cursor(self) -> PoolingCursor: + pooling_cursor = self.pooling_cursor + assert pooling_cursor is not None, "Should call `build_pooling_cursor` first" + + return pooling_cursor + def build_pooling_cursor( self, num_scheduled_tokens_np: np.ndarray, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5b9f0742d..9e8aaeb26 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4680,7 +4680,7 @@ class GPUModelRunner( for task in supported_pooling_tasks: # Run a full batch with each task to ensure none of them OOMs output = self._dummy_pooler_run_task(hidden_states, task) - output_size[task] = sum(o.nbytes for o in output) + output_size[task] = sum(o.nbytes for o in output if o is not None) del output # Allow GC max_task = max(output_size.items(), key=lambda x: x[1])[0]