[V1] Support LLM.apply_model (#18465)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-09-20 15:14:35 +08:00
committed by GitHub
parent be874c0201
commit 3d9a1d2de5
17 changed files with 194 additions and 169 deletions

View File

@@ -7,10 +7,10 @@ Run `pytest tests/quantization/test_quark.py`.
See also `tests/kernels/moe/test_mxfp4_moe.py`.
"""
import importlib
import importlib.metadata
import os
from dataclasses import dataclass
from importlib.util import find_spec
import huggingface_hub
import lm_eval
@@ -24,9 +24,8 @@ from vllm.platforms import current_platform
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
"quark") is not None and version.parse(
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
if QUARK_MXFP4_AVAILABLE:
from quark.torch.export.nn.modules.realquantizer import (
@@ -43,11 +42,9 @@ except huggingface_hub.errors.RepositoryNotFoundError:
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8'])
@@ -132,13 +129,12 @@ def test_quark_fp8_parity(vllm_runner):
}
with (vllm_runner(quark_model_id, **llm_kwargs) as
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
quark_model = (quark_handle.llm.llm_engine.model_executor.
driver_worker.model_runner.model)
quark_state_dict = quark_model.state_dict()
fp8_model = (fp8_handle.llm.llm_engine.model_executor.driver_worker.
model_runner.model)
fp8_state_dict = fp8_model.state_dict()
def get_state_dict(model):
return {k: v.cpu() for k, v in model.state_dict().items()}
quark_state_dict, = quark_handle.apply_model(get_state_dict)
fp8_state_dict, = fp8_handle.apply_model(get_state_dict)
assert fp8_state_dict.keys() == quark_state_dict.keys()