[1/n][Chunked Prefill] Refactor input query shapes (#3236)
This commit is contained in:
@@ -13,6 +13,7 @@ MODELS = [
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
@@ -20,12 +21,13 @@ def test_models(
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
) -> None:
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user