[Core] Set pooling params based on task and model (#21128)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-18 20:41:17 +08:00
committed by GitHub
parent 4adc66f64d
commit 45badd05d0
24 changed files with 509 additions and 241 deletions

View File

@@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import importlib.util
from array import array
import numpy as np
import openai
import pytest
from scipy.spatial.distance import cosine
@@ -14,10 +12,6 @@ from vllm.config import ModelConfig
from ....utils import RemoteOpenAIServer
# GritLM embedding implementation is only supported by XFormers backend.
pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"),
reason="GritLM requires XFormers")
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
MAX_MODEL_LEN = 4000
@@ -26,11 +20,11 @@ def _arr(arr):
"""
Convert a list of integers to an array of integers.
"""
return array("i", arr)
return np.array(arr)
def test_find_array():
from vllm.model_executor.models.gritlm import GritLMPooler
from vllm.model_executor.models.gritlm import GritLMMeanPool
model_config = ModelConfig(
MODEL_NAME,
@@ -41,17 +35,19 @@ def test_find_array():
dtype="bfloat16",
seed=0,
)
pooler = GritLMPooler(model_config=model_config)
pooling = GritLMMeanPool(model_config=model_config)
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=3) == -1
assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=4) == 3
assert pooling._find_array(arr, _arr([3, 5]), start_idx=0) == -1
with pytest.raises(ValueError):
pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
pooling._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
def run_llm_encode(