[Feature]: Support for multiple embedding types in a single inference call (#35829)
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
This commit is contained in:
@@ -170,4 +170,42 @@ class BOSEOSFilter(Pooler):
|
||||
return pooled_outputs
|
||||
|
||||
|
||||
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"]
|
||||
class BgeM3Pooler(Pooler):
|
||||
def __init__(self, token_classify_pooler: Pooler, embed_pooler: Pooler) -> None:
|
||||
super().__init__()
|
||||
self.token_classify_pooler = token_classify_pooler
|
||||
self.embed_pooler = embed_pooler
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata
|
||||
) -> PoolerOutput:
|
||||
embed_outputs = self.embed_pooler(hidden_states, pooling_metadata)
|
||||
token_classify_outputs = self.token_classify_pooler(
|
||||
hidden_states, pooling_metadata
|
||||
)
|
||||
pooler_outputs: list[torch.Tensor] = []
|
||||
for embed_output, token_classify_output in zip(
|
||||
embed_outputs, token_classify_outputs
|
||||
):
|
||||
pooler_outputs.append(
|
||||
torch.cat(
|
||||
[embed_output.view(-1), token_classify_output.view(-1)], dim=-1
|
||||
)
|
||||
)
|
||||
|
||||
return pooler_outputs
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"embed&token_classify"}
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.embed_pooler.get_pooling_updates(
|
||||
"embed"
|
||||
) | self.token_classify_pooler.get_pooling_updates("token_classify")
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"supported_task={self.get_supported_tasks()}"
|
||||
return s
|
||||
|
||||
|
||||
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler", "BgeM3Pooler"]
|
||||
|
||||
Reference in New Issue
Block a user