[NVFP4] Support NVFP4 dense models from modelopt and compressed-tensors on AMD Instinct MI300, MI355X and Hopper through emulation (#35733)

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
fxmarty-amd
2026-04-07 00:18:27 +02:00
committed by GitHub
parent 9c81f35b1a
commit 00d7b497b3
10 changed files with 191 additions and 58 deletions

View File

@@ -89,22 +89,33 @@ def test_models(example_prompts, model_name) -> None:
EAGER = [True, False]
SM_100_NVFP4_BACKENDS = [
"flashinfer-cudnn",
"flashinfer-trtllm",
"flashinfer-cutlass",
]
@pytest.mark.skipif(
not current_platform.has_device_capability(100),
reason="modelopt_fp4 is not supported on this GPU type.",
)
@pytest.mark.parametrize("model", ["nvidia/Llama-3.1-8B-Instruct-NVFP4"])
@pytest.mark.parametrize("eager", EAGER)
@pytest.mark.parametrize(
"backend",
[
"emulation",
"flashinfer-cudnn",
"flashinfer-trtllm", # the small seq_len ensures trtllm_8x4_layout backend is used
"flashinfer-cutlass",
],
)
def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch):
if (
not current_platform.has_device_capability(100)
and backend in SM_100_NVFP4_BACKENDS
):
pytest.skip(
f"The backend {backend} is not supported with current_platform.has_device_capability(100) == False"
)
monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", backend)
with vllm_runner(model, enforce_eager=eager) as llm:
output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2)

View File

@@ -366,9 +366,6 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
assert output
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize(
"args",
[
@@ -398,7 +395,7 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
assert qkv_proj.scheme.group_size == 16
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=4)
output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
print(output)
assert output