[AMD][Quantization] Add TritonScaledMMLinearKernel since int8 is broken for AMD (#12282)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
rasmith
2025-01-22 18:10:37 -06:00
committed by GitHub
parent aea94362c9
commit 68c4421b6d
3 changed files with 58 additions and 5 deletions

View File

@@ -5,8 +5,8 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassScaledMMLinearKernel)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearKernel, ScaledMMLinearLayerConfig)
# from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
# TritonScaledMMLinear)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
XLAScaledMMLinearKernel)
from vllm.platforms import PlatformEnum, current_platform
@@ -15,9 +15,7 @@ from vllm.platforms import PlatformEnum, current_platform
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
# TODO(rob): Create TritonScaledMMLinear kernel. ROCM will
# incorrectly attempt to run AZP models if prompted to.
PlatformEnum.ROCM: [CutlassScaledMMLinearKernel],
PlatformEnum.ROCM: [TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
}

View File

@@ -0,0 +1,38 @@
from typing import Optional, Tuple
import torch
from vllm.platforms import current_platform
from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if current_platform.is_cpu():
return (
False,
"TritonScaledMMLinearKernel requires Triton which is not " +
"currently supported on CPU.")
if not c.input_symmetric:
return (False,
"TritonScaledMMLinearKernel only supports symmetric " +
"quantization.")
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return super().apply_weights(layer, x, bias)