[Core] [Bugfix]: tensor parallel with prompt embeds (#18171)
Signed-off-by: Nan2018 <nan@protopia.ai> Co-authored-by: Andrew Sansom <andrew@protopia.ai>
This commit is contained in:
@@ -8,12 +8,13 @@ import weakref
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm import LLM, envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
|
||||
|
||||
from ..conftest import VllmRunner
|
||||
from ..conftest import HfRunner, VllmRunner
|
||||
from ..models.utils import check_outputs_equal
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
@@ -43,11 +44,26 @@ def test_vllm_gc_ed():
|
||||
assert weak_llm() is None
|
||||
|
||||
|
||||
def _fix_prompt_embed_outputs(
|
||||
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner,
|
||||
example_prompts: list[str]) -> list[tuple[list[int], str]]:
|
||||
fixed_vllm_outputs = []
|
||||
for vllm_output, hf_input, prompt in zip(
|
||||
vllm_outputs, hf_model.get_inputs(example_prompts),
|
||||
example_prompts):
|
||||
hf_input_ids = hf_input["input_ids"].tolist()[0]
|
||||
fixed_vllm_outputs.append(
|
||||
(hf_input_ids + vllm_output[0][len(hf_input_ids):],
|
||||
prompt + vllm_output[1]))
|
||||
return fixed_vllm_outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [False])
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
def test_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
hf_runner,
|
||||
@@ -56,8 +72,13 @@ def test_models(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
enable_prompt_embeds: bool,
|
||||
) -> None:
|
||||
|
||||
if enable_prompt_embeds and envs.is_set(
|
||||
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
||||
|
||||
if backend == "FLASHINFER" and current_platform.is_rocm():
|
||||
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
||||
|
||||
@@ -78,14 +99,25 @@ def test_models(
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(
|
||||
example_prompts)
|
||||
|
||||
with VllmRunner(model,
|
||||
max_model_len=8192,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
vllm_outputs = vllm_model.generate_greedy(
|
||||
prompt_embeds, max_tokens)
|
||||
vllm_outputs = _fix_prompt_embed_outputs(
|
||||
vllm_outputs, hf_model, example_prompts)
|
||||
else:
|
||||
vllm_outputs = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
@@ -108,6 +140,7 @@ def test_models(
|
||||
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"),
|
||||
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
|
||||
])
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
def test_models_distributed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
hf_runner,
|
||||
@@ -117,14 +150,22 @@ def test_models_distributed(
|
||||
distributed_executor_backend: str,
|
||||
attention_backend: str,
|
||||
test_suite: str,
|
||||
enable_prompt_embeds: bool,
|
||||
) -> None:
|
||||
|
||||
if enable_prompt_embeds and envs.is_set(
|
||||
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
||||
|
||||
if test_suite != TARGET_TEST_SUITE:
|
||||
pytest.skip(f"Skip test for {test_suite}")
|
||||
|
||||
with monkeypatch.context() as monkeypatch_context:
|
||||
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
|
||||
# test Ray Compiled Graph
|
||||
if enable_prompt_embeds:
|
||||
pytest.skip(
|
||||
"enable_prompt_embeds does not work with ray compiled dag."
|
||||
)
|
||||
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
|
||||
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
|
||||
|
||||
@@ -147,12 +188,26 @@ def test_models_distributed(
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(
|
||||
example_prompts)
|
||||
vllm_outputs = vllm_model.generate_greedy(
|
||||
prompt_embeds, max_tokens)
|
||||
vllm_outputs = _fix_prompt_embed_outputs(
|
||||
vllm_outputs, hf_model, example_prompts)
|
||||
hf_outputs = hf_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
else:
|
||||
vllm_outputs = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
|
||||
Reference in New Issue
Block a user