[Core] Rework dtype resolution (#18751)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -60,7 +60,6 @@ def _fix_prompt_embed_outputs(
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [False])
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
@@ -69,7 +68,6 @@ def test_models(
|
||||
hf_runner,
|
||||
model: str,
|
||||
backend: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
enable_prompt_embeds: bool,
|
||||
@@ -97,7 +95,7 @@ def test_models(
|
||||
str(i) for i in range(1024)) + " are:"
|
||||
example_prompts = [prompt]
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
with torch.no_grad():
|
||||
@@ -106,7 +104,6 @@ def test_models(
|
||||
|
||||
with VllmRunner(model,
|
||||
max_model_len=8192,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7) as vllm_model:
|
||||
|
||||
Reference in New Issue
Block a user