[Refactor] Make FP8 Linear Ops use kernel abstraction (#27814)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -11,13 +11,13 @@ from abc import ABC
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||
AiterScaledMMLinearKernel,
|
||||
AiterInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
|
||||
CPUScaledMMLinearKernel,
|
||||
CPUInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearKernel,
|
||||
@@ -33,36 +33,38 @@ def test_is_supported_is_abstract():
|
||||
|
||||
|
||||
def test_cpu_kernel_implements_is_supported():
|
||||
"""Test that CPUScaledMMLinearKernel implements is_supported() method."""
|
||||
assert hasattr(CPUScaledMMLinearKernel, "is_supported"), (
|
||||
"CPUScaledMMLinearKernel missing is_supported() method"
|
||||
"""Test that CPUInt8ScaledMMLinearKernel implements is_supported() method."""
|
||||
assert hasattr(CPUInt8ScaledMMLinearKernel, "is_supported"), (
|
||||
"CPUInt8ScaledMMLinearKernel missing is_supported() method"
|
||||
)
|
||||
# Verify it's a classmethod by checking if it can be called with the class
|
||||
# and by checking the method type
|
||||
assert inspect.ismethod(CPUScaledMMLinearKernel.is_supported) or inspect.isfunction(
|
||||
CPUScaledMMLinearKernel.is_supported
|
||||
), "CPUScaledMMLinearKernel.is_supported() should be a classmethod"
|
||||
assert inspect.ismethod(
|
||||
CPUInt8ScaledMMLinearKernel.is_supported
|
||||
) or inspect.isfunction(CPUInt8ScaledMMLinearKernel.is_supported), (
|
||||
"CPUInt8ScaledMMLinearKernel.is_supported() should be a classmethod"
|
||||
)
|
||||
# Verify it can be called as a classmethod
|
||||
result, reason = CPUScaledMMLinearKernel.is_supported()
|
||||
result, reason = CPUInt8ScaledMMLinearKernel.is_supported()
|
||||
assert isinstance(result, bool), "is_supported() should return a bool"
|
||||
assert reason is None or isinstance(reason, str), "reason should be str or None"
|
||||
|
||||
|
||||
def test_aiter_kernel_implements_is_supported():
|
||||
"""Test that AiterScaledMMLinearKernel implements is_supported() method."""
|
||||
assert hasattr(AiterScaledMMLinearKernel, "is_supported"), (
|
||||
"AiterScaledMMLinearKernel missing is_supported() method"
|
||||
"""Test that AiterInt8ScaledMMLinearKernel implements is_supported() method."""
|
||||
assert hasattr(AiterInt8ScaledMMLinearKernel, "is_supported"), (
|
||||
"AiterInt8ScaledMMLinearKernel missing is_supported() method"
|
||||
)
|
||||
# Verify it's a classmethod by checking if it can be called with the class
|
||||
# and by checking the method type
|
||||
assert inspect.ismethod(
|
||||
AiterScaledMMLinearKernel.is_supported
|
||||
) or inspect.isfunction(AiterScaledMMLinearKernel.is_supported), (
|
||||
"AiterScaledMMLinearKernel.is_supported() should be a classmethod"
|
||||
AiterInt8ScaledMMLinearKernel.is_supported
|
||||
) or inspect.isfunction(AiterInt8ScaledMMLinearKernel.is_supported), (
|
||||
"AiterInt8ScaledMMLinearKernel.is_supported() should be a classmethod"
|
||||
)
|
||||
# Verify it can be called as a classmethod
|
||||
# (will return False on CPU, which is expected)
|
||||
result, reason = AiterScaledMMLinearKernel.is_supported()
|
||||
result, reason = AiterInt8ScaledMMLinearKernel.is_supported()
|
||||
assert isinstance(result, bool), "is_supported() should return a bool"
|
||||
assert reason is None or isinstance(reason, str), "reason should be str or None"
|
||||
# On CPU, it should return False with a reason about requiring ROCm
|
||||
@@ -70,14 +72,14 @@ def test_aiter_kernel_implements_is_supported():
|
||||
|
||||
|
||||
def test_cpu_kernel_accepts_all_configs():
|
||||
"""Test that CPUScaledMMLinearKernel accepts all config combinations."""
|
||||
"""Test that CPUInt8ScaledMMLinearKernel accepts all config combinations."""
|
||||
configs = [
|
||||
ScaledMMLinearLayerConfig(
|
||||
Int8ScaledMMLinearLayerConfig(
|
||||
is_channelwise=False,
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=True,
|
||||
),
|
||||
ScaledMMLinearLayerConfig(
|
||||
Int8ScaledMMLinearLayerConfig(
|
||||
is_channelwise=True,
|
||||
is_static_input_scheme=False,
|
||||
input_symmetric=False,
|
||||
@@ -85,7 +87,7 @@ def test_cpu_kernel_accepts_all_configs():
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
can_impl, reason = CPUScaledMMLinearKernel.can_implement(config)
|
||||
can_impl, reason = CPUInt8ScaledMMLinearKernel.can_implement(config)
|
||||
assert can_impl, (
|
||||
f"CPUScaledMMLinearKernel should accept config {config}: {reason}"
|
||||
f"CPUInt8ScaledMMLinearKernel should accept config {config}: {reason}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user