[Refactor] Clean up pooler modules (#31897)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -24,11 +24,14 @@ from vllm.model_executor.layers.pooler import (
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
PoolingType,
|
||||
TokenPoolerHeadOutput,
|
||||
TokenPoolingMethodOutput,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.outputs import TokenPoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||
@@ -97,24 +100,26 @@ class BertPooler(Pooler):
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def _head(self, pooled_output: torch.Tensor):
|
||||
pooled_output = self.dense(pooled_output)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
def head(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerHeadOutput:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
|
||||
pooled_data = self.dense(pooled_data)
|
||||
pooled_data = self.activation(pooled_data)
|
||||
return pooled_data
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
||||
|
||||
if isinstance(pooled_output, list):
|
||||
pooled_output = [self._head(output) for output in pooled_output]
|
||||
else:
|
||||
pooled_output = self._head(pooled_output)
|
||||
|
||||
return pooled_output
|
||||
) -> TokenPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user