[Kernel] fp4 marlin kernel (#17687)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
Jinzhen Lin
2025-05-11 10:58:49 +08:00
committed by GitHub
parent ca66a1674c
commit d74e5f37bc
21 changed files with 1216 additions and 331 deletions

View File

@@ -304,8 +304,10 @@ class HQQMarlinMethod(LinearMethodBase):
marlin_out = ops.gptq_marlin_gemm(
x,
None,
layer.marlin_qweight,
scales,
None,
zeros,
layer.g_idx,
layer.g_idx_sort_indices,
@@ -315,7 +317,7 @@ class HQQMarlinMethod(LinearMethodBase):
self.output_size_per_partition,
self.input_size_per_partition,
True, # is_k_full
True, # has_zp
False, # use atomic add
True, # use 32-bit reduce
True, # use float zp
)

View File

@@ -17,6 +17,9 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear, is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@@ -24,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@@ -196,7 +200,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 100
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
@@ -278,9 +282,15 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptNvFp4Config):
self.quant_config = quant_config
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
self.use_marlin = False
if not self.cutlass_nvfp4_supported:
raise ValueError("Current platform does not support NVFP4"
" quantization. Please use Blackwell and above.")
if is_fp4_marlin_supported():
self.use_marlin = True
else:
raise ValueError("Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above.")
def create_weights(
self,
@@ -392,12 +402,29 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
if self.use_marlin:
prepare_fp4_layer_for_marlin(layer)
del layer.alpha
del layer.input_scale
del layer.weight_scale_swizzled
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.use_marlin:
return apply_fp4_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale_2=layer.weight_scale_2,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)
output_dtype = x.dtype
# for input only the contracting dimension has a constraint.
@@ -434,6 +461,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def __init__(self, quant_config: ModelOptNvFp4Config):
self.quant_config = quant_config
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
self.use_marlin = False
if not self.cutlass_nvfp4_supported:
if is_fp4_marlin_supported():
self.use_marlin = True
else:
raise ValueError("Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above.")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@@ -442,6 +479,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
raise ValueError("NVFP4 quantization was selected, "
" dynamic quantization is not supported.")
layer.num_experts = num_experts
layer.params_dtype = params_dtype
layer.quant_config = self.quant_config
weight_dtype = torch.uint8
weight_scale_dtype = torch.float8_e4m3fn
@@ -594,7 +633,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
requires_grad=False)
return
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
del layer.g1_alphas
del layer.g2_alphas
del layer.w13_input_scale_quant
del layer.w2_input_scale_quant
del layer.w13_blockscale_swizzled
del layer.w2_blockscale_swizzled
def apply(
self,
@@ -614,6 +661,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
):
if self.use_marlin:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
global_num_experts=global_num_experts,
expert_map=expert_map)
assert activation == "silu", "Only SiLU activation is supported."
assert not apply_router_weight_on_input, (
"Router weight on input is not "

View File

@@ -33,7 +33,7 @@ USE_FP32_REDUCE_DEFAULT = True
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(
has_zp: bool,
has_zp: Optional[bool] = None,
include_fp_type: bool = True,
device_capability: Optional[int] = None,
):
@@ -45,6 +45,16 @@ def query_marlin_supported_quant_types(
if device_capability < 80:
return []
# - has_zp is True: return quant_types that has zero points
# - has_zp is False: return quant_types that has not zero points
# - has_zp is None: both
if has_zp is None:
types0 = query_marlin_supported_quant_types(False, include_fp_type,
device_capability)
types1 = query_marlin_supported_quant_types(True, include_fp_type,
device_capability)
return types0 + types1
if has_zp:
# AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4]
@@ -52,7 +62,7 @@ def query_marlin_supported_quant_types(
# GPTQ style, unsigned + symmetric bias
res = [scalar_types.uint4b8, scalar_types.uint8b128]
if include_fp_type:
res += [scalar_types.float8_e4m3fn]
res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
return res
@@ -394,6 +404,7 @@ def apply_gptq_marlin_linear(
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
@@ -439,6 +450,7 @@ def apply_awq_marlin_linear(
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,

View File

@@ -0,0 +1,277 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import vllm._custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
should_use_atomic_add_reduce)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16]
logger = init_logger(__name__)
def is_fp4_marlin_supported():
return current_platform.has_device_capability(80)
def fp4_marlin_process_scales(marlin_scales):
assert (marlin_scales >= 0).all()
# convert to half first, we would convert to fp8 later
marlin_scales = marlin_scales.to(torch.half)
# 8 is the number of scale number using by one thread
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
marlin_scales.size(0) * 2, -1)
# fit the layout of fp8 dequantization
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
marlin_scales.size(0), -1)
# We assume that weight_scale (FP8-S1E4M3) is always greater
# than or equal to 0. So we can convert
# (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format.
# After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1
# when weight_scale > 0. This allows us to have an exponent bias
# closer to zero after dequantization.
marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
marlin_scales = marlin_scales[:, 1::2].contiguous()
return marlin_scales
def fp4_marlin_process_global_scale(global_scale):
assert global_scale.dtype in [torch.half, torch.bfloat16]
fp4_exponent = 2
if global_scale.dtype == torch.half:
target_exponent = 5
elif global_scale.dtype == torch.bfloat16:
target_exponent = 8
# exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14
# exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126
exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1)
return global_scale * (2.0**(exponent_bias - 7))
def apply_fp4_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor,
workspace: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
# For GPUs that lack FP4 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP4 quantization
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n, )
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
n=size_n,
k=size_k,
device=input.device,
dtype=input.dtype)
output = ops.gptq_marlin_gemm(a=reshaped_x,
c=None,
b_q_weight=weight,
b_scales=weight_scale,
global_scale=weight_scale_2,
b_zeros=None,
g_idx=None,
perm=None,
workspace=workspace,
b_q_type=scalar_types.float4_e2m1f,
size_m=reshaped_x.size(0),
size_n=size_n,
size_k=size_k,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
param_dtype = layer.params_dtype
assert layer.weight.shape == (part_size_n, part_size_k // 2)
device = layer.weight.device
# WORKSPACE
layer.workspace = marlin_make_workspace_new(device)
# WEIGHT
# Repack weights to marlin format
perm = torch.empty(0, dtype=torch.int, device=device)
qweight = layer.weight.view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
perm=perm,
size_k=part_size_k,
size_n=part_size_n,
num_bits=4)
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES
# Permute scales
weight_scale = layer.weight_scale.T.to(param_dtype)
weight_scale = marlin_permute_scales(s=weight_scale,
size_k=part_size_k,
size_n=part_size_n,
group_size=16)
weight_scale = fp4_marlin_process_scales(weight_scale)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2)
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
requires_grad=False)
return
def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
e = layer.num_experts
k = layer.hidden_size
n = layer.intermediate_size_per_partition
# WORKSPACE
device = layer.w13_weight.device
param_dtype = layer.params_dtype
layer.workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT
# Repack weights to marlin format
for name in ["w13_weight", "w2_weight"]:
weight = getattr(layer, name)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
assert weight.shape == (e, size_n, size_k // 2)
for i in range(e):
qweight = weight[i].view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=4)
tensor_list.append(marlin_qweight)
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
weight = torch.nn.Parameter(weight, requires_grad=False)
setattr(layer, name, weight)
# WEIGHT SCALES
# Permute scales
for name in ["w13", "w2"]:
scales = getattr(layer, name + "_weight_scale").to(param_dtype)
global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
for i in range(e):
marlin_scales = marlin_permute_scales(s=scales[i].T,
size_k=size_k,
size_n=size_n,
group_size=16)
marlin_scales = fp4_marlin_process_scales(marlin_scales)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
scales = torch.nn.Parameter(scales, requires_grad=False)
setattr(layer, name + "_weight_scale", scales)
global_scale = fp4_marlin_process_global_scale(global_scale)
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
setattr(layer, name + "_weight_scale_2", global_scale)
def rand_marlin_weight_fp4_like(weight, group_size):
assert group_size > 0
size_n, size_k = weight.shape
device = weight.device
scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6
global_scale = scales.max() / 448
scales = (scales / global_scale).to(torch.float8_e4m3fn)
fp4_weight = torch.randint(0,
256, (size_n, size_k // 2),
dtype=torch.uint8,
device=weight.device)
fp4_weight_part_1 = ((fp4_weight & 0b10000000) |
((fp4_weight & 0b01110000) >> 2))
fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn)
fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6)
fp4_weight2 = fp4_weight << 4
fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) |
((fp4_weight2 & 0b01110000) >> 2))
fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn)
fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6)
weight_ref = torch.cat(
[fp4_weight_part_2.unsqueeze(2),
fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k)
weight_ref = weight_ref * global_scale.to(weight.dtype) * \
scales.repeat_interleave(group_size, 1).to(weight.dtype)
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=size_k,
size_n=size_n,
num_bits=4,
)
marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype),
size_k=size_k,
size_n=size_n,
group_size=group_size)
marlin_scales = fp4_marlin_process_scales(marlin_scales)
global_scale = fp4_marlin_process_global_scale(global_scale)
return weight_ref.T, marlin_qweight, marlin_scales, global_scale

