[Kernel] some optimizations for dense marlin and moe marlin (#16850)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
Jinzhen Lin
2025-05-06 00:39:30 +08:00
committed by GitHub
parent f62cad6431
commit 1d0c9d6b2d
26 changed files with 3512 additions and 3268 deletions

View File

@@ -22,9 +22,10 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
check_marlin_supports_layer, check_moe_marlin_supports_layer,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_permute_scales, moe_awq_to_marlin_zero_points,
verify_marlin_supported, verify_marlin_supports_shape)
marlin_make_empty_g_idx, marlin_make_workspace_new,
marlin_moe_permute_scales, marlin_permute_scales,
moe_awq_to_marlin_zero_points, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
@@ -267,8 +268,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
requires_grad=False)
# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
layer.workspace = marlin_make_workspace_new(device)
# Repack weights from AWQ format to marlin format.
marlin_qweight = ops.awq_marlin_repack(
@@ -322,6 +322,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: AWQMarlinConfig):
self.quant_config = quant_config
if self.quant_config.weight_bits != 4:
raise ValueError("AWQMoEMethod only supports 4bit now.")
self.quant_type = scalar_types.uint4
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@@ -396,11 +399,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_qzeros, extra_weight_attrs)
device = layer.w13_qweight.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
layer.workspace = torch.zeros((sms * 4, ),
dtype=torch.int,
device=device,
requires_grad=False)
layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
@@ -511,10 +510,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
)
workspace=layer.workspace)

View File

@@ -55,7 +55,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
# required by torch.compile to be torch.nn.Parameter
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
prepare_fp8_layer_for_marlin(layer, strategy="channel")
prepare_fp8_layer_for_marlin(layer)
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
@@ -68,6 +68,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHT
weight = ModelWeightParameter(data=torch.empty(

View File

@@ -21,19 +21,21 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
prepare_moe_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, all_close_1d, convert_to_channelwise,
cutlass_block_fp8_supported, cutlass_fp8_supported,
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize, requantize_with_max_scale)
Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
cutlass_fp8_supported, maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
requantize_with_max_scale)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
ACTIVATION_SCHEMES = ["static", "dynamic"]
@@ -181,10 +183,6 @@ class Fp8LinearMethod(LinearMethodBase):
self.use_marlin = False
self.block_quant = self.quant_config.weight_block_size is not None
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
self.fp8_linear = Fp8LinearOp(
# Default to using per_token quantization if cutlass is supported
use_per_token_if_dynamic=cutlass_fp8_supported())
@@ -203,10 +201,16 @@ class Fp8LinearMethod(LinearMethodBase):
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
if self.block_quant:
tp_size = get_tensor_model_parallel_world_size()
assert self.quant_config.weight_block_size is not None
layer.weight_block_size = self.quant_config.weight_block_size
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
@@ -229,12 +233,6 @@ class Fp8LinearMethod(LinearMethodBase):
f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}.")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
@@ -303,9 +301,11 @@ class Fp8LinearMethod(LinearMethodBase):
return weight
def process_weights_after_loading(self, layer: Module) -> None:
size_k_first = True
# TODO(rob): refactor block quant into separate class.
if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
size_k_first = False
if current_platform.is_fp8_fnuz():
weight, weight_scale_inv, _ = \
normalize_e4m3fn_to_e4m3fnuz(
@@ -321,21 +321,12 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale_inv,
requires_grad=False)
return
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
elif not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
assert weight_scale.numel() == 1
weight_scale = convert_to_channelwise(
weight_scale.expand(len(layer.logical_widths)),
layer.logical_widths)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
@@ -349,20 +340,14 @@ class Fp8LinearMethod(LinearMethodBase):
if self.quant_config.activation_scheme == "static":
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
weight = layer.weight
weight_scale = convert_to_channelwise(layer.weight_scale,
layer.logical_widths)
weight = layer.weight
weight_scale = layer.weight_scale
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
else:
if not self.use_marlin:
# Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight
weight_scale = layer.weight_scale
if current_platform.is_fp8_fnuz():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
@@ -388,7 +373,7 @@ class Fp8LinearMethod(LinearMethodBase):
requires_grad=False)
if self.use_marlin:
prepare_fp8_layer_for_marlin(layer)
prepare_fp8_layer_for_marlin(layer, size_k_first)
# Activations not quantized for marlin.
del layer.input_scale
@@ -444,6 +429,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
# Check for DeepGemm support.
self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM:
@@ -461,10 +454,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
if self.block_quant:
assert self.quant_config.weight_block_size is not None
layer.weight_block_size = self.quant_config.weight_block_size
tp_size = get_tensor_model_parallel_world_size()
block_n, block_k = (
self.quant_config.weight_block_size[0],
@@ -630,10 +630,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale_inv = \
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
return
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
elif not self.quant_config.is_checkpoint_fp8_serialized:
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=fp8_dtype)
@@ -677,8 +675,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
@@ -766,7 +762,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
return
if self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
def apply(
self,
@@ -801,6 +802,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias,
)
if self.use_marlin:
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
global_num_experts=global_num_experts,
expert_map=expert_map)
return fused_experts(
x,
layer.w13_weight,

View File

@@ -21,8 +21,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, check_moe_marlin_supports_layer,
marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks,
verify_marlin_supported)
marlin_make_workspace_new, marlin_moe_permute_scales,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
@@ -350,6 +350,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4:
self.quant_type = scalar_types.uint4b8
elif self.quant_config.quant_type.size_bits == 8:
self.quant_type = scalar_types.uint8b128
else:
raise ValueError(
"GPTQMarlinMoEMethod only supports int4 and int8 now.")
def create_weights(
self,
@@ -498,11 +505,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
device = layer.w13_qweight.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
layer.workspace = torch.zeros((sms * 4, ),
dtype=torch.int,
device=device,
requires_grad=False)
layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@@ -633,12 +636,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.quant_config.quant_type.size_bits,
workspace=layer.workspace,
is_k_full=self.is_k_full)

