[OOT] Add OOT support for linear kernel. (#37989)
Signed-off-by: menogrey <1299267905@qq.com>
This commit is contained in:
@@ -7,15 +7,21 @@ Run `pytest tests/kernels/quantization/test_scaled_mm_kernel_selection.py`.
|
||||
|
||||
import inspect
|
||||
from abc import ABC
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
AiterInt8ScaledMMLinearKernel,
|
||||
CPUInt8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearKernel,
|
||||
init_int8_linear_kernel,
|
||||
register_linear_kernel,
|
||||
)
|
||||
from vllm.platforms import PlatformEnum
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
@@ -85,3 +91,39 @@ def test_cpu_kernel_accepts_all_configs():
|
||||
assert can_impl, (
|
||||
f"CPUInt8ScaledMMLinearKernel should accept config {config}: {reason}"
|
||||
)
|
||||
|
||||
|
||||
class OOTInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
@patch("vllm.model_executor.kernels.linear.current_platform")
|
||||
def test_register_oot_linear_kernel(platform_mock):
|
||||
"""Test that the linear kernel registration works correctly."""
|
||||
platform_mock._enum = PlatformEnum.OOT
|
||||
register_linear_kernel(OOTInt8ScaledMMLinearKernel, PlatformEnum.OOT, "int8")
|
||||
|
||||
kernel = init_int8_linear_kernel(True, True, True, "module")
|
||||
|
||||
assert isinstance(kernel, OOTInt8ScaledMMLinearKernel), (
|
||||
"init_int8_linear_kernel should return an instance of the registered kernel"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user