[1/n][Chunked Prefill] Refactor input query shapes (#3236)

This commit is contained in:
SangBin Cho
2024-03-21 06:46:05 +09:00
committed by GitHub
parent 426ec4ec67
commit 6e435de766
18 changed files with 579 additions and 263 deletions

View File

@@ -13,6 +13,7 @@ MODELS = [
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False, True])
def test_models(
hf_runner,
vllm_runner,
@@ -20,12 +21,13 @@ def test_models(
model: str,
dtype: str,
max_tokens: int,
enforce_eager: bool,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model