[Core] Support async scheduling with uniproc executor (#24219)

Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Nick Hill
2025-09-12 16:34:28 -07:00
committed by GitHub
parent 8226dd56bf
commit 4fdd6f5cbf
9 changed files with 103 additions and 55 deletions

View File

@@ -62,6 +62,8 @@ def _fix_prompt_embed_outputs(
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("async_scheduling", [True, False])
@pytest.mark.parametrize("model_executor", ["uni", "mp"])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models(
monkeypatch: pytest.MonkeyPatch,
@@ -70,6 +72,8 @@ def test_models(
backend: str,
max_tokens: int,
enforce_eager: bool,
async_scheduling: bool,
model_executor: str,
enable_prompt_embeds: bool,
) -> None:
@@ -77,6 +81,12 @@ def test_models(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")
if not envs.VLLM_USE_V1:
if async_scheduling:
pytest.skip("async_scheduling only supported in v1.")
if model_executor != "uni":
pytest.skip("only test uniproc executor for v0.")
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
pytest.skip(
f"{backend} does not support gemma2 with full context length.")
@@ -98,11 +108,15 @@ def test_models(
prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts)
with VllmRunner(model,
max_model_len=8192,
enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7) as vllm_model:
with VllmRunner(
model,
max_model_len=8192,
enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7,
async_scheduling=async_scheduling,
distributed_executor_backend=model_executor,
) as vllm_model:
if enable_prompt_embeds:
vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)