Fix KV sharing fast prefill with cudagraph enabled (#28537)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -4,13 +4,11 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationMode
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
from ...utils import fork_new_process_for_each_test
|
||||
from ...utils import check_answers, fork_new_process_for_each_test, prep_prompts
|
||||
|
||||
# global seed
|
||||
SEED = 42
|
||||
@@ -45,28 +43,12 @@ def test_prompts():
|
||||
return prompts
|
||||
|
||||
|
||||
def cleanup(llm: LLM, compilation_config: CompilationConfig):
|
||||
# hacky: below lines are required to free up memory for the next test
|
||||
# when setting VLLM_ENABLE_V1_MULTIPROCESSING=0, del llm is not sufficient
|
||||
# TODO(sarckk): when enforce_eager=False, memory is not freed:
|
||||
# find out why and re-enable test for enforce_eager=False case
|
||||
llm_engine = llm.llm_engine.engine_core.engine_core
|
||||
model_runner = llm_engine.model_executor.driver_worker.worker.model_runner
|
||||
del model_runner.model
|
||||
del model_runner.kv_caches
|
||||
del compilation_config.static_forward_context
|
||||
compilation_config.static_forward_context = {}
|
||||
|
||||
del llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
@pytest.mark.parametrize("enforce_eager", [True])
|
||||
@pytest.mark.skip(reason="Disable until Gemma3n supports fast prefill")
|
||||
@pytest.mark.parametrize("kv_sharing_fast_prefill", [False, True])
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_kv_sharing_fast_prefill(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
kv_sharing_fast_prefill: bool,
|
||||
enforce_eager: bool,
|
||||
test_prompts: list[str],
|
||||
):
|
||||
@@ -79,36 +61,25 @@ def test_kv_sharing_fast_prefill(
|
||||
if not enforce_eager
|
||||
else CompilationMode.NONE,
|
||||
)
|
||||
batch_size = 10
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
# Make scheduling deterministic for reproducibility
|
||||
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
llm = LLM(
|
||||
model="google/gemma-3n-E2B-it",
|
||||
enforce_eager=enforce_eager,
|
||||
compilation_config=compilation_config,
|
||||
seed=SEED,
|
||||
)
|
||||
ref_responses = llm.generate(test_prompts, sampling_params)
|
||||
|
||||
cleanup(llm, compilation_config)
|
||||
prompts, answer, indices = prep_prompts(batch_size)
|
||||
|
||||
llm = LLM(
|
||||
model="google/gemma-3n-E2B-it",
|
||||
enforce_eager=enforce_eager,
|
||||
compilation_config=compilation_config,
|
||||
seed=SEED,
|
||||
kv_sharing_fast_prefill=True,
|
||||
kv_sharing_fast_prefill=kv_sharing_fast_prefill,
|
||||
)
|
||||
responses = llm.generate(prompts, sampling_params)
|
||||
check_answers(
|
||||
indices,
|
||||
answer,
|
||||
[response.outputs[0].text for response in responses],
|
||||
accept_rate=1.0,
|
||||
)
|
||||
optimized_responses = llm.generate(test_prompts, sampling_params)
|
||||
|
||||
cleanup(llm, compilation_config)
|
||||
|
||||
misses = 0
|
||||
|
||||
for ref_response, optimized_response in zip(ref_responses, optimized_responses):
|
||||
if ref_response.outputs[0].text != optimized_response.outputs[0].text:
|
||||
misses += 1
|
||||
|
||||
assert misses == 0
|
||||
|
||||
Reference in New Issue
Block a user