[OOT] Add OOT support for linear kernel. (#37989)

Signed-off-by: menogrey <1299267905@qq.com>
This commit is contained in:
zhangyiming
2026-03-31 14:33:21 +08:00
committed by GitHub
parent 6cc7abdc66
commit 1ac6694297
2 changed files with 76 additions and 0 deletions

View File

@@ -7,15 +7,21 @@ Run `pytest tests/kernels/quantization/test_scaled_mm_kernel_selection.py`.
import inspect
from abc import ABC
from unittest.mock import patch
import pytest
import torch
from vllm.model_executor.kernels.linear import (
AiterInt8ScaledMMLinearKernel,
CPUInt8ScaledMMLinearKernel,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel,
init_int8_linear_kernel,
register_linear_kernel,
)
from vllm.platforms import PlatformEnum
pytestmark = pytest.mark.cpu_test
@@ -85,3 +91,39 @@ def test_cpu_kernel_accepts_all_configs():
assert can_impl, (
f"CPUInt8ScaledMMLinearKernel should accept config {config}: {reason}"
)
class OOTInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
return True, None
@classmethod
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
pass
@patch("vllm.model_executor.kernels.linear.current_platform")
def test_register_oot_linear_kernel(platform_mock):
"""Test that the linear kernel registration works correctly."""
platform_mock._enum = PlatformEnum.OOT
register_linear_kernel(OOTInt8ScaledMMLinearKernel, PlatformEnum.OOT, "int8")
kernel = init_int8_linear_kernel(True, True, True, "module")
assert isinstance(kernel, OOTInt8ScaledMMLinearKernel), (
"init_int8_linear_kernel should return an instance of the registered kernel"
)