[Core] Set pooling params based on task and model (#21128)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -18,13 +18,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
||||
PoolingMethod, PoolingTask,
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
PoolingType)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
|
||||
@@ -91,8 +92,11 @@ class BertPooler(Pooler):
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
return self.pooling.get_pooling_params(task)
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user