diff --git a/tests/kernels/quantization/test_scaled_mm_kernel_selection.py b/tests/kernels/quantization/test_scaled_mm_kernel_selection.py index 1ac663ff6..bedebdb59 100644 --- a/tests/kernels/quantization/test_scaled_mm_kernel_selection.py +++ b/tests/kernels/quantization/test_scaled_mm_kernel_selection.py @@ -7,15 +7,21 @@ Run `pytest tests/kernels/quantization/test_scaled_mm_kernel_selection.py`. import inspect from abc import ABC +from unittest.mock import patch import pytest +import torch from vllm.model_executor.kernels.linear import ( AiterInt8ScaledMMLinearKernel, CPUInt8ScaledMMLinearKernel, + Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, ScaledMMLinearKernel, + init_int8_linear_kernel, + register_linear_kernel, ) +from vllm.platforms import PlatformEnum pytestmark = pytest.mark.cpu_test @@ -85,3 +91,39 @@ def test_cpu_kernel_accepts_all_configs(): assert can_impl, ( f"CPUInt8ScaledMMLinearKernel should accept config {config}: {reason}" ) + + +class OOTInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + return True, None + + @classmethod + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + pass + + +@patch("vllm.model_executor.kernels.linear.current_platform") +def test_register_oot_linear_kernel(platform_mock): + """Test that the linear kernel registration works correctly.""" + platform_mock._enum = PlatformEnum.OOT + register_linear_kernel(OOTInt8ScaledMMLinearKernel, PlatformEnum.OOT, "int8") + + kernel = init_int8_linear_kernel(True, True, True, "module") + + assert isinstance(kernel, OOTInt8ScaledMMLinearKernel), ( + "init_int8_linear_kernel should return an instance of the registered kernel" + ) diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py index 282208502..cfef32056 100644 --- a/vllm/model_executor/kernels/linear/__init__.py +++ b/vllm/model_executor/kernels/linear/__init__.py @@ -367,10 +367,44 @@ def choose_mp_linear_kernel( ) +def register_linear_kernel( + kernel_class: type, + platform: PlatformEnum, + kernel_type: str = "mp", +) -> None: + """ + Register a new linear kernel class to be considered in kernel selection. + + Args: + kernel_class (type): The kernel class to register. + platform (PlatformEnum): The platform for which this kernel is applicable. + kernel_type (str): The type of the kernel, either "mp", "int8", or "fp8". + Defaults to "mp". + + Raises: + ValueError: If the kernel_type is not recognized. + """ + if kernel_type == "mp": + if platform not in _POSSIBLE_KERNELS: + _POSSIBLE_KERNELS[platform] = [] + _POSSIBLE_KERNELS[platform].append(kernel_class) + elif kernel_type == "int8": + if platform not in _POSSIBLE_INT8_KERNELS: + _POSSIBLE_INT8_KERNELS[platform] = [] + _POSSIBLE_INT8_KERNELS[platform].append(kernel_class) + elif kernel_type == "fp8": + if platform not in _POSSIBLE_FP8_KERNELS: + _POSSIBLE_FP8_KERNELS[platform] = [] + _POSSIBLE_FP8_KERNELS[platform].append(kernel_class) + else: + raise ValueError(f"Unrecognized kernel type: {kernel_type}") + + __all__ = [ "init_fp8_linear_kernel", "init_int8_linear_kernel", "choose_mp_linear_kernel", + "register_linear_kernel", "FP8ScaledMMLinearKernel", "Int8ScaledMMLinearKernel", "ScaledMMLinearKernel",