[FP8]add FP8 WoQ kernel abstraction. (#32929)
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
This commit is contained in:
@@ -72,6 +72,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.marlin import (
|
||||
MarlinFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
@@ -104,6 +107,7 @@ _POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]]
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CUDA: [
|
||||
MarlinFP8ScaledMMLinearKernel,
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
|
||||
@@ -14,6 +14,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.marlin import (
|
||||
MarlinFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
@@ -46,6 +49,7 @@ __all__ = [
|
||||
"CutlassFP8ScaledMMLinearKernel",
|
||||
"CutlassInt8ScaledMMLinearKernel",
|
||||
"FlashInferFP8ScaledMMLinearKernel",
|
||||
"MarlinFP8ScaledMMLinearKernel",
|
||||
"ChannelWiseTorchFP8ScaledMMLinearKernel",
|
||||
"PerTensorTorchFP8ScaledMMLinearKernel",
|
||||
"RowWiseTorchFP8ScaledMMLinearKernel",
|
||||
|
||||
120
vllm/model_executor/kernels/linear/scaled_mm/marlin.py
Normal file
120
vllm/model_executor/kernels/linear/scaled_mm/marlin.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
process_fp8_weight_block_strategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
is_fp8_marlin_supported,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8Static128BlockSym,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class MarlinFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
"""
|
||||
FP8 Marlin kernel for GPUs that lack FP8 hardware support.
|
||||
Leverages the Marlin kernel for fast weight-only FP8 quantization.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "requires CUDA."
|
||||
# Check if platform supports FP8 Marlin
|
||||
if not is_fp8_marlin_supported():
|
||||
return False, "FP8 Marlin requires compute capability 7.5 or higher"
|
||||
if vllm_is_batch_invariant():
|
||||
return False, "FP8 Marlin not supported for batch invariant execution."
|
||||
if (
|
||||
compute_capability is not None
|
||||
and compute_capability >= 89
|
||||
and not envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
):
|
||||
return (
|
||||
False,
|
||||
"To apply FP8 Marlin on high-capability GPUs, please set "
|
||||
"VLLM_TEST_FORCE_FP8_MARLIN=1",
|
||||
)
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
def __init__(
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
|
||||
) -> None:
|
||||
super().__init__(c, layer_param_names)
|
||||
self.marlin_input_dtype = None
|
||||
self.block_quant = self.config.weight_quant_key in {kFp8Static128BlockSym}
|
||||
self.size_k_first = not self.block_quant
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if self.block_quant:
|
||||
weight, weight_scale_inv = process_fp8_weight_block_strategy(
|
||||
layer.weight, layer.weight_scale_inv
|
||||
)
|
||||
# Update layer with new values
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
|
||||
else:
|
||||
weight = layer.weight.t()
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
layer.input_scale = None
|
||||
prepare_fp8_layer_for_marlin(
|
||||
layer, self.size_k_first, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
del layer.input_scale
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.block_quant:
|
||||
weight_scale = layer.weight_scale_inv
|
||||
else:
|
||||
weight_scale = layer.weight_scale
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
pass
|
||||
@@ -22,7 +22,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@@ -177,15 +176,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.quant_config.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
@@ -7,7 +7,6 @@ import torch
|
||||
from torch.nn import Module
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
@@ -16,6 +15,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm import MarlinFP8ScaledMMLinearKernel
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
@@ -61,10 +61,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
get_marlin_input_dtype,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
is_layer_skipped,
|
||||
@@ -280,15 +276,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.marlin_input_dtype = None
|
||||
self.use_marlin = (
|
||||
not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||
self.use_marlin = False
|
||||
if vllm_is_batch_invariant():
|
||||
self.use_marlin = False
|
||||
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
self.use_deep_gemm = is_deep_gemm_supported()
|
||||
@@ -297,16 +284,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self.block_quant = self.weight_block_size is not None
|
||||
self.act_q_static = self.quant_config.activation_scheme == "static"
|
||||
|
||||
if self.block_quant:
|
||||
assert not self.act_q_static
|
||||
assert self.weight_block_size is not None
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if self.act_q_static:
|
||||
activation_quant_key = kFp8StaticTensorSym
|
||||
@@ -315,12 +292,28 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
else:
|
||||
activation_quant_key = kFp8DynamicTensorSym
|
||||
|
||||
if self.block_quant:
|
||||
weight_quant_key = kFp8Static128BlockSym
|
||||
else:
|
||||
weight_quant_key = kFp8StaticTensorSym
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=activation_quant_key,
|
||||
weight_quant_key=kFp8StaticTensorSym,
|
||||
weight_quant_key=weight_quant_key,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
|
||||
|
||||
if self.block_quant and not self.use_marlin:
|
||||
assert not self.act_q_static
|
||||
assert self.weight_block_size is not None
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -387,12 +380,18 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("input_scale", scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
size_k_first = True
|
||||
if self.use_marlin:
|
||||
# Only Marlin kernels support `marlin_input_dtype`; guard to avoid
|
||||
# AttributeError if backend selection changes.
|
||||
if hasattr(self.fp8_linear, "marlin_input_dtype"):
|
||||
self.fp8_linear.marlin_input_dtype = self.marlin_input_dtype
|
||||
self.fp8_linear.process_weights_after_loading(layer)
|
||||
return
|
||||
|
||||
input_scale = None
|
||||
# TODO(rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
assert not self.act_q_static
|
||||
size_k_first = False
|
||||
|
||||
weight, weight_scale_inv = process_fp8_weight_block_strategy(
|
||||
layer.weight, layer.weight_scale_inv
|
||||
@@ -411,7 +410,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
if not self.use_marlin:
|
||||
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
|
||||
weight,
|
||||
weight_scale,
|
||||
@@ -432,14 +430,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(
|
||||
layer, size_k_first, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
return
|
||||
|
||||
if self.block_quant:
|
||||
maybe_post_process_fp8_weight_block(layer)
|
||||
|
||||
@@ -486,21 +476,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
||||
|
||||
if self.use_marlin:
|
||||
if self.block_quant:
|
||||
weight_scale = layer.weight_scale_inv
|
||||
else:
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
bias=bias,
|
||||
)
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
@@ -623,18 +599,20 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
|
||||
|
||||
layer.input_scale = None
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||
weight = qweight.t()
|
||||
|
||||
# Update layer with new values.
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
replace_parameter(layer, "weight", qweight.data)
|
||||
replace_parameter(layer, "weight_scale", weight_scale.data)
|
||||
|
||||
if self.use_marlin:
|
||||
size_k_first = True
|
||||
prepare_fp8_layer_for_marlin(
|
||||
layer, size_k_first, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
# Activations not quantized for marlin.
|
||||
# Only Marlin kernels support `marlin_input_dtype`; guard to avoid
|
||||
# AttributeError if backend selection changes.
|
||||
if hasattr(self.fp8_linear, "marlin_input_dtype"):
|
||||
self.fp8_linear.marlin_input_dtype = self.marlin_input_dtype
|
||||
self.fp8_linear.process_weights_after_loading(layer)
|
||||
else:
|
||||
weight = qweight.t()
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
|
||||
# Prevent duplicate processing (e.g., during weight reload)
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
Reference in New Issue
Block a user