[Model] GritLM supports other attention backends (#18109)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -11,7 +11,6 @@ from scipy.spatial.distance import cosine
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
@@ -117,44 +116,37 @@ def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]):
|
||||
assert math.isclose(cosine_sim_q1_d1, 0.534, abs_tol=0.001)
|
||||
|
||||
|
||||
def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch,
|
||||
vllm_runner):
|
||||
# GritLM embedding implementation is only supported by XFormers backend.
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
|
||||
def test_gritlm_offline_embedding(vllm_runner):
|
||||
queries, q_instruction, documents, d_instruction = get_test_data()
|
||||
|
||||
queries, q_instruction, documents, d_instruction = get_test_data()
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
task="embed",
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
) as vllm_model:
|
||||
llm = vllm_model.model
|
||||
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
task="embed",
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
) as vllm_model:
|
||||
llm = vllm_model.model
|
||||
d_rep = run_llm_encode(
|
||||
llm,
|
||||
documents,
|
||||
d_instruction,
|
||||
)
|
||||
q_rep = run_llm_encode(
|
||||
llm,
|
||||
queries,
|
||||
q_instruction,
|
||||
)
|
||||
|
||||
d_rep = run_llm_encode(
|
||||
llm,
|
||||
documents,
|
||||
d_instruction,
|
||||
)
|
||||
q_rep = run_llm_encode(
|
||||
llm,
|
||||
queries,
|
||||
q_instruction,
|
||||
)
|
||||
|
||||
validate_embed_output(q_rep, d_rep)
|
||||
validate_embed_output(q_rep, d_rep)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gritlm_api_server_embedding():
|
||||
queries, q_instruction, documents, d_instruction = get_test_data()
|
||||
|
||||
# GritLM embedding implementation is only supported by XFormers backend.
|
||||
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
|
||||
env_dict = {STR_BACKEND_ENV_VAR: "XFORMERS"}
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as server:
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as server:
|
||||
client_embedding = server.get_async_client()
|
||||
|
||||
d_rep = await run_client_embeddings(
|
||||
@@ -172,35 +164,28 @@ async def test_gritlm_api_server_embedding():
|
||||
|
||||
|
||||
def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner):
|
||||
# GritLM embedding implementation is only supported by XFormers backend.
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
|
||||
input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"
|
||||
|
||||
input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
task="generate",
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
) as vllm_model:
|
||||
llm = vllm_model.model
|
||||
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
task="generate",
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
) as vllm_model:
|
||||
llm = vllm_model.model
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=256)
|
||||
outputs = llm.generate(input, sampling_params=sampling_params)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=256)
|
||||
outputs = llm.generate(input, sampling_params=sampling_params)
|
||||
|
||||
assert outputs[0].outputs[0].text == "The capital of France is Paris."
|
||||
assert outputs[0].outputs[0].text == "The capital of France is Paris."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gritlm_api_server_generate():
|
||||
input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"
|
||||
|
||||
# GritLM embedding implementation is only supported by XFormers backend.
|
||||
args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)]
|
||||
env_dict = {"VLLM_USE_V1": "0", STR_BACKEND_ENV_VAR: "XFORMERS"}
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as server:
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as server:
|
||||
client_generate = server.get_async_client()
|
||||
|
||||
outputs = await client_generate.completions.create(
|
||||
|
||||
Reference in New Issue
Block a user