View File

@@ -8,7 +8,7 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
marlin_make_workspace_new, marlin_permute_scales, marlin_sort_g_idx,
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
@@ -53,8 +53,7 @@ class MarlinLinearKernel(MPLinearKernel):
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
# Allocate marlin workspace.
self.workspace = marlin_make_workspace(c.partition_weight_shape[1],
device)
self.workspace = marlin_make_workspace_new(device)
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
@@ -127,6 +126,5 @@ class MarlinLinearKernel(MPLinearKernel):
wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
has_zp=self.config.zero_points,
is_k_full=self.is_k_full,
bias=bias)

View File

@@ -7,12 +7,15 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
from .quant_utils import pack_cols, unpack_cols
logger = init_logger(__name__)
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
@@ -29,9 +32,11 @@ USE_FP32_REDUCE_DEFAULT = True
# For binary size and compile time, we don't support the same types for with and
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
def query_marlin_supported_quant_types(
has_zp: bool,
include_fp_type: bool = True,
device_capability: Optional[int] = None,
):
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
@@ -42,12 +47,13 @@ def query_marlin_supported_quant_types(has_zp: bool,
if has_zp:
# AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4, scalar_types.uint8]
return [scalar_types.uint4]
else:
# GPTQ style, unsigned + symmetric bias
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
# to add `scalar_types.float8_e4m3fn` here
return [scalar_types.uint4b8, scalar_types.uint8b128]
res = [scalar_types.uint4b8, scalar_types.uint8b128]
if include_fp_type:
res += [scalar_types.float8_e4m3fn]
return res
def _check_marlin_supported(
@@ -62,7 +68,7 @@ def _check_marlin_supported(
capability_tuple.to_int())
supported_types = query_marlin_supported_quant_types(
has_zp, device_capability)
has_zp, True, device_capability)
if quant_type not in supported_types:
return (False, f"Marlin does not support weight_bits = {quant_type}. "
@@ -175,6 +181,17 @@ def marlin_make_workspace(output_size_per_partition: int,
requires_grad=False)
def marlin_make_workspace_new(device: torch.device,
max_blocks_per_sm: int = 1) -> torch.Tensor:
# In the new marlin kernel, we use the num of threadblocks as workspace
# size. The num of threadblocks is is sms_count * max_blocks_per_sm.
sms = torch.cuda.get_device_properties(device).multi_processor_count
return torch.zeros(sms * max_blocks_per_sm,
dtype=torch.int,
device=device,
requires_grad=False)
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (not act_order) or (act_order and not is_row_parallel)
@@ -304,21 +321,50 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return output
def maybe_warn_marlin_atomic_add(device, dtype):
if torch.compiler.is_dynamo_compiling():
return
device_capability = torch.cuda.get_device_capability(device)
if device_capability[0] < 9 and dtype == torch.bfloat16:
logger.info_once(
"You are running Marlin kernel with bf16 on GPUs before SM90. "
"You can consider change to fp16 to achieve better performance "
"if possible.")
def maybe_warn_marlin_atomic_add_env():
if torch.compiler.is_dynamo_compiling():
return
if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
return
logger.info_once(
"Marlin kernel can achieve better performance for small size_n "
"with experimental use_atomic_add feature. "
"You can consider set environment variable "
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.")
def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
dtype: torch.dtype) -> bool:
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
if n >= 2048 or k < 2048 or device.type != "cuda":
return False
# disable atomicAdd reduce by default,
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD or device.type != "cuda":
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD:
maybe_warn_marlin_atomic_add_env()
return False
# sm8x doesn't support atomicAdd + bfloat16 natively
device_capability = torch.cuda.get_device_capability(device)
if device_capability[0] < 9 and dtype == torch.bfloat16:
maybe_warn_marlin_atomic_add(device, dtype)
return False
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
return n < 2048 and k >= 2048
return True
def apply_gptq_marlin_linear(
@@ -332,7 +378,6 @@ def apply_gptq_marlin_linear(
wtype: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
has_zp: bool,
is_k_full: bool,
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
@@ -346,6 +391,7 @@ def apply_gptq_marlin_linear(
dtype=input.dtype)
output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight,
weight_scale,
weight_zp,
@@ -358,7 +404,6 @@ def apply_gptq_marlin_linear(
size_k=input_size_per_partition,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)
@@ -391,6 +436,7 @@ def apply_awq_marlin_linear(
dtype=input.dtype)
output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight,
weight_scale,
weight_zp,
@@ -401,8 +447,6 @@ def apply_awq_marlin_linear(
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=True,
has_zp=True,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)

View File

@@ -6,9 +6,11 @@ import torch
import vllm._custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
should_use_atomic_add_reduce)
from vllm.platforms import current_platform
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@@ -18,30 +20,40 @@ def is_fp8_marlin_supported():
def apply_fp8_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor],
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n, )
output = ops.fp8_marlin_gemm(
a=reshaped_x,
b_q_weight=weight,
b_scales=weight_scale,
workspace=workspace,
num_bits=8,
size_m=reshaped_x.shape[0],
size_n=size_n,
size_k=size_k,
)
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
n=size_n,
k=size_k,
device=input.device,
dtype=input.dtype)
output = ops.gptq_marlin_gemm(a=reshaped_x,
c=None,
b_q_weight=weight,
b_scales=weight_scale,
b_zeros=None,
g_idx=None,
perm=None,
workspace=workspace,
b_q_type=scalar_types.float8_e4m3fn,
size_m=reshaped_x.size(0),
size_n=size_n,
size_k=size_k,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce)
if bias is not None:
output.add_(bias) # In-place add
@@ -50,7 +62,7 @@ def apply_fp8_marlin_linear(
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
strategy: str = "tensor") -> None:
size_k_first: bool = True) -> None:
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
@@ -60,51 +72,234 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
if size_k_first:
assert layer.weight.shape == (part_size_k, part_size_n)
else:
assert layer.weight.shape == (part_size_n, part_size_k)
device = layer.weight.device
# WORKSPACE
layer.workspace = marlin_make_workspace(part_size_n, device)
layer.workspace = marlin_make_workspace_new(device)
# WEIGHT
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32(
layer.weight),
perm=torch.empty(0,
dtype=torch.int,
device=device),
perm = torch.empty(0, dtype=torch.int, device=device)
qweight = pack_fp8_to_int32(layer.weight, size_k_first)
if not size_k_first:
qweight = qweight.T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
perm=perm,
size_k=part_size_k,
size_n=part_size_n,
num_bits=8)
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES
scales = layer.weight_scale.to(layer.orig_dtype)
# Permute scales
if "weight_scale" in dir(layer):
scales = layer.weight_scale.to(layer.orig_dtype)
elif "weight_scale_inv" in dir(layer):
scales = layer.weight_scale_inv.to(layer.orig_dtype)
del layer.weight_scale_inv
if layer.weight_block_size is None:
group_size = -1
else:
group_size = layer.weight_block_size[1]
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
if layer.weight_block_size is None:
if scales.nelement() == 1:
# tensor-wise quantization -> channel-wise quantization
# (1, 1) =>(repeat)=> (1, size_n)
scales = scales.view(1, 1).repeat_interleave(part_size_n, 1)
elif scales.nelement() > 1 and scales.nelement() != part_size_n:
assert part_size_n % scales.nelement() == 0
s_size = scales.nelement()
# tensor-wise quantization (for gate-up proj)
# -> channel-wise quantization
# (1, s_size) =>(repeat)=> (1, size_n)
scales = scales.view(1, s_size)
scales = scales.repeat_interleave(part_size_n // s_size, 1)
else:
# channel-wise quantization
# (1, size_n)
scales = scales.view(1, part_size_n)
else:
# block-wise quantization -> group-wise quantization
# (size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (size_k // block_size[1], size_n)
block_n = layer.weight_block_size[0]
scales = scales.T.repeat_interleave(block_n, 1)
# size_n may not divisible by block_size[0]
scales = scales[:, :part_size_n]
marlin_scales = marlin_permute_scales(s=scales,
size_k=part_size_k,
size_n=part_size_n,
group_size=-1)
group_size=group_size)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
size_k_first: bool = True) -> None:
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
e = layer.num_experts
k = layer.hidden_size
n = layer.intermediate_size_per_partition
# WORKSPACE
device = layer.w13_weight.device
layer.workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT
# Repack weights to marlin format
for name in ["w13_weight", "w2_weight"]:
weight = getattr(layer, name)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
if size_k_first:
assert weight.shape == (e, size_k, size_n)
else:
assert weight.shape == (e, size_n, size_k)
for i in range(e):
qweight = pack_fp8_to_int32(weight[i], size_k_first)
if not size_k_first:
qweight = qweight.T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=8)
tensor_list.append(marlin_qweight)
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
weight = torch.nn.Parameter(weight, requires_grad=False)
setattr(layer, name, weight)
# WEIGHT SCALES
# Permute scales
if layer.weight_block_size is None:
group_size = -1
else:
group_size = layer.weight_block_size[1]
for name in ["w13", "w2"]:
if name + "_weight_scale" in dir(layer):
new_name = name + "_weight_scale"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
elif name + "_weight_scale_inv" in dir(layer):
new_name = name + "_weight_scale_inv"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
if layer.weight_block_size is None:
if scales.nelement() == e:
# tensor-wise quantization -> channel-wise quantization
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2)
elif scales.nelement() > e and scales.nelement() != e * size_n:
assert (e * size_n) % scales.nelement() == 0
s_size = scales.nelement() // e
# tensor-wise quantization (for gate-up proj)
# -> channel-wise quantization
# (e, 1, s_size) =>(repeat)=> (e, 1, size_n)
scales = scales.view(e, 1, s_size)
scales = scales.repeat_interleave(size_n // s_size, 2)
else:
# channel-wise quantization
# (e, 1, size_n)
scales = scales.view(e, 1, size_n)
else:
# block-wise quantization -> group-wise quantization
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (e, size_k // block_size[1], size_n)
block_n = layer.weight_block_size[0]
scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2)
# size_n may not divisible by block_size[0]
scales = scales[..., :size_n].contiguous()
for i in range(e):
marlin_scales = marlin_permute_scales(s=scales[i],
size_k=size_k,
size_n=size_n,
group_size=group_size)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
scales = torch.nn.Parameter(scales, requires_grad=False)
setattr(layer, name + "_weight_scale", scales)
def pack_fp8_to_int32(fp8_tensor: torch.Tensor,
size_k_first: bool = True) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
assert fp8_tensor.shape[0] % 4 == 0
assert fp8_tensor.ndim == 2
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
fp8_tensor = fp8_tensor.contiguous()
# fp8_tensor is contiguous and have shape (N, K) now
# with `.view(torch.int32)`, it become (N, K // 4)
int32_tensor = fp8_tensor.view(torch.int32)
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = (byte_tensor[:, 0].to(torch.int32) |
(byte_tensor[:, 1].to(torch.int32) << 8) |
(byte_tensor[:, 2].to(torch.int32) << 16) |
(byte_tensor[:, 3].to(torch.int32) << 24))
def marlin_quant_fp8_torch(weight, group_size):
size_n, size_k = weight.shape
device = weight.device
return packed.view(fp8_tensor.shape[0] // 4,
*fp8_tensor.shape[1:]).contiguous()
if group_size != -1:
scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
repeated_scales = scales.repeat_interleave(group_size, 1)
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
else:
scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
repeated_scales = scales.repeat_interleave(size_k, 1)
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_weight,
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=size_k,
size_n=size_n,
num_bits=8,
)
marlin_scales = marlin_permute_scales(s=scales.T,
size_k=size_k,
size_n=size_n,
group_size=group_size)
return weight_ref.T, marlin_qweight, marlin_scales