[Model] Avoid hardcoding pooling type (#32119)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -25,11 +25,11 @@ from vllm.model_executor.layers.pooler import (
|
||||
PoolingParamsUpdate,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.seqwise import (
|
||||
CLSPool,
|
||||
SequencePooler,
|
||||
SequencePoolerHeadOutput,
|
||||
SequencePoolerOutput,
|
||||
SequencePoolingMethodOutput,
|
||||
get_seq_pooling_method,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.tokwise import (
|
||||
pooler_for_token_classify,
|
||||
@@ -94,9 +94,9 @@ class BertEmbedding(nn.Module):
|
||||
|
||||
|
||||
class BertPooler(SequencePooler):
|
||||
def __init__(self, config: BertConfig):
|
||||
def __init__(self, config: BertConfig, pooler_config: PoolerConfig):
|
||||
super().__init__(
|
||||
pooling=CLSPool(),
|
||||
pooling=get_seq_pooling_method(pooler_config.seq_pooling_type),
|
||||
head=self.head,
|
||||
)
|
||||
|
||||
@@ -450,7 +450,11 @@ class BertPoolingModel(BertModel):
|
||||
)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = BertPooler(config, pooler_config)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
other_weights, loaded_stacked_params = self._load_weights(weights)
|
||||
@@ -711,6 +715,8 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
|
||||
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
|
||||
)
|
||||
|
||||
# None of vLLM's built-in sequence pooling types are
|
||||
# applicable so it is overwritten by SPLADESparsePooler
|
||||
pooling_mode = getattr(self, "_splade_pooling", "max")
|
||||
|
||||
cls_id = getattr(cfg, "cls_token_id", None)
|
||||
|
||||
Reference in New Issue
Block a user