[MoE Refactor][15/N] Apply Refactor to Fp8 (#31415)
This commit is contained in:
@@ -15,7 +15,6 @@ from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8Config,
|
||||
Fp8KVCacheMethod,
|
||||
Fp8LinearMethod,
|
||||
Fp8MoeBackend,
|
||||
Fp8MoEMethod,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@@ -278,8 +277,18 @@ def test_scaled_fp8_quant(dtype) -> None:
|
||||
# this is the case for marlin as well as per-tensor Fp8MoEMethod
|
||||
@pytest.mark.parametrize("use_marlin", [False]) # skip True
|
||||
def test_fp8_reloading(
|
||||
method_cls, is_checkpoint_fp8_serialized, weight_block_size, use_marlin, dist_init
|
||||
method_cls,
|
||||
is_checkpoint_fp8_serialized,
|
||||
weight_block_size,
|
||||
use_marlin,
|
||||
dist_init,
|
||||
monkeypatch,
|
||||
):
|
||||
# NOTE(rob): this test fails when using DeepGEMM because the
|
||||
# shapes are invalid. Previously the test was passing because
|
||||
# we set fp8_backend to None, which sidestepped the issue.
|
||||
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "0")
|
||||
|
||||
if is_checkpoint_fp8_serialized is False:
|
||||
pytest.skip("FP8 weight reloading does not support online quantization")
|
||||
|
||||
@@ -307,6 +316,7 @@ def test_fp8_reloading(
|
||||
params_dtype=torch.bfloat16,
|
||||
weight_loader=default_weight_loader,
|
||||
)
|
||||
method.use_marlin = use_marlin
|
||||
|
||||
else:
|
||||
layer = FusedMoE(
|
||||
@@ -325,11 +335,6 @@ def test_fp8_reloading(
|
||||
weight_loader=default_weight_loader,
|
||||
)
|
||||
|
||||
# Fp8LinearMethod uses use_marlin
|
||||
# Fp8MoEMethod uses fp8_backend
|
||||
method.use_marlin = use_marlin
|
||||
method.fp8_backend = Fp8MoeBackend.MARLIN if use_marlin else None
|
||||
|
||||
# capture weights format during loading
|
||||
original_metadata = [
|
||||
(name, param.shape, getattr(param, "weight_loader", default_weight_loader))
|
||||
|
||||
Reference in New Issue
Block a user