[Refactor] Make FP8 Linear Ops use kernel abstraction (#27814)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm
2026-01-20 14:48:20 +08:00
committed by GitHub
parent e9c83cdc51
commit 148117ea2e
30 changed files with 1467 additions and 1038 deletions

View File

@@ -11,13 +11,13 @@ from abc import ABC
import pytest
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig,
Int8ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel,
AiterInt8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
CPUScaledMMLinearKernel,
CPUInt8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearKernel,
@@ -33,36 +33,38 @@ def test_is_supported_is_abstract():
def test_cpu_kernel_implements_is_supported():
"""Test that CPUScaledMMLinearKernel implements is_supported() method."""
assert hasattr(CPUScaledMMLinearKernel, "is_supported"), (
"CPUScaledMMLinearKernel missing is_supported() method"
"""Test that CPUInt8ScaledMMLinearKernel implements is_supported() method."""
assert hasattr(CPUInt8ScaledMMLinearKernel, "is_supported"), (
"CPUInt8ScaledMMLinearKernel missing is_supported() method"
)
# Verify it's a classmethod by checking if it can be called with the class
# and by checking the method type
assert inspect.ismethod(CPUScaledMMLinearKernel.is_supported) or inspect.isfunction(
CPUScaledMMLinearKernel.is_supported
), "CPUScaledMMLinearKernel.is_supported() should be a classmethod"
assert inspect.ismethod(
CPUInt8ScaledMMLinearKernel.is_supported
) or inspect.isfunction(CPUInt8ScaledMMLinearKernel.is_supported), (
"CPUInt8ScaledMMLinearKernel.is_supported() should be a classmethod"
)
# Verify it can be called as a classmethod
result, reason = CPUScaledMMLinearKernel.is_supported()
result, reason = CPUInt8ScaledMMLinearKernel.is_supported()
assert isinstance(result, bool), "is_supported() should return a bool"
assert reason is None or isinstance(reason, str), "reason should be str or None"
def test_aiter_kernel_implements_is_supported():
"""Test that AiterScaledMMLinearKernel implements is_supported() method."""
assert hasattr(AiterScaledMMLinearKernel, "is_supported"), (
"AiterScaledMMLinearKernel missing is_supported() method"
"""Test that AiterInt8ScaledMMLinearKernel implements is_supported() method."""
assert hasattr(AiterInt8ScaledMMLinearKernel, "is_supported"), (
"AiterInt8ScaledMMLinearKernel missing is_supported() method"
)
# Verify it's a classmethod by checking if it can be called with the class
# and by checking the method type
assert inspect.ismethod(
AiterScaledMMLinearKernel.is_supported
) or inspect.isfunction(AiterScaledMMLinearKernel.is_supported), (
"AiterScaledMMLinearKernel.is_supported() should be a classmethod"
AiterInt8ScaledMMLinearKernel.is_supported
) or inspect.isfunction(AiterInt8ScaledMMLinearKernel.is_supported), (
"AiterInt8ScaledMMLinearKernel.is_supported() should be a classmethod"
)
# Verify it can be called as a classmethod
# (will return False on CPU, which is expected)
result, reason = AiterScaledMMLinearKernel.is_supported()
result, reason = AiterInt8ScaledMMLinearKernel.is_supported()
assert isinstance(result, bool), "is_supported() should return a bool"
assert reason is None or isinstance(reason, str), "reason should be str or None"
# On CPU, it should return False with a reason about requiring ROCm
@@ -70,14 +72,14 @@ def test_aiter_kernel_implements_is_supported():
def test_cpu_kernel_accepts_all_configs():
"""Test that CPUScaledMMLinearKernel accepts all config combinations."""
"""Test that CPUInt8ScaledMMLinearKernel accepts all config combinations."""
configs = [
ScaledMMLinearLayerConfig(
Int8ScaledMMLinearLayerConfig(
is_channelwise=False,
is_static_input_scheme=True,
input_symmetric=True,
),
ScaledMMLinearLayerConfig(
Int8ScaledMMLinearLayerConfig(
is_channelwise=True,
is_static_input_scheme=False,
input_symmetric=False,
@@ -85,7 +87,7 @@ def test_cpu_kernel_accepts_all_configs():
]
for config in configs:
can_impl, reason = CPUScaledMMLinearKernel.can_implement(config)
can_impl, reason = CPUInt8ScaledMMLinearKernel.can_implement(config)
assert can_impl, (
f"CPUScaledMMLinearKernel should accept config {config}: {reason}"
f"CPUInt8ScaledMMLinearKernel should accept config {config}: {reason}"
)