[Core] Set pooling params based on task and model (#21128)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -2,9 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
from array import array
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
from scipy.spatial.distance import cosine
|
||||
@@ -14,10 +12,6 @@ from vllm.config import ModelConfig
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
# GritLM embedding implementation is only supported by XFormers backend.
|
||||
pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"),
|
||||
reason="GritLM requires XFormers")
|
||||
|
||||
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
|
||||
MAX_MODEL_LEN = 4000
|
||||
|
||||
@@ -26,11 +20,11 @@ def _arr(arr):
|
||||
"""
|
||||
Convert a list of integers to an array of integers.
|
||||
"""
|
||||
return array("i", arr)
|
||||
return np.array(arr)
|
||||
|
||||
|
||||
def test_find_array():
|
||||
from vllm.model_executor.models.gritlm import GritLMPooler
|
||||
from vllm.model_executor.models.gritlm import GritLMMeanPool
|
||||
|
||||
model_config = ModelConfig(
|
||||
MODEL_NAME,
|
||||
@@ -41,17 +35,19 @@ def test_find_array():
|
||||
dtype="bfloat16",
|
||||
seed=0,
|
||||
)
|
||||
pooler = GritLMPooler(model_config=model_config)
|
||||
pooling = GritLMMeanPool(model_config=model_config)
|
||||
|
||||
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
|
||||
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
|
||||
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
|
||||
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
|
||||
assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1
|
||||
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
|
||||
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
|
||||
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
|
||||
assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=3) == -1
|
||||
assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=4) == 3
|
||||
assert pooling._find_array(arr, _arr([3, 5]), start_idx=0) == -1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
|
||||
pooling._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
|
||||
|
||||
|
||||
def run_llm_encode(
|
||||
|
||||
Reference in New Issue
Block a user