[V1] Support LLM.apply_model (#18465)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-09-20 15:14:35 +08:00
committed by GitHub
parent be874c0201
commit 3d9a1d2de5
17 changed files with 194 additions and 169 deletions

View File

@@ -13,6 +13,16 @@ from vllm.model_executor.layers.quantization.ptpc_fp8 import (
PTPCFp8LinearMethod)
from vllm.platforms import current_platform
UNSUPPORTED_STR = (
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only "
"support output dtype of bfloat16. torch.float16 is specified.")
@pytest.fixture(scope="function", autouse=True)
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"),
reason="PTPC FP8 is not supported on this GPU type.")
@@ -21,14 +31,22 @@ from vllm.platforms import current_platform
@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
try:
with vllm_runner("facebook/opt-125m",
dtype=dtype,
quantization="ptpc_fp8",
kv_cache_dtype=kv_cache_dtype) as llm:
llm = vllm_runner("facebook/opt-125m",
dtype=dtype,
quantization="ptpc_fp8",
kv_cache_dtype=kv_cache_dtype)
except AssertionError as e:
if str(e) == UNSUPPORTED_STR:
# If the error message matches, the test passes
return
else:
# If the error message does not match, re-raise the exception
raise
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
with llm:
def check_model(model):
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.quant_method, PTPCFp8LinearMethod)
if kv_cache_dtype == "ptpc_fp8":
@@ -40,17 +58,8 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
if current_platform.has_device_capability(94):
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fnuz
else:
pytest.skip()
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
except AssertionError as e:
if str(
e
) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501
# If the error message matches, the test passes
pass
else:
# If the error message does not match, re-raise the exception
raise
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output