[core][distributed] simplify code to support pipeline parallel (#6406)
This commit is contained in:
@@ -28,10 +28,8 @@ def test_vllm_gc_ed():
|
||||
assert weak_llm() is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_hip()
|
||||
and os.getenv("VLLM_ATTENTION_BACKEND") == "FLASHINFER",
|
||||
reason="Flashinfer does not support ROCm/HIP.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||
@@ -40,10 +38,17 @@ def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
backend: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
) -> None:
|
||||
|
||||
if backend == "FLASHINFER" and is_hip():
|
||||
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
||||
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user