[Misc] Replace os environ to monkeypatch in test suite (#14516)

Signed-off-by: sibi <85477603+t-sibiraj@users.noreply.github.com>
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Sibi
2025-03-17 11:35:57 +08:00
committed by GitHub
parent 1e799b7ec1
commit a73e183e36
43 changed files with 1900 additions and 1658 deletions

View File

@@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import importlib.util
import math
@@ -11,6 +12,7 @@ from scipy.spatial.distance import cosine
import vllm
import vllm.config
from vllm.utils import STR_BACKEND_ENV_VAR
from ....utils import RemoteOpenAIServer
@@ -29,36 +31,34 @@ def _arr(arr):
return array("i", arr)
def test_find_array(monkeypatch):
def test_find_array(monkeypatch: pytest.MonkeyPatch):
# GritLM embedding implementation is only supported by XFormers backend.
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
from vllm.model_executor.models.gritlm import GritLMPooler
from vllm.model_executor.models.gritlm import GritLMPooler
# Create an LLM object to get the model config.
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
pooler = GritLMPooler(model_config=llm.llm_engine.model_config)
# Create an LLM object to get the model config.
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
pooler = GritLMPooler(model_config=llm.llm_engine.model_config)
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1
with pytest.raises(ValueError):
pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
with pytest.raises(ValueError):
pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
@pytest.fixture(scope="module")
def server_embedding():
# GritLM embedding implementation is only supported by XFormers backend.
with pytest.MonkeyPatch.context() as mp:
mp.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
@@ -69,9 +69,12 @@ def server_generate():
@pytest_asyncio.fixture
async def client_embedding(server_embedding: RemoteOpenAIServer):
async with server_embedding.get_async_client() as async_client:
yield async_client
async def client_embedding(monkeypatch: pytest.MonkeyPatch,
server_embedding: RemoteOpenAIServer):
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
async with server_embedding.get_async_client() as async_client:
yield async_client
@pytest_asyncio.fixture
@@ -80,14 +83,20 @@ async def client_generate(server_generate: RemoteOpenAIServer):
yield async_client
def run_llm_encode(llm: vllm.LLM, queries: list[str],
instruction: str) -> list[float]:
def run_llm_encode(
llm: vllm.LLM,
queries: list[str],
instruction: str,
) -> list[float]:
outputs = llm.encode([instruction + q for q in queries], )
return [output.outputs.embedding for output in outputs]
async def run_client_embeddings(client: vllm.LLM, queries: list[str],
instruction: str) -> list[float]:
async def run_client_embeddings(
client: vllm.LLM,
queries: list[str],
instruction: str,
) -> list[float]:
outputs = await client.embeddings.create(
model=MODEL_NAME,
input=[instruction + q for q in queries],
@@ -106,7 +115,7 @@ def get_test_data():
README.md in https://github.com/ContextualAI/gritlm
"""
q_instruction = gritlm_instruction(
"Given a scientific paper title, retrieve the paper's abstract")
"Given a scientific paper title, retrieve the paper's abstract", )
queries = [
"Bitcoin: A Peer-to-Peer Electronic Cash System",
"Generative Representational Instruction Tuning",
@@ -136,31 +145,32 @@ def validate_embed_output(q_rep: list[float], d_rep: list[float]):
assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001)
def test_gritlm_offline_embedding(monkeypatch):
def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch):
# GritLM embedding implementation is only supported by XFormers backend.
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
queries, q_instruction, documents, d_instruction = get_test_data()
queries, q_instruction, documents, d_instruction = get_test_data()
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
d_rep = run_llm_encode(
llm,
documents,
d_instruction,
)
q_rep = run_llm_encode(
llm,
queries,
q_instruction,
)
d_rep = run_llm_encode(
llm,
documents,
d_instruction,
)
q_rep = run_llm_encode(
llm,
queries,
q_instruction,
)
validate_embed_output(q_rep, d_rep)
validate_embed_output(q_rep, d_rep)
@pytest.mark.asyncio
async def test_gritlm_api_server_embedding(
client_embedding: openai.AsyncOpenAI):
client_embedding: openai.AsyncOpenAI, ):
queries, q_instruction, documents, d_instruction = get_test_data()
d_rep = await run_client_embeddings(