[Misc] Replace os environ to monkeypatch in test suite (#14516)
Signed-off-by: sibi <85477603+t-sibiraj@users.noreply.github.com> Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import math
|
||||
@@ -11,6 +12,7 @@ from scipy.spatial.distance import cosine
|
||||
|
||||
import vllm
|
||||
import vllm.config
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
@@ -29,36 +31,34 @@ def _arr(arr):
|
||||
return array("i", arr)
|
||||
|
||||
|
||||
def test_find_array(monkeypatch):
|
||||
def test_find_array(monkeypatch: pytest.MonkeyPatch):
|
||||
# GritLM embedding implementation is only supported by XFormers backend.
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
|
||||
|
||||
from vllm.model_executor.models.gritlm import GritLMPooler
|
||||
from vllm.model_executor.models.gritlm import GritLMPooler
|
||||
|
||||
# Create an LLM object to get the model config.
|
||||
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
|
||||
pooler = GritLMPooler(model_config=llm.llm_engine.model_config)
|
||||
# Create an LLM object to get the model config.
|
||||
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
|
||||
pooler = GritLMPooler(model_config=llm.llm_engine.model_config)
|
||||
|
||||
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
|
||||
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
|
||||
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
|
||||
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
|
||||
assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1
|
||||
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
|
||||
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
|
||||
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
|
||||
assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
|
||||
with pytest.raises(ValueError):
|
||||
pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_embedding():
|
||||
# GritLM embedding implementation is only supported by XFormers backend.
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
mp.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
|
||||
|
||||
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -69,9 +69,12 @@ def server_generate():
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client_embedding(server_embedding: RemoteOpenAIServer):
|
||||
async with server_embedding.get_async_client() as async_client:
|
||||
yield async_client
|
||||
async def client_embedding(monkeypatch: pytest.MonkeyPatch,
|
||||
server_embedding: RemoteOpenAIServer):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
|
||||
async with server_embedding.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -80,14 +83,20 @@ async def client_generate(server_generate: RemoteOpenAIServer):
|
||||
yield async_client
|
||||
|
||||
|
||||
def run_llm_encode(llm: vllm.LLM, queries: list[str],
|
||||
instruction: str) -> list[float]:
|
||||
def run_llm_encode(
|
||||
llm: vllm.LLM,
|
||||
queries: list[str],
|
||||
instruction: str,
|
||||
) -> list[float]:
|
||||
outputs = llm.encode([instruction + q for q in queries], )
|
||||
return [output.outputs.embedding for output in outputs]
|
||||
|
||||
|
||||
async def run_client_embeddings(client: vllm.LLM, queries: list[str],
|
||||
instruction: str) -> list[float]:
|
||||
async def run_client_embeddings(
|
||||
client: vllm.LLM,
|
||||
queries: list[str],
|
||||
instruction: str,
|
||||
) -> list[float]:
|
||||
outputs = await client.embeddings.create(
|
||||
model=MODEL_NAME,
|
||||
input=[instruction + q for q in queries],
|
||||
@@ -106,7 +115,7 @@ def get_test_data():
|
||||
README.md in https://github.com/ContextualAI/gritlm
|
||||
"""
|
||||
q_instruction = gritlm_instruction(
|
||||
"Given a scientific paper title, retrieve the paper's abstract")
|
||||
"Given a scientific paper title, retrieve the paper's abstract", )
|
||||
queries = [
|
||||
"Bitcoin: A Peer-to-Peer Electronic Cash System",
|
||||
"Generative Representational Instruction Tuning",
|
||||
@@ -136,31 +145,32 @@ def validate_embed_output(q_rep: list[float], d_rep: list[float]):
|
||||
assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001)
|
||||
|
||||
|
||||
def test_gritlm_offline_embedding(monkeypatch):
|
||||
def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch):
|
||||
# GritLM embedding implementation is only supported by XFormers backend.
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
|
||||
|
||||
queries, q_instruction, documents, d_instruction = get_test_data()
|
||||
queries, q_instruction, documents, d_instruction = get_test_data()
|
||||
|
||||
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
|
||||
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
|
||||
|
||||
d_rep = run_llm_encode(
|
||||
llm,
|
||||
documents,
|
||||
d_instruction,
|
||||
)
|
||||
q_rep = run_llm_encode(
|
||||
llm,
|
||||
queries,
|
||||
q_instruction,
|
||||
)
|
||||
d_rep = run_llm_encode(
|
||||
llm,
|
||||
documents,
|
||||
d_instruction,
|
||||
)
|
||||
q_rep = run_llm_encode(
|
||||
llm,
|
||||
queries,
|
||||
q_instruction,
|
||||
)
|
||||
|
||||
validate_embed_output(q_rep, d_rep)
|
||||
validate_embed_output(q_rep, d_rep)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gritlm_api_server_embedding(
|
||||
client_embedding: openai.AsyncOpenAI):
|
||||
client_embedding: openai.AsyncOpenAI, ):
|
||||
queries, q_instruction, documents, d_instruction = get_test_data()
|
||||
|
||||
d_rep = await run_client_embeddings(
|
||||
|
||||
Reference in New Issue
Block a user