[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8 block linear kernel selections. (#33892)

Signed-off-by: maral <maralbahari.98@gmail.com>
Signed-off-by: Maral <maralbahari.98@gmail.com>
This commit is contained in:
Maral
2026-04-09 08:50:39 +08:00
committed by GitHub
parent 8332078cfd
commit 2e9034c998
35 changed files with 1710 additions and 904 deletions

View File

@@ -16,6 +16,9 @@ from compressed_tensors.quantization import (
)
from tests.models.utils import check_logprobs_close
from vllm.model_executor.kernels.linear import (
Fp8BlockScaledMMLinearKernel,
)
from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig,
@@ -29,7 +32,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsWNA16,
)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
)
@@ -473,16 +475,14 @@ def test_compressed_tensors_fp8_block_enabled(vllm_runner):
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
assert isinstance(
qkv_proj.scheme.w8a8_block_fp8_linear, W8A8BlockFp8LinearOp
)
assert isinstance(qkv_proj.scheme.fp8_linear, Fp8BlockScaledMMLinearKernel)
assert qkv_proj.weight.dtype is fp8_dtype
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight.shape) == 2
assert len(qkv_proj.weight_scale.shape) == 2
input_quant_op = qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
input_quant_op = qkv_proj.scheme.fp8_linear.quant_fp8
assert isinstance(input_quant_op, QuantFP8)
assert input_quant_op._forward_method in (
input_quant_op.forward_cuda,

View File

@@ -13,6 +13,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops
from vllm.config.model import ModelConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config,
@@ -406,6 +407,8 @@ def test_fp8_reloading(
"If this is your use case, consider using a restore function like #26327"
)
# Set model config as model_config.dtype is required in Fp8LinearMethod.
default_vllm_config.model_config = ModelConfig()
with torch.device("cuda:0"):
config = Fp8Config(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,

View File

@@ -12,6 +12,7 @@ import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from vllm.config.model import ModelConfig
@pytest.fixture(scope="function", autouse=True)
@@ -46,7 +47,7 @@ def _snapshot_download_or_skip(model_id: str) -> str:
not is_quant_method_supported("modelopt"),
reason="ModelOpt FP8 is not supported on this GPU type.",
)
def test_modelopt_fp8_checkpoint_setup(vllm_runner):
def test_modelopt_fp8_checkpoint_setup(default_vllm_config, vllm_runner):
"""Test ModelOpt FP8 checkpoint loading and structure validation."""
# TODO: provide a small publicly available test checkpoint
model_path = (
@@ -61,6 +62,8 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
"This test requires a local ModelOpt FP8 checkpoint."
)
# Set model config as model_config.dtype is required in ModelOptFp8LinearMethod.
default_vllm_config.model_config = ModelConfig()
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
def check_model(model):
@@ -120,11 +123,13 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
not is_quant_method_supported("modelopt"),
reason="ModelOpt FP8 is not supported on this GPU type.",
)
def test_modelopt_fp8_pc_pt_checkpoint_setup(vllm_runner):
def test_modelopt_fp8_pc_pt_checkpoint_setup(default_vllm_config, vllm_runner):
"""Test ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoint setup."""
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pc-pt"
model_path = _snapshot_download_or_skip(model_id)
# Set model config as model_config.dtype is required in ModelOptFp8LinearMethod.
default_vllm_config.model_config = ModelConfig()
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
def check_model(model):
@@ -181,11 +186,13 @@ def test_modelopt_fp8_pc_pt_checkpoint_setup(vllm_runner):
not is_quant_method_supported("modelopt"),
reason="ModelOpt FP8 is not supported on this GPU type.",
)
def test_modelopt_fp8_pb_wo_checkpoint_setup(vllm_runner):
def test_modelopt_fp8_pb_wo_checkpoint_setup(default_vllm_config, vllm_runner):
"""Test ModelOpt FP8_PB_WO checkpoint setup."""
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pb-wo"
model_path = _snapshot_download_or_skip(model_id)
# Set model config as model_config.dtype is required in ModelOptFp8LinearMethod.
default_vllm_config.model_config = ModelConfig()
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
def check_model(model):