[V1] Support LLM.apply_model (#18465)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -7,10 +7,10 @@ Run `pytest tests/quantization/test_quark.py`.
|
||||
See also `tests/kernels/moe/test_mxfp4_moe.py`.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
|
||||
import huggingface_hub
|
||||
import lm_eval
|
||||
@@ -24,9 +24,8 @@ from vllm.platforms import current_platform
|
||||
|
||||
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch
|
||||
|
||||
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
|
||||
"quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
||||
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
||||
|
||||
if QUARK_MXFP4_AVAILABLE:
|
||||
from quark.torch.export.nn.modules.realquantizer import (
|
||||
@@ -43,11 +42,9 @@ except huggingface_hub.errors.RepositoryNotFoundError:
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This module relies on V0 internals, so set VLLM_USE_V1=0.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
def enable_pickle(monkeypatch):
|
||||
"""`LLM.apply_model` requires pickling a function."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8'])
|
||||
@@ -132,13 +129,12 @@ def test_quark_fp8_parity(vllm_runner):
|
||||
}
|
||||
with (vllm_runner(quark_model_id, **llm_kwargs) as
|
||||
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
|
||||
quark_model = (quark_handle.llm.llm_engine.model_executor.
|
||||
driver_worker.model_runner.model)
|
||||
quark_state_dict = quark_model.state_dict()
|
||||
|
||||
fp8_model = (fp8_handle.llm.llm_engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
fp8_state_dict = fp8_model.state_dict()
|
||||
def get_state_dict(model):
|
||||
return {k: v.cpu() for k, v in model.state_dict().items()}
|
||||
|
||||
quark_state_dict, = quark_handle.apply_model(get_state_dict)
|
||||
fp8_state_dict, = fp8_handle.apply_model(get_state_dict)
|
||||
|
||||
assert fp8_state_dict.keys() == quark_state_dict.keys()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user