[ROCm] Enable Triton ScaledMM fallback + kernel selection fix (#26668)
Signed-off-by: Shivam <shivampr.dev@gmail.com> Signed-off-by: Shivam <shivamprasad91@gmail.com>
This commit is contained in:
@@ -17,7 +17,9 @@ class ScaledMMLinearLayerConfig:
|
||||
class ScaledMMLinearKernel(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@@ -35,6 +37,7 @@ class ScaledMMLinearKernel(ABC):
|
||||
azp_adj_param_name: str,
|
||||
) -> None:
|
||||
assert self.can_implement(c)
|
||||
assert self.is_supported()
|
||||
self.config = c
|
||||
self.w_q_name = w_q_param_name
|
||||
self.w_s_name = w_s_param_name
|
||||
|
||||
@@ -27,7 +27,7 @@ from vllm.platforms import PlatformEnum, current_platform
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
|
||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
||||
}
|
||||
@@ -55,41 +55,25 @@ def choose_scaled_mm_linear_kernel(
|
||||
type[ScaledMMLinearKernel]: Chosen kernel.
|
||||
"""
|
||||
|
||||
if compute_capability is None:
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
||||
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
|
||||
failure_reasons.append(
|
||||
f" {kernel.__name__} disabled by environment variable"
|
||||
)
|
||||
failure_reasons.append(f"{kernel.__name__}: disabled by env var")
|
||||
continue
|
||||
|
||||
# If the current platform uses compute_capability,
|
||||
# make sure the kernel supports the compute cability.
|
||||
if compute_capability is not None:
|
||||
kernel_min_capability = kernel.get_min_capability()
|
||||
if (
|
||||
kernel_min_capability is not None
|
||||
and kernel_min_capability > compute_capability
|
||||
):
|
||||
failure_reasons.append(
|
||||
f"{kernel.__name__} requires capability "
|
||||
f"{kernel_min_capability}, current compute capability "
|
||||
f"is {compute_capability}"
|
||||
)
|
||||
continue
|
||||
is_supported, reason = kernel.is_supported(compute_capability)
|
||||
if not is_supported:
|
||||
failure_reasons.append(f"{kernel.__name__}: {reason}")
|
||||
continue
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
if can_implement:
|
||||
return kernel
|
||||
else:
|
||||
failure_reasons.append(
|
||||
f" {kernel.__name__} cannot implement due to: {failure_reason}"
|
||||
)
|
||||
can_implement, reason = kernel.can_implement(config)
|
||||
if not can_implement:
|
||||
failure_reasons.append(f"{kernel.__name__}: {reason}")
|
||||
continue
|
||||
|
||||
return kernel
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "
|
||||
|
||||
@@ -14,17 +14,21 @@ from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
|
||||
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_rocm():
|
||||
return (
|
||||
False,
|
||||
"AiterScaledMMLinearKernel requires `aiter` which is not "
|
||||
+ "currently supported on non-ROCm platform.",
|
||||
)
|
||||
if compute_capability is None:
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc.major * 10 + _cc.minor
|
||||
if compute_capability is not None and compute_capability < 90:
|
||||
return False, f"requires capability 90, got {compute_capability}"
|
||||
|
||||
try:
|
||||
import aiter # noqa: F401 # deliberately attempt to import aiter
|
||||
@@ -34,8 +38,8 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
"AiterScaledMMLinearKernel requires `aiter` which is not "
|
||||
+ "installed on ROCm.",
|
||||
)
|
||||
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
|
||||
if not (rocm_aiter_ops.is_linear_enabled()):
|
||||
|
||||
if not rocm_aiter_ops.is_linear_enabled():
|
||||
return (
|
||||
False,
|
||||
"AiterScaledMMLinearKernel is disabled. "
|
||||
@@ -44,6 +48,10 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
+ "`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not c.input_symmetric:
|
||||
return (
|
||||
False,
|
||||
|
||||
@@ -19,14 +19,15 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi
|
||||
|
||||
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "Requires CPU."
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "CPUScaledMM requires running on CPU."
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
@@ -16,14 +16,21 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi
|
||||
|
||||
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Requires CUDA."
|
||||
if compute_capability is None:
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc.major * 10 + _cc.minor
|
||||
if compute_capability is not None and compute_capability < 75:
|
||||
return False, f"requires capability 75, got {compute_capability}"
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CutlassScaledMM requires running on CUDA."
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
@@ -4,34 +4,53 @@
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa: E501
|
||||
triton_scaled_mm,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .cutlass import CutlassScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if current_platform.is_cuda_alike():
|
||||
return True, None
|
||||
return False, "Requires ROCm or CUDA."
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if current_platform.is_cpu():
|
||||
return (
|
||||
False,
|
||||
"TritonScaledMMLinearKernel requires Triton which is not "
|
||||
+ "currently supported on CPU.",
|
||||
)
|
||||
if not c.input_symmetric:
|
||||
return (
|
||||
False,
|
||||
"TritonScaledMMLinearKernel only supports symmetric " + "quantization.",
|
||||
)
|
||||
return False, "Only symmetric input is supported."
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, self.i_s_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
else:
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
@@ -39,4 +58,14 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return super().apply_weights(layer, x, bias)
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(
|
||||
x.contiguous(), i_s, i_zp, symmetric=True
|
||||
)
|
||||
|
||||
assert x_zp is None, "Triton kernel only supports symmetric quantization"
|
||||
|
||||
return triton_scaled_mm(
|
||||
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
|
||||
)
|
||||
|
||||
@@ -17,11 +17,12 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi
|
||||
|
||||
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
"TPU platform does have a concept of compute capability, "
|
||||
"this method should not be called."
|
||||
)
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_tpu():
|
||||
return False, "Requires TPU."
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
|
||||
Reference in New Issue
Block a user