View File

@@ -19,6 +19,20 @@ def is_fp8_marlin_supported():
return current_platform.has_device_capability(80)
def fp8_fused_exponent_bias_into_scales(scales):
fp8_exponent = 4
if scales.dtype == torch.half:
target_exponent = 5
elif scales.dtype == torch.bfloat16:
target_exponent = 8
# exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
# exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1)
s = torch.ones_like(scales) * 2
s = s**exponent_bias
return scales * s
def apply_fp8_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
@@ -44,6 +58,7 @@ def apply_fp8_marlin_linear(
c=None,
b_q_weight=weight,
b_scales=weight_scale,
global_scale=None,
b_zeros=None,
g_idx=None,
perm=None,
@@ -132,8 +147,10 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
# block-wise quantization -> group-wise quantization
# (size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (size_k // block_size[1], size_n)
if not size_k_first:
scales = scales.T.contiguous()
block_n = layer.weight_block_size[0]
scales = scales.T.repeat_interleave(block_n, 1)
scales = scales.repeat_interleave(block_n, 1)
# size_n may not divisible by block_size[0]
scales = scales[:, :part_size_n]
@@ -141,6 +158,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
size_k=part_size_k,
size_n=part_size_n,
group_size=group_size)
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
@@ -239,8 +257,10 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
# block-wise quantization -> group-wise quantization
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (e, size_k // block_size[1], size_n)
if not size_k_first:
scales = scales.permute(0, 2, 1)
block_n = layer.weight_block_size[0]
scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2)
scales = scales.repeat_interleave(block_n, 2)
# size_n may not divisible by block_size[0]
scales = scales[..., :size_n].contiguous()
@@ -302,4 +322,6 @@ def marlin_quant_fp8_torch(weight, group_size):
size_n=size_n,
group_size=group_size)
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
return weight_ref.T, marlin_qweight, marlin_scales