[Kernel] fp4 marlin kernel (#17687)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
@@ -304,8 +304,10 @@ class HQQMarlinMethod(LinearMethodBase):
|
||||
|
||||
marlin_out = ops.gptq_marlin_gemm(
|
||||
x,
|
||||
None,
|
||||
layer.marlin_qweight,
|
||||
scales,
|
||||
None,
|
||||
zeros,
|
||||
layer.g_idx,
|
||||
layer.g_idx_sort_indices,
|
||||
@@ -315,7 +317,7 @@ class HQQMarlinMethod(LinearMethodBase):
|
||||
self.output_size_per_partition,
|
||||
self.input_size_per_partition,
|
||||
True, # is_k_full
|
||||
True, # has_zp
|
||||
False, # use atomic add
|
||||
True, # use 32-bit reduce
|
||||
True, # use float zp
|
||||
)
|
||||
|
||||
@@ -17,6 +17,9 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear, is_fp4_marlin_supported,
|
||||
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
@@ -24,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -196,7 +200,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 100
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@@ -278,9 +282,15 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
||||
self.use_marlin = False
|
||||
|
||||
if not self.cutlass_nvfp4_supported:
|
||||
raise ValueError("Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and above.")
|
||||
if is_fp4_marlin_supported():
|
||||
self.use_marlin = True
|
||||
else:
|
||||
raise ValueError("Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and"
|
||||
" above.")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -392,12 +402,29 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
del layer.alpha
|
||||
del layer.input_scale
|
||||
del layer.weight_scale_swizzled
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.use_marlin:
|
||||
return apply_fp4_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale_2=layer.weight_scale_2,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
output_dtype = x.dtype
|
||||
|
||||
# for input only the contracting dimension has a constraint.
|
||||
@@ -434,6 +461,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
||||
self.use_marlin = False
|
||||
|
||||
if not self.cutlass_nvfp4_supported:
|
||||
if is_fp4_marlin_supported():
|
||||
self.use_marlin = True
|
||||
else:
|
||||
raise ValueError("Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and"
|
||||
" above.")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@@ -442,6 +479,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
raise ValueError("NVFP4 quantization was selected, "
|
||||
" dynamic quantization is not supported.")
|
||||
|
||||
layer.num_experts = num_experts
|
||||
layer.params_dtype = params_dtype
|
||||
layer.quant_config = self.quant_config
|
||||
weight_dtype = torch.uint8
|
||||
weight_scale_dtype = torch.float8_e4m3fn
|
||||
@@ -594,7 +633,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
|
||||
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
del layer.g1_alphas
|
||||
del layer.g2_alphas
|
||||
del layer.w13_input_scale_quant
|
||||
del layer.w2_input_scale_quant
|
||||
del layer.w13_blockscale_swizzled
|
||||
del layer.w2_blockscale_swizzled
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -614,6 +661,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
):
|
||||
if self.use_marlin:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_scale1=layer.w13_weight_scale_2,
|
||||
global_scale2=layer.w2_weight_scale_2,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert not apply_router_weight_on_input, (
|
||||
"Router weight on input is not "
|
||||
|
||||
@@ -33,7 +33,7 @@ USE_FP32_REDUCE_DEFAULT = True
|
||||
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
||||
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
||||
def query_marlin_supported_quant_types(
|
||||
has_zp: bool,
|
||||
has_zp: Optional[bool] = None,
|
||||
include_fp_type: bool = True,
|
||||
device_capability: Optional[int] = None,
|
||||
):
|
||||
@@ -45,6 +45,16 @@ def query_marlin_supported_quant_types(
|
||||
if device_capability < 80:
|
||||
return []
|
||||
|
||||
# - has_zp is True: return quant_types that has zero points
|
||||
# - has_zp is False: return quant_types that has not zero points
|
||||
# - has_zp is None: both
|
||||
if has_zp is None:
|
||||
types0 = query_marlin_supported_quant_types(False, include_fp_type,
|
||||
device_capability)
|
||||
types1 = query_marlin_supported_quant_types(True, include_fp_type,
|
||||
device_capability)
|
||||
return types0 + types1
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + runtime zero-point
|
||||
return [scalar_types.uint4]
|
||||
@@ -52,7 +62,7 @@ def query_marlin_supported_quant_types(
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
res = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
if include_fp_type:
|
||||
res += [scalar_types.float8_e4m3fn]
|
||||
res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
|
||||
return res
|
||||
|
||||
|
||||
@@ -394,6 +404,7 @@ def apply_gptq_marlin_linear(
|
||||
None,
|
||||
weight,
|
||||
weight_scale,
|
||||
None,
|
||||
weight_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
@@ -439,6 +450,7 @@ def apply_awq_marlin_linear(
|
||||
None,
|
||||
weight,
|
||||
weight_scale,
|
||||
None,
|
||||
weight_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
|
||||
@@ -0,0 +1,277 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
|
||||
should_use_atomic_add_reduce)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_fp4_marlin_supported():
|
||||
return current_platform.has_device_capability(80)
|
||||
|
||||
|
||||
def fp4_marlin_process_scales(marlin_scales):
|
||||
assert (marlin_scales >= 0).all()
|
||||
|
||||
# convert to half first, we would convert to fp8 later
|
||||
marlin_scales = marlin_scales.to(torch.half)
|
||||
|
||||
# 8 is the number of scale number using by one thread
|
||||
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
|
||||
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
|
||||
marlin_scales.size(0) * 2, -1)
|
||||
|
||||
# fit the layout of fp8 dequantization
|
||||
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
|
||||
marlin_scales.size(0), -1)
|
||||
|
||||
# We assume that weight_scale (FP8-S1E4M3) is always greater
|
||||
# than or equal to 0. So we can convert
|
||||
# (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format.
|
||||
# After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1
|
||||
# when weight_scale > 0. This allows us to have an exponent bias
|
||||
# closer to zero after dequantization.
|
||||
|
||||
marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
|
||||
marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
|
||||
marlin_scales = marlin_scales[:, 1::2].contiguous()
|
||||
|
||||
return marlin_scales
|
||||
|
||||
|
||||
def fp4_marlin_process_global_scale(global_scale):
|
||||
assert global_scale.dtype in [torch.half, torch.bfloat16]
|
||||
fp4_exponent = 2
|
||||
if global_scale.dtype == torch.half:
|
||||
target_exponent = 5
|
||||
elif global_scale.dtype == torch.bfloat16:
|
||||
target_exponent = 8
|
||||
# exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14
|
||||
# exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126
|
||||
exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1)
|
||||
return global_scale * (2.0**(exponent_bias - 7))
|
||||
|
||||
|
||||
def apply_fp4_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_scale_2: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
||||
# For GPUs that lack FP4 hardware support, we can leverage the
|
||||
# Marlin kernel for fast weight-only FP4 quantization
|
||||
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (size_n, )
|
||||
|
||||
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
|
||||
n=size_n,
|
||||
k=size_k,
|
||||
device=input.device,
|
||||
dtype=input.dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(a=reshaped_x,
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
b_scales=weight_scale,
|
||||
global_scale=weight_scale_2,
|
||||
b_zeros=None,
|
||||
g_idx=None,
|
||||
perm=None,
|
||||
workspace=workspace,
|
||||
b_q_type=scalar_types.float4_e2m1f,
|
||||
size_m=reshaped_x.size(0),
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
logger.warning_once(
|
||||
"Your GPU does not have native support for FP4 computation but "
|
||||
"FP4 quantization is being used. Weight-only FP4 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
param_dtype = layer.params_dtype
|
||||
|
||||
assert layer.weight.shape == (part_size_n, part_size_k // 2)
|
||||
|
||||
device = layer.weight.device
|
||||
|
||||
# WORKSPACE
|
||||
layer.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# WEIGHT
|
||||
# Repack weights to marlin format
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
qweight = layer.weight.view(torch.int32).T.contiguous()
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
|
||||
perm=perm,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
num_bits=4)
|
||||
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
weight_scale = layer.weight_scale.T.to(param_dtype)
|
||||
weight_scale = marlin_permute_scales(s=weight_scale,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=16)
|
||||
weight_scale = fp4_marlin_process_scales(weight_scale)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
|
||||
weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2)
|
||||
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
|
||||
requires_grad=False)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
logger.warning_once(
|
||||
"Your GPU does not have native support for FP4 computation but "
|
||||
"FP4 quantization is being used. Weight-only FP4 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
e = layer.num_experts
|
||||
k = layer.hidden_size
|
||||
n = layer.intermediate_size_per_partition
|
||||
|
||||
# WORKSPACE
|
||||
device = layer.w13_weight.device
|
||||
param_dtype = layer.params_dtype
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
|
||||
# WEIGHT
|
||||
# Repack weights to marlin format
|
||||
for name in ["w13_weight", "w2_weight"]:
|
||||
weight = getattr(layer, name)
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = n * 2, k
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
assert weight.shape == (e, size_n, size_k // 2)
|
||||
|
||||
for i in range(e):
|
||||
qweight = weight[i].view(torch.int32).T.contiguous()
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
|
||||
perm=perm,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=4)
|
||||
tensor_list.append(marlin_qweight)
|
||||
|
||||
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
setattr(layer, name, weight)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
for name in ["w13", "w2"]:
|
||||
scales = getattr(layer, name + "_weight_scale").to(param_dtype)
|
||||
global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
|
||||
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = n * 2, k
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
for i in range(e):
|
||||
marlin_scales = marlin_permute_scales(s=scales[i].T,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=16)
|
||||
marlin_scales = fp4_marlin_process_scales(marlin_scales)
|
||||
tensor_list.append(marlin_scales)
|
||||
|
||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
scales = torch.nn.Parameter(scales, requires_grad=False)
|
||||
setattr(layer, name + "_weight_scale", scales)
|
||||
|
||||
global_scale = fp4_marlin_process_global_scale(global_scale)
|
||||
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
|
||||
setattr(layer, name + "_weight_scale_2", global_scale)
|
||||
|
||||
|
||||
def rand_marlin_weight_fp4_like(weight, group_size):
|
||||
assert group_size > 0
|
||||
size_n, size_k = weight.shape
|
||||
device = weight.device
|
||||
|
||||
scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6
|
||||
global_scale = scales.max() / 448
|
||||
scales = (scales / global_scale).to(torch.float8_e4m3fn)
|
||||
|
||||
fp4_weight = torch.randint(0,
|
||||
256, (size_n, size_k // 2),
|
||||
dtype=torch.uint8,
|
||||
device=weight.device)
|
||||
fp4_weight_part_1 = ((fp4_weight & 0b10000000) |
|
||||
((fp4_weight & 0b01110000) >> 2))
|
||||
fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn)
|
||||
fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6)
|
||||
|
||||
fp4_weight2 = fp4_weight << 4
|
||||
fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) |
|
||||
((fp4_weight2 & 0b01110000) >> 2))
|
||||
fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn)
|
||||
fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6)
|
||||
|
||||
weight_ref = torch.cat(
|
||||
[fp4_weight_part_2.unsqueeze(2),
|
||||
fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k)
|
||||
weight_ref = weight_ref * global_scale.to(weight.dtype) * \
|
||||
scales.repeat_interleave(group_size, 1).to(weight.dtype)
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=4,
|
||||
)
|
||||
|
||||
marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size)
|
||||
marlin_scales = fp4_marlin_process_scales(marlin_scales)
|
||||
|
||||
global_scale = fp4_marlin_process_global_scale(global_scale)
|
||||
|
||||
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
|
||||
@@ -19,6 +19,20 @@ def is_fp8_marlin_supported():
|
||||
return current_platform.has_device_capability(80)
|
||||
|
||||
|
||||
def fp8_fused_exponent_bias_into_scales(scales):
|
||||
fp8_exponent = 4
|
||||
if scales.dtype == torch.half:
|
||||
target_exponent = 5
|
||||
elif scales.dtype == torch.bfloat16:
|
||||
target_exponent = 8
|
||||
# exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
|
||||
# exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
|
||||
exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1)
|
||||
s = torch.ones_like(scales) * 2
|
||||
s = s**exponent_bias
|
||||
return scales * s
|
||||
|
||||
|
||||
def apply_fp8_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
@@ -44,6 +58,7 @@ def apply_fp8_marlin_linear(
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
b_scales=weight_scale,
|
||||
global_scale=None,
|
||||
b_zeros=None,
|
||||
g_idx=None,
|
||||
perm=None,
|
||||
@@ -132,8 +147,10 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
# block-wise quantization -> group-wise quantization
|
||||
# (size_k // block_size[1], ceil(size_n / block_size[0]))
|
||||
# =>(repeat)=> (size_k // block_size[1], size_n)
|
||||
if not size_k_first:
|
||||
scales = scales.T.contiguous()
|
||||
block_n = layer.weight_block_size[0]
|
||||
scales = scales.T.repeat_interleave(block_n, 1)
|
||||
scales = scales.repeat_interleave(block_n, 1)
|
||||
# size_n may not divisible by block_size[0]
|
||||
scales = scales[:, :part_size_n]
|
||||
|
||||
@@ -141,6 +158,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=group_size)
|
||||
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
|
||||
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
||||
|
||||
|
||||
@@ -239,8 +257,10 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
# block-wise quantization -> group-wise quantization
|
||||
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
|
||||
# =>(repeat)=> (e, size_k // block_size[1], size_n)
|
||||
if not size_k_first:
|
||||
scales = scales.permute(0, 2, 1)
|
||||
block_n = layer.weight_block_size[0]
|
||||
scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2)
|
||||
scales = scales.repeat_interleave(block_n, 2)
|
||||
# size_n may not divisible by block_size[0]
|
||||
scales = scales[..., :size_n].contiguous()
|
||||
|
||||
@@ -302,4 +322,6 @@ def marlin_quant_fp8_torch(weight, group_size):
|
||||
size_n=size_n,
|
||||
group_size=group_size)
|
||||
|
||||
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
|
||||
|
||||
return weight_ref.T, marlin_qweight, marlin_scales
|
||||
|
||||
Reference in New Issue
Block a user