[Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -57,20 +57,24 @@ def test_multiple_pooling_params(llm: LLM):
|
||||
]
|
||||
|
||||
# Multiple PoolingParams should be matched with each prompt
|
||||
outputs = llm.encode(PROMPTS, pooling_params=pooling_params)
|
||||
outputs = llm.encode(PROMPTS, pooling_params=pooling_params, pooling_task="embed")
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
|
||||
# Exception raised, if the size of params does not match the size of prompts
|
||||
with pytest.raises(ValueError):
|
||||
outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3])
|
||||
outputs = llm.encode(
|
||||
PROMPTS, pooling_params=pooling_params[:3], pooling_task="embed"
|
||||
)
|
||||
|
||||
# Single PoolingParams should be applied to every prompt
|
||||
single_pooling_params = PoolingParams()
|
||||
outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params)
|
||||
outputs = llm.encode(
|
||||
PROMPTS, pooling_params=single_pooling_params, pooling_task="embed"
|
||||
)
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
|
||||
# pooling_params is None, default params should be applied
|
||||
outputs = llm.encode(PROMPTS, pooling_params=None)
|
||||
outputs = llm.encode(PROMPTS, pooling_params=None, pooling_task="embed")
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user