[Refactor] [1/N] Reorganize kernel abstraction directory (#34055)
Signed-off-by: BadrBasowid <badr.basowid@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -26,24 +26,16 @@ from vllm.config import (
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
RowWiseTorchFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
FP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
|
||||
@@ -26,22 +26,14 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
FP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
|
||||
@@ -10,16 +10,10 @@ from abc import ABC
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
AiterInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
|
||||
CPUInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearKernel,
|
||||
)
|
||||
|
||||
|
||||
@@ -42,11 +42,9 @@ from vllm.distributed import (
|
||||
)
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.cli.serve import ServeSubcommand
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
|
||||
@@ -1,34 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
This module re-exports linear kernel implementations to provide a
|
||||
stable import interface during an ongoing reorganization. Upcoming
|
||||
PRs will remove the scaled_mm and mixed_precision subdirectories
|
||||
and reorganize kernels by provider (aiter, cutlass, flashinfer, etc.)
|
||||
rather than by precision type. By centralizing exports here, we
|
||||
minimize the need to update imports across other modules when the
|
||||
internal structure changes. If you are adding a new kernel selector
|
||||
or kernel implementation, add it to this __init__.py to maintain
|
||||
import stability.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||
AiterInt8ScaledMMLinearKernel,
|
||||
from vllm.model_executor.kernels.linear.mixed_precision import (
|
||||
MPLinearKernel,
|
||||
MPLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
|
||||
CPUInt8ScaledMMLinearKernel,
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.allspark import (
|
||||
AllSparkLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
CutlassInt8ScaledMMLinearKernel,
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.conch import (
|
||||
ConchLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.cpu import (
|
||||
CPUWNA16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
RowWiseTorchFP8ScaledMMLinearKernel,
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.cutlass import (
|
||||
CutlassW4A8LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.dynamic_4bit import (
|
||||
Dynamic4bitLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.exllama import (
|
||||
ExllamaLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.machete import (
|
||||
MacheteLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
|
||||
MarlinLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
|
||||
XPUwNa16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
Int8ScaledMMLinearKernel,
|
||||
@@ -36,10 +59,31 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
|
||||
AiterInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
|
||||
CPUInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
CutlassInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
RowWiseTorchFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.rocm import (
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
|
||||
TritonInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xpu import (
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.xpu import (
|
||||
XPUFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
@@ -80,6 +124,29 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
|
||||
],
|
||||
}
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
||||
PlatformEnum.CUDA: [
|
||||
CutlassW4A8LinearKernel,
|
||||
MacheteLinearKernel,
|
||||
AllSparkLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
ConchLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
],
|
||||
PlatformEnum.ROCM: [
|
||||
ConchLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
],
|
||||
PlatformEnum.XPU: [
|
||||
XPUwNa16LinearKernel,
|
||||
],
|
||||
PlatformEnum.CPU: [
|
||||
Dynamic4bitLinearKernel,
|
||||
CPUWNA16LinearKernel,
|
||||
],
|
||||
}
|
||||
|
||||
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
|
||||
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
|
||||
|
||||
@@ -234,3 +301,97 @@ def init_int8_linear_kernel(
|
||||
"azp_adj",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def choose_mp_linear_kernel(
|
||||
config: MPLinearLayerConfig, compute_capability: int | None = None
|
||||
) -> type[MPLinearKernel]:
|
||||
"""
|
||||
Choose an MPLinearKernel that can implement the given config for the given
|
||||
compute capability. Attempts to choose the best kernel in terms of
|
||||
performance.
|
||||
|
||||
Args:
|
||||
config (MPLinearLayerConfig): Description of the linear layer to be
|
||||
implemented.
|
||||
compute_capability (Optional[int], optional): The compute capability of
|
||||
the target device, if None uses `current_platform` to get
|
||||
the compute capability. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
type[MPLinearKernel]: Chosen kernel.
|
||||
"""
|
||||
if compute_capability is None:
|
||||
if current_platform is None:
|
||||
raise ValueError("Cannot determine compute capability")
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
||||
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
|
||||
failure_reasons.append(
|
||||
f" {kernel.__name__} disabled by environment variable"
|
||||
)
|
||||
continue
|
||||
if (
|
||||
compute_capability is not None
|
||||
and kernel.get_min_capability() > compute_capability
|
||||
):
|
||||
failure_reasons.append(
|
||||
f"{kernel.__name__} requires capability "
|
||||
f"{kernel.get_min_capability()}, current compute "
|
||||
f" capability is {compute_capability}"
|
||||
)
|
||||
continue
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
if can_implement:
|
||||
return kernel
|
||||
else:
|
||||
failure_reasons.append(
|
||||
f" {kernel.__name__} cannot implement due to: {failure_reason}"
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "
|
||||
"WNA16 linear layer. Reasons: \n" + "\n".join(failure_reasons)
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"init_fp8_linear_kernel",
|
||||
"init_int8_linear_kernel",
|
||||
"choose_mp_linear_kernel",
|
||||
"FP8ScaledMMLinearKernel",
|
||||
"Int8ScaledMMLinearKernel",
|
||||
"ScaledMMLinearKernel",
|
||||
"FP8ScaledMMLinearLayerConfig",
|
||||
"Int8ScaledMMLinearLayerConfig",
|
||||
"ScaledMMLinearLayerConfig",
|
||||
"AiterInt8ScaledMMLinearKernel",
|
||||
"CPUInt8ScaledMMLinearKernel",
|
||||
"CutlassFP8ScaledMMLinearKernel",
|
||||
"CutlassInt8ScaledMMLinearKernel",
|
||||
"FlashInferFP8ScaledMMLinearKernel",
|
||||
"ChannelWiseTorchFP8ScaledMMLinearKernel",
|
||||
"PerTensorTorchFP8ScaledMMLinearKernel",
|
||||
"RowWiseTorchFP8ScaledMMLinearKernel",
|
||||
"ROCmFP8ScaledMMLinearKernel",
|
||||
"TritonInt8ScaledMMLinearKernel",
|
||||
"MPLinearKernel",
|
||||
"MPLinearLayerConfig",
|
||||
"AllSparkLinearKernel",
|
||||
"ConchLinearKernel",
|
||||
"CPUWNA16LinearKernel",
|
||||
"CutlassW4A8LinearKernel",
|
||||
"Dynamic4bitLinearKernel",
|
||||
"ExllamaLinearKernel",
|
||||
"MacheteLinearKernel",
|
||||
"MarlinLinearKernel",
|
||||
"XPUwNa16LinearKernel",
|
||||
]
|
||||
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.allspark import (
|
||||
AllSparkLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.conch import (
|
||||
ConchLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.cpu import (
|
||||
CPUWNA16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.cutlass import (
|
||||
CutlassW4A8LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.dynamic_4bit import (
|
||||
Dynamic4bitLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.exllama import (
|
||||
ExllamaLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.machete import (
|
||||
MacheteLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
|
||||
MarlinLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import (
|
||||
MPLinearKernel,
|
||||
MPLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
|
||||
XPUwNa16LinearKernel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MPLinearKernel",
|
||||
"MPLinearLayerConfig",
|
||||
"AllSparkLinearKernel",
|
||||
"ConchLinearKernel",
|
||||
"CPUWNA16LinearKernel",
|
||||
"CutlassW4A8LinearKernel",
|
||||
"Dynamic4bitLinearKernel",
|
||||
"ExllamaLinearKernel",
|
||||
"MacheteLinearKernel",
|
||||
"MarlinLinearKernel",
|
||||
"XPUwNa16LinearKernel",
|
||||
]
|
||||
54
vllm/model_executor/kernels/linear/scaled_mm/__init__.py
Normal file
54
vllm/model_executor/kernels/linear/scaled_mm/__init__.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
|
||||
AiterInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
|
||||
CPUInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
CutlassInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
RowWiseTorchFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.rocm import (
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
|
||||
TritonInt8ScaledMMLinearKernel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FP8ScaledMMLinearKernel",
|
||||
"FP8ScaledMMLinearLayerConfig",
|
||||
"Int8ScaledMMLinearKernel",
|
||||
"Int8ScaledMMLinearLayerConfig",
|
||||
"ScaledMMLinearKernel",
|
||||
"ScaledMMLinearLayerConfig",
|
||||
"AiterInt8ScaledMMLinearKernel",
|
||||
"CPUInt8ScaledMMLinearKernel",
|
||||
"CutlassFP8ScaledMMLinearKernel",
|
||||
"CutlassInt8ScaledMMLinearKernel",
|
||||
"FlashInferFP8ScaledMMLinearKernel",
|
||||
"ChannelWiseTorchFP8ScaledMMLinearKernel",
|
||||
"PerTensorTorchFP8ScaledMMLinearKernel",
|
||||
"RowWiseTorchFP8ScaledMMLinearKernel",
|
||||
"ROCmFP8ScaledMMLinearKernel",
|
||||
"TritonInt8ScaledMMLinearKernel",
|
||||
]
|
||||
@@ -5,7 +5,7 @@ from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
from vllm.model_executor.kernels.linear import ( # noqa: E501
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
@@ -7,13 +7,13 @@ import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
MPLinearLayerConfig,
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_repeat_scales_on_all_ranks,
|
||||
)
|
||||
|
||||
@@ -6,13 +6,13 @@ from collections.abc import Callable
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
MPLinearLayerConfig,
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
|
||||
@@ -9,12 +9,12 @@ from torch.nn import Parameter
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
create_fp8_input_scale,
|
||||
|
||||
@@ -7,12 +7,12 @@ import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_int8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_int8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
|
||||
@@ -7,15 +7,13 @@ import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
MarlinLinearKernel,
|
||||
MPLinearLayerConfig,
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import (
|
||||
MarlinLinearKernel,
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
get_marlin_input_dtype,
|
||||
|
||||
@@ -8,6 +8,9 @@ from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
@@ -18,9 +21,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
|
||||
@@ -13,6 +13,9 @@ from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
@@ -46,9 +49,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
|
||||
@@ -10,6 +10,10 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
MPLinearLayerConfig,
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -27,10 +31,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig,
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_dynamic_override,
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
|
||||
AllSparkLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
|
||||
ConchLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.cpu import ( # noqa: E501
|
||||
CPUWNA16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501
|
||||
CutlassW4A8LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501
|
||||
Dynamic4bitLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
|
||||
ExllamaLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
|
||||
MacheteLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501
|
||||
MarlinLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501
|
||||
MPLinearKernel,
|
||||
MPLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.xpu import ( # noqa: E501
|
||||
XPUwNa16LinearKernel,
|
||||
)
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
||||
PlatformEnum.CUDA: [
|
||||
CutlassW4A8LinearKernel,
|
||||
MacheteLinearKernel,
|
||||
AllSparkLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
ConchLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
],
|
||||
PlatformEnum.ROCM: [
|
||||
ConchLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
],
|
||||
PlatformEnum.XPU: [
|
||||
XPUwNa16LinearKernel,
|
||||
],
|
||||
PlatformEnum.CPU: [
|
||||
Dynamic4bitLinearKernel,
|
||||
CPUWNA16LinearKernel,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def choose_mp_linear_kernel(
|
||||
config: MPLinearLayerConfig, compute_capability: int | None = None
|
||||
) -> type[MPLinearKernel]:
|
||||
"""
|
||||
Choose an MPLinearKernel that can implement the given config for the given
|
||||
compute capability. Attempts to choose the best kernel in terms of
|
||||
performance.
|
||||
|
||||
Args:
|
||||
config (MPLinearLayerConfig): Description of the linear layer to be
|
||||
implemented.
|
||||
compute_capability (Optional[int], optional): The compute capability of
|
||||
the target device, if None uses `current_platform` to get
|
||||
the compute capability. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
type[MPLinearKernel]: Chosen kernel.
|
||||
"""
|
||||
if compute_capability is None:
|
||||
if current_platform is None:
|
||||
raise ValueError("Cannot determine compute capability")
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
||||
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
|
||||
failure_reasons.append(
|
||||
f" {kernel.__name__} disabled by environment variable"
|
||||
)
|
||||
continue
|
||||
if (
|
||||
compute_capability is not None
|
||||
and kernel.get_min_capability() > compute_capability
|
||||
):
|
||||
failure_reasons.append(
|
||||
f"{kernel.__name__} requires capability "
|
||||
f"{kernel.get_min_capability()}, current compute "
|
||||
f" capability is {compute_capability}"
|
||||
)
|
||||
continue
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
if can_implement:
|
||||
return kernel
|
||||
else:
|
||||
failure_reasons.append(
|
||||
f" {kernel.__name__} cannot implement due to: {failure_reason}"
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "
|
||||
"WNA16 linear layer. Reasons: \n" + "\n".join(failure_reasons)
|
||||
)
|
||||
@@ -9,6 +9,9 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -45,9 +48,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
flashinfer_trtllm_fp4_moe,
|
||||
|
||||
@@ -8,6 +8,9 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@@ -17,9 +20,6 @@ from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8KVCacheMethod,
|
||||
Fp8LinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped,
|
||||
kFp8DynamicTokenSym,
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
|
||||
@@ -6,7 +6,7 @@ from collections.abc import Callable
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_int8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
|
||||
Reference in New Issue
Block a user