[core] gemma2 full context length support (#10584)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-22 20:13:54 -08:00
committed by GitHub
parent 978b39744b
commit 4aba6e3d1a
4 changed files with 54 additions and 23 deletions

View File

@@ -14,11 +14,12 @@ from vllm import LLM
from vllm.platforms import current_platform
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from ..conftest import VllmRunner
from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test
MODELS = [
"facebook/opt-125m",
"google/gemma-2-2b-it",
"meta-llama/Llama-3.2-1B",
]
@@ -42,8 +43,6 @@ def test_vllm_gc_ed():
@pytest.mark.parametrize("enforce_eager", [False, True])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
backend: str,
dtype: str,
@@ -54,15 +53,27 @@ def test_models(
if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
pytest.skip(
"XFORMERS does not support gemma2 with full context length.")
os.environ["VLLM_ATTENTION_BACKEND"] = backend
# 5042 tokens for gemma2
# gemma2 has alternating sliding window size of 4096
# we need a prompt with more than 4096 tokens to test the sliding window
prompt = "The following numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:"
example_prompts = [prompt]
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7) as vllm_model:
with VllmRunner(model,
max_model_len=8192,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(