[Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-10-15 19:14:41 +08:00
committed by GitHub
parent d4d1a6024f
commit f54f85129e
41 changed files with 786 additions and 399 deletions

View File

@@ -64,66 +64,6 @@ class PoolingParamsUpdate:
params.requires_token_ids = self.requires_token_ids
class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@staticmethod
def for_encode(pooler_config: PoolerConfig):
if pooler_config.pooling_type == "STEP":
return StepPooler()
resolved_config = ResolvedPoolingConfig(
task="encode", pooling_type=PoolingType.ALL
)
return SimplePooler.from_config(resolved_config)
@staticmethod
def for_embed(pooler_config: PoolerConfig):
resolved_config = ResolvedPoolingConfig.from_config(
task="embed",
pooler_config=pooler_config,
)
return SimplePooler.from_config(resolved_config)
@staticmethod
def for_classify(
pooler_config: PoolerConfig,
classifier: ClassifierFn | None,
):
resolved_config = ResolvedPoolingConfig.from_config(
task="classify",
pooler_config=pooler_config,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
return ClassifierPooler(
pooling=pooling,
classifier=classifier,
)
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
"""Determine which pooling tasks are supported."""
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
def get_prompt_lens(
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
@@ -237,7 +177,7 @@ class PoolingMethod(nn.Module, ABC):
class CLSPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"}
return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all(
self,
@@ -253,7 +193,7 @@ class CLSPool(PoolingMethod):
class LastPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"}
return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all(
self,
@@ -265,7 +205,7 @@ class LastPool(PoolingMethod):
class AllPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode"}
return {"token_embed", "token_classify"}
def forward_all(
self,
@@ -284,7 +224,7 @@ class AllPool(PoolingMethod):
class MeanPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"}
return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all(
self,
@@ -398,6 +338,82 @@ class LambdaPoolerActivation(PoolerActivation):
return self.fn(pooled_data)
class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@staticmethod
def for_token_embed(pooler_config: PoolerConfig):
head = TokenEmbeddingPoolerHead()
if pooler_config.pooling_type == "STEP":
return StepPooler(head=head)
return AllPooler(head=head)
@staticmethod
def for_token_classify(
pooler_config: PoolerConfig,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
):
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
if pooler_config.pooling_type == "STEP":
return StepPooler(head=head)
return AllPooler(head=head)
@staticmethod
def for_embed(pooler_config: PoolerConfig):
resolved_config = ResolvedPoolingConfig.from_config(
task="embed",
pooler_config=pooler_config,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
head = EmbeddingPoolerHead()
return SimplePooler(pooling=pooling, head=head)
@staticmethod
def for_classify(
pooler_config: PoolerConfig,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | str | None = None,
):
resolved_config = ResolvedPoolingConfig.from_config(
task="classify",
pooler_config=pooler_config,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
return ClassifierPooler(
pooling=pooling,
classifier=classifier,
act_fn=act_fn,
)
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
"""Determine which pooling tasks are supported."""
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
class PoolerHead(nn.Module):
def __init__(self, activation: PoolerActivation) -> None:
super().__init__()
@@ -416,7 +432,6 @@ class EmbeddingPoolerHead(PoolerHead):
super().__init__(activation=PoolerNormalize())
# Load ST projector if available
vllm_config = get_current_vllm_config()
self.projector: nn.Module | None = (
_load_st_projector(vllm_config.model_config) if vllm_config else None
@@ -471,39 +486,6 @@ class EmbeddingPoolerHead(PoolerHead):
return pooled_data
class RewardPoolerHead(PoolerHead):
def __init__(self) -> None:
super().__init__(activation=PoolerClassify(static_num_labels=False))
vllm_config = get_current_vllm_config()
self.head_dtype = vllm_config.model_config.head_dtype
def forward(
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
):
if isinstance(pooled_data, list):
pooled_data = [p.to(self.head_dtype) for p in pooled_data]
else:
pooled_data = pooled_data.to(self.head_dtype)
pooling_params = get_pooling_params(pooling_metadata)
# for softmax
flags = [p.softmax for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
return pooled_data
class SimplePooler(Pooler):
"""A layer that pools specific information from hidden states.
@@ -513,20 +495,6 @@ class SimplePooler(Pooler):
3. Returns structured results as `PoolerOutput`.
"""
@classmethod
def from_config(
cls,
pooler_config: ResolvedPoolingConfig,
) -> "SimplePooler":
pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
if pooler_config.task == "embed":
head = EmbeddingPoolerHead()
elif pooler_config.task == "encode":
head = RewardPoolerHead()
else:
raise NotImplementedError(f"Unknown task: {pooler_config.task}")
return cls(pooling, head)
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
super().__init__()
@@ -549,58 +517,6 @@ class SimplePooler(Pooler):
return pooled_data
class StepPooler(Pooler):
def __init__(
self,
) -> None:
super().__init__()
self.pooling = AllPool()
self.head = RewardPoolerHead()
def extract_states(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor] | torch.Tensor:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
pooled_data = list[torch.Tensor]()
pooling_params = get_pooling_params(pooling_metadata)
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
class ClassifierPooler(Pooler):
"""A pooling layer for classification tasks.
@@ -611,26 +527,46 @@ class ClassifierPooler(Pooler):
"""
@staticmethod
def act_fn_for_seq_cls(config: ModelConfig):
return get_classification_activation_function(config.hf_config)
def act_fn_for_seq_cls(model_config: ModelConfig):
return get_classification_activation_function(model_config.hf_config)
@staticmethod
def act_fn_for_cross_encoder(config: ModelConfig):
return get_cross_encoder_activation_function(config.hf_config)
def act_fn_for_cross_encoder(model_config: ModelConfig):
return get_cross_encoder_activation_function(model_config.hf_config)
@staticmethod
def resolve_act_fn(
model_config: ModelConfig,
static_num_labels: bool = True,
act_fn: PoolerActivation | str | None = None,
):
if isinstance(act_fn, str):
if act_fn == "classify":
return ClassifierPooler.act_fn_for_seq_cls(model_config)
elif act_fn == "score":
return ClassifierPooler.act_fn_for_cross_encoder(model_config)
else:
raise ValueError(f"act_fn [{act_fn=}] not supported.")
elif act_fn is None:
return PoolerClassify(static_num_labels=static_num_labels)
else:
assert callable(act_fn)
return act_fn
def __init__(
self,
pooling: PoolingFn,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | None = None,
act_fn: PoolerActivation | str | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
self.pooling = pooling
self.classifier = classifier
self.act_fn = act_fn or PoolerClassify()
self.act_fn = self.resolve_act_fn(
vllm_config.model_config, static_num_labels=True, act_fn=act_fn
)
self.logit_bias: float | None = (
vllm_config.model_config.pooler_config.logit_bias
)
@@ -672,6 +608,150 @@ class ClassifierPooler(Pooler):
return scores
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
def forward(
self, pooled_data: torch.Tensor, pooling_param: PoolingParams
) -> torch.Tensor:
pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
if self.projector is not None:
pooled_data = self.projector(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
# for matryoshka representation
pooled_data = pooled_data[..., : pooling_param.dimensions]
# for normalize
if pooling_param.normalize:
pooled_data = self.activation(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
return pooled_data
class TokenClassifierPoolerHead(nn.Module):
def __init__(
self,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | str | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
self.classifier = classifier
self.act_fn = ClassifierPooler.resolve_act_fn(
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
)
self.logit_bias: float | None = (
vllm_config.model_config.pooler_config.logit_bias
)
self.head_dtype = vllm_config.model_config.head_dtype
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_classify"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_param: PoolingParams,
) -> torch.Tensor:
hidden_states = hidden_states.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]
if self.classifier is not None:
scores = self.classifier(hidden_states)
else:
scores = hidden_states
# scores shape: [n_token, num_labels]
if self.logit_bias is not None:
scores -= self.logit_bias
if pooling_param.activation:
scores = self.act_fn(scores)
# scores shape: [n_token, num_labels]
return scores
class AllPooler(Pooler):
def __init__(self, head: nn.Module | PoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
self.head = head
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooling_params = get_pooling_params(pooling_metadata)
assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
return pooled_data
class StepPooler(Pooler):
def __init__(self, head: nn.Module | PoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
self.head = head
def extract_states(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
pooled_data = list[torch.Tensor]()
pooling_params = get_pooling_params(pooling_metadata)
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooling_params = get_pooling_params(pooling_metadata)
assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
return pooled_data
class DispatchPooler(Pooler):
"""Dispatches calls to a sub-pooler based on the pooling task."""