[MoE Refactor][15/N] Apply Refactor to Fp8 (#31415)

This commit is contained in:
Robert Shaw
2026-01-07 19:42:33 -05:00
committed by GitHub
parent ffc0a2798b
commit 5dcd7ef1f2
38 changed files with 1439 additions and 1528 deletions

View File

@@ -15,7 +15,6 @@ from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config,
Fp8KVCacheMethod,
Fp8LinearMethod,
Fp8MoeBackend,
Fp8MoEMethod,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -278,8 +277,18 @@ def test_scaled_fp8_quant(dtype) -> None:
# this is the case for marlin as well as per-tensor Fp8MoEMethod
@pytest.mark.parametrize("use_marlin", [False]) # skip True
def test_fp8_reloading(
method_cls, is_checkpoint_fp8_serialized, weight_block_size, use_marlin, dist_init
method_cls,
is_checkpoint_fp8_serialized,
weight_block_size,
use_marlin,
dist_init,
monkeypatch,
):
# NOTE(rob): this test fails when using DeepGEMM because the
# shapes are invalid. Previously the test was passing because
# we set fp8_backend to None, which sidestepped the issue.
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "0")
if is_checkpoint_fp8_serialized is False:
pytest.skip("FP8 weight reloading does not support online quantization")
@@ -307,6 +316,7 @@ def test_fp8_reloading(
params_dtype=torch.bfloat16,
weight_loader=default_weight_loader,
)
method.use_marlin = use_marlin
else:
layer = FusedMoE(
@@ -325,11 +335,6 @@ def test_fp8_reloading(
weight_loader=default_weight_loader,
)
# Fp8LinearMethod uses use_marlin
# Fp8MoEMethod uses fp8_backend
method.use_marlin = use_marlin
method.fp8_backend = Fp8MoeBackend.MARLIN if use_marlin else None
# capture weights format during loading
original_metadata = [
(name, param.shape, getattr(param, "weight_loader", default_weight_loader))