[Attention] Update tests to remove deprecated env vars (#30563)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -67,7 +67,6 @@ def _fix_prompt_embed_outputs(
|
||||
@pytest.mark.parametrize("model_executor", ["uni", "mp"])
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
def test_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
hf_runner,
|
||||
model: str,
|
||||
backend: str,
|
||||
@@ -77,48 +76,46 @@ def test_models(
|
||||
model_executor: str,
|
||||
enable_prompt_embeds: bool,
|
||||
) -> None:
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
# 5042 tokens for gemma2
|
||||
# gemma2 has alternating sliding window size of 4096
|
||||
# we need a prompt with more than 4096 tokens to test the sliding window
|
||||
prompt = (
|
||||
"The following numbers of the sequence "
|
||||
+ ", ".join(str(i) for i in range(1024))
|
||||
+ " are:"
|
||||
)
|
||||
example_prompts = [prompt]
|
||||
|
||||
# 5042 tokens for gemma2
|
||||
# gemma2 has alternating sliding window size of 4096
|
||||
# we need a prompt with more than 4096 tokens to test the sliding window
|
||||
prompt = (
|
||||
"The following numbers of the sequence "
|
||||
+ ", ".join(str(i) for i in range(1024))
|
||||
+ " are:"
|
||||
)
|
||||
example_prompts = [prompt]
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=8192,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
async_scheduling=async_scheduling,
|
||||
distributed_executor_backend=model_executor,
|
||||
attention_config={"backend": backend},
|
||||
) as vllm_model:
|
||||
if enable_prompt_embeds:
|
||||
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
||||
vllm_outputs = _fix_prompt_embed_outputs(
|
||||
vllm_outputs, hf_model, example_prompts
|
||||
)
|
||||
else:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=8192,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
async_scheduling=async_scheduling,
|
||||
distributed_executor_backend=model_executor,
|
||||
) as vllm_model:
|
||||
if enable_prompt_embeds:
|
||||
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
||||
vllm_outputs = _fix_prompt_embed_outputs(
|
||||
vllm_outputs, hf_model, example_prompts
|
||||
)
|
||||
else:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@@ -161,12 +158,6 @@ def test_models_distributed(
|
||||
): # noqa
|
||||
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
|
||||
|
||||
if attention_backend:
|
||||
monkeypatch_context.setenv(
|
||||
"VLLM_ATTENTION_BACKEND",
|
||||
attention_backend,
|
||||
)
|
||||
|
||||
for k, v in extra_env.items():
|
||||
monkeypatch_context.setenv(k, v)
|
||||
|
||||
@@ -178,6 +169,7 @@ def test_models_distributed(
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method
|
||||
# (the default method).
|
||||
attention_config = {"backend": attention_backend} if attention_backend else None
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
@@ -185,6 +177,7 @@ def test_models_distributed(
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
attention_config=attention_config,
|
||||
) as vllm_model:
|
||||
if enable_prompt_embeds:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
|
||||
Reference in New Issue
Block a user