[Refactor] Clean up pooler modules (#31897)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -24,11 +24,14 @@ from vllm.model_executor.layers.pooler import (
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
PoolingType,
|
||||
TokenPoolerHeadOutput,
|
||||
TokenPoolingMethodOutput,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.outputs import TokenPoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||
@@ -97,24 +100,26 @@ class BertPooler(Pooler):
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def _head(self, pooled_output: torch.Tensor):
|
||||
pooled_output = self.dense(pooled_output)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
def head(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerHeadOutput:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
|
||||
pooled_data = self.dense(pooled_data)
|
||||
pooled_data = self.activation(pooled_data)
|
||||
return pooled_data
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
||||
|
||||
if isinstance(pooled_output, list):
|
||||
pooled_output = [self._head(output) for output in pooled_output]
|
||||
else:
|
||||
pooled_output = self._head(pooled_output)
|
||||
|
||||
return pooled_output
|
||||
) -> TokenPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
@@ -4,21 +4,22 @@ from collections.abc import Set
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
DispatchPooler,
|
||||
Pooler,
|
||||
PoolerHead,
|
||||
PoolerNormalize,
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
TokenPoolerHeadOutput,
|
||||
TokenPoolingMethodOutput,
|
||||
)
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
from vllm.v1.outputs import PoolerOutput
|
||||
from vllm.v1.outputs import TokenPoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces_base import default_pooling_type
|
||||
@@ -26,7 +27,7 @@ from .interfaces_base import default_pooling_type
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GritLMMeanPool(nn.Module):
|
||||
class GritLMMeanPool(PoolingMethod):
|
||||
"""As `MeanPool`, but only includes non-instruction tokens."""
|
||||
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
@@ -141,16 +142,16 @@ class GritLMMeanPool(nn.Module):
|
||||
return instruction_len
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode", "embed"}
|
||||
return {"embed"}
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
) -> TokenPoolingMethodOutput:
|
||||
prompt_lens = pooling_metadata.prompt_lens
|
||||
instr_lens = torch.tensor(
|
||||
[
|
||||
@@ -178,7 +179,7 @@ class GritLMPooler(Pooler):
|
||||
super().__init__()
|
||||
|
||||
self.pooling = GritLMMeanPool(model_config)
|
||||
self.head = PoolerHead(PoolerNormalize())
|
||||
self.activation = PoolerNormalize()
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return self.pooling.get_supported_tasks()
|
||||
@@ -186,11 +187,18 @@ class GritLMPooler(Pooler):
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def head(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerHeadOutput:
|
||||
return self.activation(pooled_data)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
) -> TokenPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
@@ -19,12 +19,15 @@ from vllm.model_executor.layers.pooler import (
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
PoolingType,
|
||||
TokenPoolerHeadOutput,
|
||||
TokenPoolingMethodOutput,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.outputs import TokenPoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces import SupportsCrossEncoding
|
||||
@@ -300,23 +303,25 @@ class ModernBertPooler(Pooler):
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def _head(self, pooled_output: torch.Tensor):
|
||||
pooled_output = pooled_output.to(self.dense.weight.dtype)
|
||||
return self.norm(self.act(self.dense(pooled_output)))
|
||||
def head(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerHeadOutput:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
|
||||
pooled_data = pooled_data.to(self.dense.weight.dtype)
|
||||
return self.norm(self.act(self.dense(pooled_data)))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
||||
|
||||
if isinstance(pooled_output, list):
|
||||
pooled_output = [self._head(output) for output in pooled_output]
|
||||
else:
|
||||
pooled_output = self._head(pooled_output)
|
||||
|
||||
return pooled_output
|
||||
) -> TokenPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
|
||||
Reference in New Issue
Block a user