Fix KV sharing fast prefill with cudagraph enabled (#28537)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Yong Hoon Shin
2025-11-14 01:53:42 -10:00
committed by GitHub
parent 4516d44b7f
commit 9324e10275
3 changed files with 17 additions and 57 deletions

View File

@@ -4,13 +4,11 @@
import random
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode
from vllm.distributed import cleanup_dist_env_and_memory
from ...utils import fork_new_process_for_each_test
from ...utils import check_answers, fork_new_process_for_each_test, prep_prompts
# global seed
SEED = 42
@@ -45,28 +43,12 @@ def test_prompts():
return prompts
def cleanup(llm: LLM, compilation_config: CompilationConfig):
# hacky: below lines are required to free up memory for the next test
# when setting VLLM_ENABLE_V1_MULTIPROCESSING=0, del llm is not sufficient
# TODO(sarckk): when enforce_eager=False, memory is not freed:
# find out why and re-enable test for enforce_eager=False case
llm_engine = llm.llm_engine.engine_core.engine_core
model_runner = llm_engine.model_executor.driver_worker.worker.model_runner
del model_runner.model
del model_runner.kv_caches
del compilation_config.static_forward_context
compilation_config.static_forward_context = {}
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@fork_new_process_for_each_test
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.skip(reason="Disable until Gemma3n supports fast prefill")
@pytest.mark.parametrize("kv_sharing_fast_prefill", [False, True])
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_kv_sharing_fast_prefill(
monkeypatch: pytest.MonkeyPatch,
kv_sharing_fast_prefill: bool,
enforce_eager: bool,
test_prompts: list[str],
):
@@ -79,36 +61,25 @@ def test_kv_sharing_fast_prefill(
if not enforce_eager
else CompilationMode.NONE,
)
batch_size = 10
with monkeypatch.context() as m:
# Make scheduling deterministic for reproducibility
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
llm = LLM(
model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager,
compilation_config=compilation_config,
seed=SEED,
)
ref_responses = llm.generate(test_prompts, sampling_params)
cleanup(llm, compilation_config)
prompts, answer, indices = prep_prompts(batch_size)
llm = LLM(
model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager,
compilation_config=compilation_config,
seed=SEED,
kv_sharing_fast_prefill=True,
kv_sharing_fast_prefill=kv_sharing_fast_prefill,
)
responses = llm.generate(prompts, sampling_params)
check_answers(
indices,
answer,
[response.outputs[0].text for response in responses],
accept_rate=1.0,
)
optimized_responses = llm.generate(test_prompts, sampling_params)
cleanup(llm, compilation_config)
misses = 0
for ref_response, optimized_response in zip(ref_responses, optimized_responses):
if ref_response.outputs[0].text != optimized_response.outputs[0].text:
misses += 1
assert misses == 0