[Feature]: Support for multiple embedding types in a single inference call (#35829)
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
This commit is contained in:
@@ -19,6 +19,12 @@ model_config = {
|
||||
),
|
||||
}
|
||||
|
||||
dense_embedding_sum = [
|
||||
-0.7214539647102356, # "What is the capital of France?"
|
||||
-0.6926871538162231, # "What is the capital of Germany?"
|
||||
-0.7129564881324768, # "What is the capital of Spain?"
|
||||
]
|
||||
|
||||
|
||||
def _float_close(expected: object, result: object):
|
||||
assert isinstance(expected, float) and isinstance(result, float), (
|
||||
@@ -33,6 +39,12 @@ def _get_attr_or_val(obj: object | dict, key: str):
|
||||
return getattr(obj, key, None)
|
||||
|
||||
|
||||
def _check_dense_embedding(data, index=0):
|
||||
assert _float_close(sum(data), dense_embedding_sum[index]), (
|
||||
"dense-embedding result not match"
|
||||
)
|
||||
|
||||
|
||||
def _check_sparse_embedding(data, check_tokens=False):
|
||||
expected_weights = [
|
||||
{"token_id": 32, "weight": 0.0552978515625, "token": "?"},
|
||||
@@ -109,7 +121,7 @@ async def test_bge_m3_sparse_plugin_online(
|
||||
assert len(_get_attr_or_val(parsed_response, "data")) > 0
|
||||
|
||||
data_entry = _get_attr_or_val(parsed_response, "data")[0]
|
||||
assert _get_attr_or_val(data_entry, "object") == "sparse-embedding"
|
||||
assert _get_attr_or_val(data_entry, "object") == "dense&sparse"
|
||||
assert _get_attr_or_val(data_entry, "sparse_embedding")
|
||||
|
||||
# Verify sparse embedding format
|
||||
@@ -117,6 +129,11 @@ async def test_bge_m3_sparse_plugin_online(
|
||||
assert isinstance(sparse_embedding, list)
|
||||
_check_sparse_embedding(sparse_embedding, return_tokens)
|
||||
|
||||
# Verify dense embedding format
|
||||
dense_embedding = _get_attr_or_val(data_entry, "dense_embedding")
|
||||
assert isinstance(dense_embedding, list)
|
||||
_check_dense_embedding(dense_embedding)
|
||||
|
||||
# Verify usage information
|
||||
usage = _get_attr_or_val(parsed_response, "usage")
|
||||
assert usage, f"usage not found for {parsed_response}"
|
||||
@@ -164,6 +181,9 @@ def test_bge_m3_sparse_plugin_offline(vllm_runner, return_tokens: bool):
|
||||
sparse_embedding = output.sparse_embedding
|
||||
assert isinstance(sparse_embedding, list)
|
||||
_check_sparse_embedding(sparse_embedding, return_tokens)
|
||||
dense_embedding = output.dense_embedding
|
||||
assert isinstance(dense_embedding, list)
|
||||
_check_dense_embedding(dense_embedding)
|
||||
|
||||
# Verify usage
|
||||
assert response.usage.prompt_tokens > 0
|
||||
@@ -206,6 +226,9 @@ def test_bge_m3_sparse_plugin_offline_multiple_inputs(vllm_runner):
|
||||
# Each output should have sparse embeddings
|
||||
sparse_embedding = output.sparse_embedding
|
||||
assert isinstance(sparse_embedding, list)
|
||||
dense_embedding = output.dense_embedding
|
||||
assert isinstance(dense_embedding, list)
|
||||
_check_dense_embedding(dense_embedding, i)
|
||||
|
||||
# Verify usage
|
||||
assert response.usage.prompt_tokens > 0
|
||||
|
||||
Reference in New Issue
Block a user