[Model] Reorganize pooling layers (#31973)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
45
vllm/model_executor/layers/pooler/seqwise/__init__.py
Normal file
45
vllm/model_executor/layers/pooler/seqwise/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Poolers that produce an output aggregating all tokens in the sequence."""
|
||||
|
||||
from .heads import (
|
||||
ClassifierPoolerHead,
|
||||
EmbeddingPoolerHead,
|
||||
SequencePoolerHead,
|
||||
SequencePoolerHeadOutput,
|
||||
)
|
||||
from .methods import (
|
||||
CLSPool,
|
||||
LastPool,
|
||||
MeanPool,
|
||||
SequencePoolingMethod,
|
||||
SequencePoolingMethodOutput,
|
||||
get_seq_pooling_method,
|
||||
)
|
||||
from .poolers import (
|
||||
SequencePooler,
|
||||
SequencePoolerOutput,
|
||||
SequencePoolingFn,
|
||||
SequencePoolingHeadFn,
|
||||
pooler_for_classify,
|
||||
pooler_for_embed,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SequencePoolerHead",
|
||||
"SequencePoolerHeadOutput",
|
||||
"ClassifierPoolerHead",
|
||||
"EmbeddingPoolerHead",
|
||||
"SequencePoolingMethod",
|
||||
"SequencePoolingMethodOutput",
|
||||
"CLSPool",
|
||||
"LastPool",
|
||||
"MeanPool",
|
||||
"get_seq_pooling_method",
|
||||
"SequencePooler",
|
||||
"SequencePoolingFn",
|
||||
"SequencePoolingHeadFn",
|
||||
"SequencePoolerOutput",
|
||||
"pooler_for_classify",
|
||||
"pooler_for_embed",
|
||||
]
|
||||
157
vllm/model_executor/layers/pooler/seqwise/heads.py
Normal file
157
vllm/model_executor/layers/pooler/seqwise/heads.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Set
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.pooler import ClassifierFn
|
||||
from vllm.model_executor.layers.pooler.activations import (
|
||||
PoolerActivation,
|
||||
PoolerNormalize,
|
||||
resolve_classifier_act_fn,
|
||||
)
|
||||
from vllm.model_executor.models.adapters import _load_st_projector
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .methods import SequencePoolingMethodOutput
|
||||
|
||||
SequencePoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
|
||||
|
||||
class SequencePoolerHead(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolerHeadOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EmbeddingPoolerHead(SequencePoolerHead):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Load ST projector if available
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
self.projector = _load_st_projector(model_config)
|
||||
self.head_dtype = model_config.head_dtype
|
||||
|
||||
self.activation = PoolerNormalize()
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"embed"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolerHeadOutput:
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_dimension]
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
# Apply ST projector
|
||||
if self.projector is not None:
|
||||
pooled_data = self.projector(pooled_data)
|
||||
# pooled_data shape: [batchsize, embedding_dimension]
|
||||
|
||||
# for matryoshka representation
|
||||
dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
|
||||
if any(d is not None for d in dimensions_list):
|
||||
# change the output dimension
|
||||
assert len(pooled_data) == len(dimensions_list)
|
||||
if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
|
||||
# if all dimensions are the same
|
||||
d = dimensions_list[0]
|
||||
pooled_data = pooled_data[..., :d]
|
||||
else:
|
||||
pooled_data = [
|
||||
vecs if d is None else vecs[..., :d]
|
||||
for vecs, d in zip(pooled_data, dimensions_list)
|
||||
]
|
||||
|
||||
# for normalize
|
||||
flags = [p.normalize for p in pooling_params]
|
||||
if len(set(flags)) == 1:
|
||||
if flags[0]:
|
||||
pooled_data = self.activation(pooled_data)
|
||||
else:
|
||||
pooled_data = [
|
||||
self.activation(vecs) if f else vecs
|
||||
for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
|
||||
# pooled_data shape: [batchsize, embedding_dimension]
|
||||
return pooled_data
|
||||
|
||||
|
||||
class ClassifierPoolerHead(SequencePoolerHead):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: ClassifierFn | None = None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
self.classifier = classifier
|
||||
self.logit_bias: float | None = model_config.pooler_config.logit_bias
|
||||
self.head_dtype = model_config.head_dtype
|
||||
|
||||
self.act_fn = resolve_classifier_act_fn(
|
||||
model_config, static_num_labels=True, act_fn=act_fn
|
||||
)
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"classify", "score"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolerHeadOutput:
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_size]
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
if self.classifier is not None:
|
||||
pooled_data = self.classifier(pooled_data)
|
||||
# pooled_data shape: [batchsize, num_labels]
|
||||
|
||||
if self.logit_bias is not None:
|
||||
pooled_data -= self.logit_bias
|
||||
|
||||
flags = [p.use_activation for p in pooling_params]
|
||||
if len(set(flags)) == 1:
|
||||
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
|
||||
else:
|
||||
scores = [
|
||||
self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
|
||||
# scores shape: [batchsize, num_labels]
|
||||
return scores
|
||||
93
vllm/model_executor/layers/pooler/seqwise/methods.py
Normal file
93
vllm/model_executor/layers/pooler/seqwise/methods.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Set
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
SequencePoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
|
||||
|
||||
class SequencePoolingMethod(nn.Module, ABC):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolingMethodOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CLSPool(SequencePoolingMethod):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with CLS pooling"
|
||||
)
|
||||
|
||||
return hidden_states[pooling_cursor.first_token_indices_gpu]
|
||||
|
||||
|
||||
class LastPool(SequencePoolingMethod):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
return hidden_states[pooling_cursor.last_token_indices_gpu]
|
||||
|
||||
|
||||
class MeanPool(SequencePoolingMethod):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with MEAN pooling"
|
||||
)
|
||||
|
||||
prompt_lens = pooling_cursor.prompt_lens_cpu.to(
|
||||
hidden_states.device, non_blocking=True
|
||||
)
|
||||
|
||||
# Use float32 for torch.cumsum in MeanPool,
|
||||
# otherwise precision will be lost significantly.
|
||||
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
|
||||
|
||||
start_indices = pooling_cursor.first_token_indices_gpu
|
||||
end_indices = pooling_cursor.last_token_indices_gpu
|
||||
|
||||
return (
|
||||
cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
|
||||
) / prompt_lens.unsqueeze(1)
|
||||
|
||||
|
||||
def get_seq_pooling_method(pooling_type: PoolingTypeStr | str):
|
||||
if pooling_type == "LAST":
|
||||
return LastPool()
|
||||
if pooling_type == "CLS":
|
||||
return CLSPool()
|
||||
if pooling_type == "MEAN":
|
||||
return MeanPool()
|
||||
|
||||
raise NotImplementedError(f"Unknown sequence pooling type: {pooling_type!r}")
|
||||
106
vllm/model_executor/layers/pooler/seqwise/poolers.py
Normal file
106
vllm/model_executor/layers/pooler/seqwise/poolers.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable, Set
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
|
||||
from vllm.model_executor.layers.pooler.abstract import Pooler
|
||||
from vllm.model_executor.layers.pooler.activations import PoolerActivation
|
||||
from vllm.tasks import POOLING_TASKS, PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .heads import (
|
||||
ClassifierPoolerHead,
|
||||
EmbeddingPoolerHead,
|
||||
SequencePoolerHead,
|
||||
SequencePoolerHeadOutput,
|
||||
)
|
||||
from .methods import (
|
||||
SequencePoolingMethod,
|
||||
SequencePoolingMethodOutput,
|
||||
get_seq_pooling_method,
|
||||
)
|
||||
|
||||
SequencePoolingFn: TypeAlias = Callable[
|
||||
[torch.Tensor, PoolingMetadata],
|
||||
SequencePoolingMethodOutput,
|
||||
]
|
||||
SequencePoolingHeadFn: TypeAlias = Callable[
|
||||
[SequencePoolingMethodOutput, PoolingMetadata],
|
||||
SequencePoolerHeadOutput,
|
||||
]
|
||||
|
||||
SequencePoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
|
||||
|
||||
class SequencePooler(Pooler):
|
||||
"""
|
||||
A layer that pools specific information from hidden states.
|
||||
|
||||
This layer does the following:
|
||||
1. Extracts specific tokens or aggregates data based on pooling method.
|
||||
2. Postprocesses the output based on pooling head.
|
||||
3. Returns structured results as `PoolerOutput`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pooling: SequencePoolingMethod | SequencePoolingFn,
|
||||
head: SequencePoolerHead | SequencePoolingHeadFn,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = pooling
|
||||
self.head = head
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
tasks = set(POOLING_TASKS)
|
||||
|
||||
if isinstance(self.pooling, SequencePoolingMethod):
|
||||
tasks &= self.pooling.get_supported_tasks()
|
||||
if isinstance(self.head, SequencePoolerHead):
|
||||
tasks &= self.head.get_supported_tasks()
|
||||
|
||||
return tasks
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
updates = PoolingParamsUpdate()
|
||||
|
||||
if isinstance(self.pooling, SequencePoolingMethod):
|
||||
updates |= self.pooling.get_pooling_updates(task)
|
||||
|
||||
return updates
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
|
||||
def pooler_for_embed(pooler_config: PoolerConfig):
|
||||
pooling = get_seq_pooling_method(pooler_config.get_pooling_type())
|
||||
head = EmbeddingPoolerHead()
|
||||
|
||||
return SequencePooler(pooling=pooling, head=head)
|
||||
|
||||
|
||||
def pooler_for_classify(
|
||||
pooler_config: PoolerConfig,
|
||||
*,
|
||||
pooling: SequencePoolingMethod | SequencePoolingFn | None = None,
|
||||
classifier: ClassifierFn | None = None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
):
|
||||
if pooling is None:
|
||||
pooling = get_seq_pooling_method(pooler_config.get_pooling_type())
|
||||
|
||||
head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
|
||||
|
||||
return SequencePooler(pooling=pooling, head=head)
|
||||
Reference in New Issue
Block a user