[Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -64,66 +64,6 @@ class PoolingParamsUpdate:
|
||||
params.requires_token_ids = self.requires_token_ids
|
||||
|
||||
|
||||
class Pooler(nn.Module, ABC):
|
||||
"""The interface required for all poolers used in pooling models in vLLM."""
|
||||
|
||||
@staticmethod
|
||||
def for_encode(pooler_config: PoolerConfig):
|
||||
if pooler_config.pooling_type == "STEP":
|
||||
return StepPooler()
|
||||
|
||||
resolved_config = ResolvedPoolingConfig(
|
||||
task="encode", pooling_type=PoolingType.ALL
|
||||
)
|
||||
|
||||
return SimplePooler.from_config(resolved_config)
|
||||
|
||||
@staticmethod
|
||||
def for_embed(pooler_config: PoolerConfig):
|
||||
resolved_config = ResolvedPoolingConfig.from_config(
|
||||
task="embed",
|
||||
pooler_config=pooler_config,
|
||||
)
|
||||
|
||||
return SimplePooler.from_config(resolved_config)
|
||||
|
||||
@staticmethod
|
||||
def for_classify(
|
||||
pooler_config: PoolerConfig,
|
||||
classifier: ClassifierFn | None,
|
||||
):
|
||||
resolved_config = ResolvedPoolingConfig.from_config(
|
||||
task="classify",
|
||||
pooler_config=pooler_config,
|
||||
)
|
||||
|
||||
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
|
||||
|
||||
return ClassifierPooler(
|
||||
pooling=pooling,
|
||||
classifier=classifier,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
"""Determine which pooling tasks are supported."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
"""
|
||||
Construct the updated pooling parameters to use for a supported task.
|
||||
"""
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: list[torch.Tensor] | torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_prompt_lens(
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
@@ -237,7 +177,7 @@ class PoolingMethod(nn.Module, ABC):
|
||||
|
||||
class CLSPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode", "embed", "classify", "score"}
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
@@ -253,7 +193,7 @@ class CLSPool(PoolingMethod):
|
||||
|
||||
class LastPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode", "embed", "classify", "score"}
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
@@ -265,7 +205,7 @@ class LastPool(PoolingMethod):
|
||||
|
||||
class AllPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode"}
|
||||
return {"token_embed", "token_classify"}
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
@@ -284,7 +224,7 @@ class AllPool(PoolingMethod):
|
||||
|
||||
class MeanPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode", "embed", "classify", "score"}
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
@@ -398,6 +338,82 @@ class LambdaPoolerActivation(PoolerActivation):
|
||||
return self.fn(pooled_data)
|
||||
|
||||
|
||||
class Pooler(nn.Module, ABC):
|
||||
"""The interface required for all poolers used in pooling models in vLLM."""
|
||||
|
||||
@staticmethod
|
||||
def for_token_embed(pooler_config: PoolerConfig):
|
||||
head = TokenEmbeddingPoolerHead()
|
||||
|
||||
if pooler_config.pooling_type == "STEP":
|
||||
return StepPooler(head=head)
|
||||
|
||||
return AllPooler(head=head)
|
||||
|
||||
@staticmethod
|
||||
def for_token_classify(
|
||||
pooler_config: PoolerConfig,
|
||||
classifier: ClassifierFn | None = None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
):
|
||||
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
|
||||
|
||||
if pooler_config.pooling_type == "STEP":
|
||||
return StepPooler(head=head)
|
||||
|
||||
return AllPooler(head=head)
|
||||
|
||||
@staticmethod
|
||||
def for_embed(pooler_config: PoolerConfig):
|
||||
resolved_config = ResolvedPoolingConfig.from_config(
|
||||
task="embed",
|
||||
pooler_config=pooler_config,
|
||||
)
|
||||
|
||||
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
|
||||
head = EmbeddingPoolerHead()
|
||||
|
||||
return SimplePooler(pooling=pooling, head=head)
|
||||
|
||||
@staticmethod
|
||||
def for_classify(
|
||||
pooler_config: PoolerConfig,
|
||||
classifier: ClassifierFn | None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
):
|
||||
resolved_config = ResolvedPoolingConfig.from_config(
|
||||
task="classify",
|
||||
pooler_config=pooler_config,
|
||||
)
|
||||
|
||||
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
|
||||
|
||||
return ClassifierPooler(
|
||||
pooling=pooling,
|
||||
classifier=classifier,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
"""Determine which pooling tasks are supported."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
"""
|
||||
Construct the updated pooling parameters to use for a supported task.
|
||||
"""
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: list[torch.Tensor] | torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PoolerHead(nn.Module):
|
||||
def __init__(self, activation: PoolerActivation) -> None:
|
||||
super().__init__()
|
||||
@@ -416,7 +432,6 @@ class EmbeddingPoolerHead(PoolerHead):
|
||||
super().__init__(activation=PoolerNormalize())
|
||||
|
||||
# Load ST projector if available
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.projector: nn.Module | None = (
|
||||
_load_st_projector(vllm_config.model_config) if vllm_config else None
|
||||
@@ -471,39 +486,6 @@ class EmbeddingPoolerHead(PoolerHead):
|
||||
return pooled_data
|
||||
|
||||
|
||||
class RewardPoolerHead(PoolerHead):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(activation=PoolerClassify(static_num_labels=False))
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
):
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = [p.to(self.head_dtype) for p in pooled_data]
|
||||
else:
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
|
||||
# for softmax
|
||||
flags = [p.softmax for p in pooling_params]
|
||||
if len(set(flags)) == 1:
|
||||
if flags[0]:
|
||||
pooled_data = self.activation(pooled_data)
|
||||
else:
|
||||
pooled_data = [
|
||||
self.activation(vecs) if f else vecs
|
||||
for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
|
||||
return pooled_data
|
||||
|
||||
|
||||
class SimplePooler(Pooler):
|
||||
"""A layer that pools specific information from hidden states.
|
||||
|
||||
@@ -513,20 +495,6 @@ class SimplePooler(Pooler):
|
||||
3. Returns structured results as `PoolerOutput`.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
pooler_config: ResolvedPoolingConfig,
|
||||
) -> "SimplePooler":
|
||||
pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
|
||||
if pooler_config.task == "embed":
|
||||
head = EmbeddingPoolerHead()
|
||||
elif pooler_config.task == "encode":
|
||||
head = RewardPoolerHead()
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown task: {pooler_config.task}")
|
||||
return cls(pooling, head)
|
||||
|
||||
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -549,58 +517,6 @@ class SimplePooler(Pooler):
|
||||
return pooled_data
|
||||
|
||||
|
||||
class StepPooler(Pooler):
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = AllPool()
|
||||
self.head = RewardPoolerHead()
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
|
||||
|
||||
pooled_data = list[torch.Tensor]()
|
||||
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
|
||||
for data, token_id, pooling_param in zip(
|
||||
pooled_data_lst, prompt_token_ids, pooling_params
|
||||
):
|
||||
step_tag_id = pooling_param.step_tag_id
|
||||
returned_token_ids = pooling_param.returned_token_ids
|
||||
|
||||
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||
data = data[:, returned_token_ids]
|
||||
|
||||
if step_tag_id is not None:
|
||||
data = data[token_id == step_tag_id]
|
||||
pooled_data.append(data)
|
||||
|
||||
return pooled_data
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode"}
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
|
||||
class ClassifierPooler(Pooler):
|
||||
"""A pooling layer for classification tasks.
|
||||
|
||||
@@ -611,26 +527,46 @@ class ClassifierPooler(Pooler):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def act_fn_for_seq_cls(config: ModelConfig):
|
||||
return get_classification_activation_function(config.hf_config)
|
||||
def act_fn_for_seq_cls(model_config: ModelConfig):
|
||||
return get_classification_activation_function(model_config.hf_config)
|
||||
|
||||
@staticmethod
|
||||
def act_fn_for_cross_encoder(config: ModelConfig):
|
||||
return get_cross_encoder_activation_function(config.hf_config)
|
||||
def act_fn_for_cross_encoder(model_config: ModelConfig):
|
||||
return get_cross_encoder_activation_function(model_config.hf_config)
|
||||
|
||||
@staticmethod
|
||||
def resolve_act_fn(
|
||||
model_config: ModelConfig,
|
||||
static_num_labels: bool = True,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
):
|
||||
if isinstance(act_fn, str):
|
||||
if act_fn == "classify":
|
||||
return ClassifierPooler.act_fn_for_seq_cls(model_config)
|
||||
elif act_fn == "score":
|
||||
return ClassifierPooler.act_fn_for_cross_encoder(model_config)
|
||||
else:
|
||||
raise ValueError(f"act_fn [{act_fn=}] not supported.")
|
||||
elif act_fn is None:
|
||||
return PoolerClassify(static_num_labels=static_num_labels)
|
||||
else:
|
||||
assert callable(act_fn)
|
||||
return act_fn
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pooling: PoolingFn,
|
||||
classifier: ClassifierFn | None,
|
||||
act_fn: PoolerActivation | None = None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
self.pooling = pooling
|
||||
self.classifier = classifier
|
||||
self.act_fn = act_fn or PoolerClassify()
|
||||
self.act_fn = self.resolve_act_fn(
|
||||
vllm_config.model_config, static_num_labels=True, act_fn=act_fn
|
||||
)
|
||||
self.logit_bias: float | None = (
|
||||
vllm_config.model_config.pooler_config.logit_bias
|
||||
)
|
||||
@@ -672,6 +608,150 @@ class ClassifierPooler(Pooler):
|
||||
return scores
|
||||
|
||||
|
||||
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
|
||||
def forward(
|
||||
self, pooled_data: torch.Tensor, pooling_param: PoolingParams
|
||||
) -> torch.Tensor:
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
# pooled_data shape: [n_tokens, hidden_dimension]
|
||||
|
||||
# Apply ST projector
|
||||
if self.projector is not None:
|
||||
pooled_data = self.projector(pooled_data)
|
||||
# pooled_data shape: [n_tokens, embedding_dimension]
|
||||
|
||||
# for matryoshka representation
|
||||
pooled_data = pooled_data[..., : pooling_param.dimensions]
|
||||
|
||||
# for normalize
|
||||
if pooling_param.normalize:
|
||||
pooled_data = self.activation(pooled_data)
|
||||
|
||||
# pooled_data shape: [n_tokens, embedding_dimension]
|
||||
return pooled_data
|
||||
|
||||
|
||||
class TokenClassifierPoolerHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: ClassifierFn | None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
self.classifier = classifier
|
||||
self.act_fn = ClassifierPooler.resolve_act_fn(
|
||||
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
|
||||
)
|
||||
self.logit_bias: float | None = (
|
||||
vllm_config.model_config.pooler_config.logit_bias
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_classify"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_param: PoolingParams,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states.to(self.head_dtype)
|
||||
# hidden_states shape: [n_token, hidden_size]
|
||||
|
||||
if self.classifier is not None:
|
||||
scores = self.classifier(hidden_states)
|
||||
else:
|
||||
scores = hidden_states
|
||||
# scores shape: [n_token, num_labels]
|
||||
|
||||
if self.logit_bias is not None:
|
||||
scores -= self.logit_bias
|
||||
|
||||
if pooling_param.activation:
|
||||
scores = self.act_fn(scores)
|
||||
|
||||
# scores shape: [n_token, num_labels]
|
||||
return scores
|
||||
|
||||
|
||||
class AllPooler(Pooler):
|
||||
def __init__(self, head: nn.Module | PoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = AllPool()
|
||||
self.head = head
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
return pooled_data
|
||||
|
||||
|
||||
class StepPooler(Pooler):
|
||||
def __init__(self, head: nn.Module | PoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = AllPool()
|
||||
self.head = head
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
|
||||
|
||||
pooled_data = list[torch.Tensor]()
|
||||
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
|
||||
for data, token_id, pooling_param in zip(
|
||||
pooled_data_lst, prompt_token_ids, pooling_params
|
||||
):
|
||||
step_tag_id = pooling_param.step_tag_id
|
||||
returned_token_ids = pooling_param.returned_token_ids
|
||||
|
||||
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||
data = data[:, returned_token_ids]
|
||||
|
||||
if step_tag_id is not None:
|
||||
data = data[token_id == step_tag_id]
|
||||
pooled_data.append(data)
|
||||
|
||||
return pooled_data
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify"}
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
return pooled_data
|
||||
|
||||
|
||||
class DispatchPooler(Pooler):
|
||||
"""Dispatches calls to a sub-pooler based on the pooling task."""
|
||||
|
||||
|
||||
@@ -250,7 +250,7 @@ def as_embedding_model(cls: _T) -> _T:
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
},
|
||||
)
|
||||
@@ -279,11 +279,8 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
ClassifierPooler,
|
||||
DispatchPooler,
|
||||
Pooler,
|
||||
PoolingMethod,
|
||||
PoolingType,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -302,42 +299,29 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
model_config.hidden_size,
|
||||
config.num_labels,
|
||||
bias=False,
|
||||
params_dtype=torch.float32,
|
||||
params_dtype=vllm_config.model_config.head_dtype,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=maybe_prefix(prefix, "score"),
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
pooling_type_str = pooler_config.pooling_type
|
||||
assert pooling_type_str is not None
|
||||
pooling_type = PoolingType[pooling_type_str]
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=PoolingMethod.from_pooling_type(pooling_type),
|
||||
classifier=self._classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config
|
||||
),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.score
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=PoolingMethod.from_pooling_type(pooling_type),
|
||||
classifier=self._classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config
|
||||
),
|
||||
"classify": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="classify"
|
||||
),
|
||||
"score": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _classifier(self, x: torch.Tensor):
|
||||
x, _ = self.score(x.float())
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -393,7 +377,11 @@ def as_reward_model(cls: _T) -> _T:
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)},
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config=pooler_config
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward")
|
||||
|
||||
@@ -521,7 +521,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
return DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
@@ -724,7 +724,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
|
||||
|
||||
return DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": SPLADESparsePooler(
|
||||
mlm_head=self.mlm_head,
|
||||
cls_token_id=cls_id,
|
||||
@@ -821,20 +821,16 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config
|
||||
),
|
||||
act_fn="classify",
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=self.bert.pooler, classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
@@ -891,7 +887,9 @@ class BertForTokenClassification(nn.Module):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config=pooler_config
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -695,20 +695,16 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=self.new.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config
|
||||
),
|
||||
act_fn="classify",
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=self.new.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=self.new.pooler, classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -837,7 +837,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -353,8 +353,15 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"classify": Pooler.for_classify(pooler_config, classifier=self.score),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.score
|
||||
),
|
||||
"classify": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="classify"
|
||||
),
|
||||
"score": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -239,7 +239,7 @@ class GritLM(LlamaForCausalLM):
|
||||
if pooler_config is not None:
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": GritLMPooler(vllm_config.model_config),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -444,7 +444,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)},
|
||||
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -604,10 +604,14 @@ class JambaForSequenceClassification(JambaForCausalLM):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.score
|
||||
),
|
||||
"classify": Pooler.for_classify(
|
||||
pooler_config,
|
||||
classifier=self.score,
|
||||
pooler_config, classifier=self.score, act_fn="classify"
|
||||
),
|
||||
"score": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -97,9 +97,15 @@ class JinaVLForSequenceClassification(
|
||||
self.score = JinaVLScorer(vllm_config.model_config)
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"classify": Pooler.for_classify(pooler_config, classifier=self.score),
|
||||
"score": Pooler.for_classify(pooler_config, classifier=self.score),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.score
|
||||
),
|
||||
"classify": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="classify"
|
||||
),
|
||||
"score": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -322,20 +322,14 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=self.pooling,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=self.pooling, classifier=self.classifier, act_fn="classify"
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=self.pooling,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=self.pooling, classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
@@ -421,7 +415,9 @@ class ModernBertForTokenClassification(nn.Module):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config=pooler_config
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)},
|
||||
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||
)
|
||||
|
||||
|
||||
@@ -120,4 +120,6 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({"encode": Pooler.for_encode(pooler_config)})
|
||||
self.pooler = DispatchPooler(
|
||||
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||
)
|
||||
|
||||
@@ -105,15 +105,7 @@ class RobertaClassificationHead(nn.Module):
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
"""A model that uses Roberta to provide embedding functionalities.
|
||||
|
||||
This class encapsulates the BertModel and provides an interface for
|
||||
embedding operations and customized pooling functions.
|
||||
|
||||
Attributes:
|
||||
model: An instance of BertModel used for forward operations.
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
"""A model that uses Roberta to provide embedding functionalities."""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
@@ -212,20 +204,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config=pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -250,7 +250,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)},
|
||||
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
|
||||
@@ -135,7 +135,7 @@ class TransformersEmbeddingModel(TransformersPoolingBase):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
@@ -190,20 +190,14 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user