[FP8]add FP8 WoQ kernel abstraction. (#32929)

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
This commit is contained in:
Kunshang Ji
2026-03-23 17:47:47 +08:00
committed by GitHub
parent 35141a7eed
commit 27d5ee3e6f
5 changed files with 177 additions and 83 deletions

View File

@@ -72,6 +72,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.marlin import (
MarlinFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
ChannelWiseTorchFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
@@ -104,6 +107,7 @@ _POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]]
# in priority/performance order (when available)
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = {
PlatformEnum.CUDA: [
MarlinFP8ScaledMMLinearKernel,
FlashInferFP8ScaledMMLinearKernel,
CutlassFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,

View File

@@ -14,6 +14,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.marlin import (
MarlinFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
ChannelWiseTorchFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
@@ -46,6 +49,7 @@ __all__ = [
"CutlassFP8ScaledMMLinearKernel",
"CutlassInt8ScaledMMLinearKernel",
"FlashInferFP8ScaledMMLinearKernel",
"MarlinFP8ScaledMMLinearKernel",
"ChannelWiseTorchFP8ScaledMMLinearKernel",
"PerTensorTorchFP8ScaledMMLinearKernel",
"RowWiseTorchFP8ScaledMMLinearKernel",

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
import torch
import vllm.envs as envs
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_block_strategy,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
is_fp8_marlin_supported,
prepare_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Static128BlockSym,
)
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
class MarlinFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
"""
FP8 Marlin kernel for GPUs that lack FP8 hardware support.
Leverages the Marlin kernel for fast weight-only FP8 quantization.
"""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "requires CUDA."
# Check if platform supports FP8 Marlin
if not is_fp8_marlin_supported():
return False, "FP8 Marlin requires compute capability 7.5 or higher"
if vllm_is_batch_invariant():
return False, "FP8 Marlin not supported for batch invariant execution."
if (
compute_capability is not None
and compute_capability >= 89
and not envs.VLLM_TEST_FORCE_FP8_MARLIN
):
return (
False,
"To apply FP8 Marlin on high-capability GPUs, please set "
"VLLM_TEST_FORCE_FP8_MARLIN=1",
)
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def __init__(
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
) -> None:
super().__init__(c, layer_param_names)
self.marlin_input_dtype = None
self.block_quant = self.config.weight_quant_key in {kFp8Static128BlockSym}
self.size_k_first = not self.block_quant
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.block_quant:
weight, weight_scale_inv = process_fp8_weight_block_strategy(
layer.weight, layer.weight_scale_inv
)
# Update layer with new values
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
else:
weight = layer.weight.t()
replace_parameter(layer, "weight", weight.data)
layer.input_scale = None
prepare_fp8_layer_for_marlin(
layer, self.size_k_first, input_dtype=self.marlin_input_dtype
)
del layer.input_scale
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.block_quant:
weight_scale = layer.weight_scale_inv
else:
weight_scale = layer.weight_scale
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
input_dtype=self.marlin_input_dtype,
bias=bias,
)
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
pass

View File

@@ -22,7 +22,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@@ -177,15 +176,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.quant_config.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
return self.fp8_linear.apply_weights(layer, x, bias)

View File

@@ -7,7 +7,6 @@ import torch
from torch.nn import Module
from torch.utils._python_dispatch import TorchDispatchMode
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
@@ -16,6 +15,7 @@ from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
from vllm.model_executor.kernels.linear.scaled_mm import MarlinFP8ScaledMMLinearKernel
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
@@ -61,10 +61,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
is_layer_skipped,
@@ -280,15 +276,6 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.marlin_input_dtype = None
self.use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
)
# Disable marlin for rocm
if current_platform.is_rocm() or current_platform.is_xpu():
self.use_marlin = False
if vllm_is_batch_invariant():
self.use_marlin = False
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
self.use_deep_gemm = is_deep_gemm_supported()
@@ -297,7 +284,28 @@ class Fp8LinearMethod(LinearMethodBase):
self.block_quant = self.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static"
# Use per-token quantization for better perf if dynamic and cutlass
if self.act_q_static:
activation_quant_key = kFp8StaticTensorSym
elif cutlass_fp8_supported():
activation_quant_key = kFp8DynamicTokenSym
else:
activation_quant_key = kFp8DynamicTensorSym
if self.block_quant:
weight_quant_key = kFp8Static128BlockSym
else:
weight_quant_key = kFp8StaticTensorSym
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
if self.block_quant and not self.use_marlin:
assert not self.act_q_static
assert self.weight_block_size is not None
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
@@ -306,21 +314,6 @@ class Fp8LinearMethod(LinearMethodBase):
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
# Use per-token quantization for better perf if dynamic and cutlass
if self.act_q_static:
activation_quant_key = kFp8StaticTensorSym
elif cutlass_fp8_supported():
activation_quant_key = kFp8DynamicTokenSym
else:
activation_quant_key = kFp8DynamicTensorSym
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=kFp8StaticTensorSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def create_weights(
self,
@@ -387,12 +380,18 @@ class Fp8LinearMethod(LinearMethodBase):
layer.register_parameter("input_scale", scale)
def process_weights_after_loading(self, layer: Module) -> None:
size_k_first = True
if self.use_marlin:
# Only Marlin kernels support `marlin_input_dtype`; guard to avoid
# AttributeError if backend selection changes.
if hasattr(self.fp8_linear, "marlin_input_dtype"):
self.fp8_linear.marlin_input_dtype = self.marlin_input_dtype
self.fp8_linear.process_weights_after_loading(layer)
return
input_scale = None
# TODO(rob): refactor block quant into separate class.
if self.block_quant:
assert not self.act_q_static
size_k_first = False
weight, weight_scale_inv = process_fp8_weight_block_strategy(
layer.weight, layer.weight_scale_inv
@@ -411,16 +410,15 @@ class Fp8LinearMethod(LinearMethodBase):
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
if not self.use_marlin:
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
weight,
weight_scale,
layer.logical_widths,
getattr(layer, "input_scale", None),
)
if self.act_q_static:
assert input_scale is not None
input_scale = input_scale.max()
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
weight,
weight_scale,
layer.logical_widths,
getattr(layer, "input_scale", None),
)
if self.act_q_static:
assert input_scale is not None
input_scale = input_scale.max()
weight = weight.t()
# Update layer with new values.
@@ -432,14 +430,6 @@ class Fp8LinearMethod(LinearMethodBase):
else:
layer.input_scale = None
if self.use_marlin:
prepare_fp8_layer_for_marlin(
layer, size_k_first, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin.
del layer.input_scale
return
if self.block_quant:
maybe_post_process_fp8_weight_block(layer)
@@ -486,21 +476,7 @@ class Fp8LinearMethod(LinearMethodBase):
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
if self.use_marlin:
if self.block_quant:
weight_scale = layer.weight_scale_inv
else:
weight_scale = layer.weight_scale
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
input_dtype=self.marlin_input_dtype,
bias=bias,
)
return self.fp8_linear.apply_weights(layer, x, bias)
if self.block_quant:
assert self.weight_block_size is not None
@@ -623,18 +599,20 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
layer.input_scale = None
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
weight = qweight.t()
# Update layer with new values.
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight", qweight.data)
replace_parameter(layer, "weight_scale", weight_scale.data)
if self.use_marlin:
size_k_first = True
prepare_fp8_layer_for_marlin(
layer, size_k_first, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin.
# Only Marlin kernels support `marlin_input_dtype`; guard to avoid
# AttributeError if backend selection changes.
if hasattr(self.fp8_linear, "marlin_input_dtype"):
self.fp8_linear.marlin_input_dtype = self.marlin_input_dtype
self.fp8_linear.process_weights_after_loading(layer)
else:
weight = qweight.t()
replace_parameter(layer, "weight", weight.data)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True