[Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-10-15 19:14:41 +08:00
committed by GitHub
parent d4d1a6024f
commit f54f85129e
41 changed files with 786 additions and 399 deletions

View File

@@ -57,20 +57,24 @@ def test_multiple_pooling_params(llm: LLM):
]
# Multiple PoolingParams should be matched with each prompt
outputs = llm.encode(PROMPTS, pooling_params=pooling_params)
outputs = llm.encode(PROMPTS, pooling_params=pooling_params, pooling_task="embed")
assert len(PROMPTS) == len(outputs)
# Exception raised, if the size of params does not match the size of prompts
with pytest.raises(ValueError):
outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3])
outputs = llm.encode(
PROMPTS, pooling_params=pooling_params[:3], pooling_task="embed"
)
# Single PoolingParams should be applied to every prompt
single_pooling_params = PoolingParams()
outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params)
outputs = llm.encode(
PROMPTS, pooling_params=single_pooling_params, pooling_task="embed"
)
assert len(PROMPTS) == len(outputs)
# pooling_params is None, default params should be applied
outputs = llm.encode(PROMPTS, pooling_params=None)
outputs = llm.encode(PROMPTS, pooling_params=None, pooling_task="embed")
assert len(PROMPTS) == len(outputs)