[AMD][Quantization] Add TritonScaledMMLinearKernel since int8 is broken for AMD (#12282)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
@@ -39,6 +39,23 @@ def get_8bit_types():
|
|||||||
return types
|
return types
|
||||||
|
|
||||||
|
|
||||||
|
# This test is to check regressions for int8 support on ROCm.
|
||||||
|
@pytest.mark.parametrize("model_path", [
|
||||||
|
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [10])
|
||||||
|
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||||
|
reason="Should only run on ROCm")
|
||||||
|
def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
|
||||||
|
max_tokens, num_logprobs):
|
||||||
|
dtype = "bfloat16"
|
||||||
|
|
||||||
|
with vllm_runner(model_path, dtype=dtype) as vllm_model:
|
||||||
|
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens,
|
||||||
|
num_logprobs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("M", [1, 33, 64, 512])
|
@pytest.mark.parametrize("M", [1, 33, 64, 512])
|
||||||
@pytest.mark.parametrize("N", [256, 971, 20486])
|
@pytest.mark.parametrize("N", [256, 971, 20486])
|
||||||
@pytest.mark.parametrize("K", [128, 496, 1024])
|
@pytest.mark.parametrize("K", [128, 496, 1024])
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
|||||||
CutlassScaledMMLinearKernel)
|
CutlassScaledMMLinearKernel)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||||
ScaledMMLinearKernel, ScaledMMLinearLayerConfig)
|
ScaledMMLinearKernel, ScaledMMLinearLayerConfig)
|
||||||
# from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||||
# TritonScaledMMLinear)
|
TritonScaledMMLinearKernel)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
||||||
XLAScaledMMLinearKernel)
|
XLAScaledMMLinearKernel)
|
||||||
from vllm.platforms import PlatformEnum, current_platform
|
from vllm.platforms import PlatformEnum, current_platform
|
||||||
@@ -15,9 +15,7 @@ from vllm.platforms import PlatformEnum, current_platform
|
|||||||
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
|
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
|
||||||
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
|
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
|
||||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
||||||
# TODO(rob): Create TritonScaledMMLinear kernel. ROCM will
|
PlatformEnum.ROCM: [TritonScaledMMLinearKernel],
|
||||||
# incorrectly attempt to run AZP models if prompted to.
|
|
||||||
PlatformEnum.ROCM: [CutlassScaledMMLinearKernel],
|
|
||||||
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,38 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .cutlass import CutlassScaledMMLinearKernel
|
||||||
|
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 75
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(
|
||||||
|
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||||
|
if current_platform.is_cpu():
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"TritonScaledMMLinearKernel requires Triton which is not " +
|
||||||
|
"currently supported on CPU.")
|
||||||
|
if not c.input_symmetric:
|
||||||
|
return (False,
|
||||||
|
"TritonScaledMMLinearKernel only supports symmetric " +
|
||||||
|
"quantization.")
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
super().process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
return super().apply_weights(layer, x, bias)
|
||||||
Reference in New Issue
Block a user