[Feature]: Support for multiple embedding types in a single inference call (#35829)

Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
This commit is contained in:
Augusto Yao
2026-03-17 17:05:42 +08:00
committed by GitHub
parent 132bfd45b6
commit 9c7cab5ebb
7 changed files with 226 additions and 36 deletions

View File

@@ -170,4 +170,42 @@ class BOSEOSFilter(Pooler):
return pooled_outputs
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"]
class BgeM3Pooler(Pooler):
def __init__(self, token_classify_pooler: Pooler, embed_pooler: Pooler) -> None:
super().__init__()
self.token_classify_pooler = token_classify_pooler
self.embed_pooler = embed_pooler
def forward(
self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata
) -> PoolerOutput:
embed_outputs = self.embed_pooler(hidden_states, pooling_metadata)
token_classify_outputs = self.token_classify_pooler(
hidden_states, pooling_metadata
)
pooler_outputs: list[torch.Tensor] = []
for embed_output, token_classify_output in zip(
embed_outputs, token_classify_outputs
):
pooler_outputs.append(
torch.cat(
[embed_output.view(-1), token_classify_output.view(-1)], dim=-1
)
)
return pooler_outputs
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"embed&token_classify"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.embed_pooler.get_pooling_updates(
"embed"
) | self.token_classify_pooler.get_pooling_updates("token_classify")
def extra_repr(self) -> str:
s = f"supported_task={self.get_supported_tasks()}"
return s
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler", "BgeM3Pooler"]