[Kernel] some optimizations for dense marlin and moe marlin (#16850)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
@@ -22,9 +22,10 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
||||
check_marlin_supports_layer, check_moe_marlin_supports_layer,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
|
||||
marlin_permute_scales, moe_awq_to_marlin_zero_points,
|
||||
verify_marlin_supported, verify_marlin_supports_shape)
|
||||
marlin_make_empty_g_idx, marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales, marlin_permute_scales,
|
||||
moe_awq_to_marlin_zero_points, verify_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
@@ -267,8 +268,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
requires_grad=False)
|
||||
|
||||
# Allocate marlin workspace
|
||||
layer.workspace = marlin_make_workspace(
|
||||
layer.output_size_per_partition, device)
|
||||
layer.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# Repack weights from AWQ format to marlin format.
|
||||
marlin_qweight = ops.awq_marlin_repack(
|
||||
@@ -322,6 +322,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: AWQMarlinConfig):
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.weight_bits != 4:
|
||||
raise ValueError("AWQMoEMethod only supports 4bit now.")
|
||||
self.quant_type = scalar_types.uint4
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@@ -396,11 +399,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
layer.workspace = torch.zeros((sms * 4, ),
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
num_experts = layer.w13_qweight.shape[0]
|
||||
@@ -511,10 +510,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=self.quant_type.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_zeros=layer.w13_qzeros,
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
workspace=layer.workspace,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
)
|
||||
workspace=layer.workspace)
|
||||
|
||||
@@ -55,7 +55,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||
requires_grad=False)
|
||||
prepare_fp8_layer_for_marlin(layer, strategy="channel")
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
output_partition_sizes: List[int],
|
||||
@@ -68,6 +68,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
|
||||
@@ -21,19 +21,21 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
|
||||
prepare_moe_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, all_close_1d, convert_to_channelwise,
|
||||
cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
|
||||
per_tensor_dequantize, requantize_with_max_scale)
|
||||
Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported, maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
@@ -181,10 +183,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self.use_marlin = False
|
||||
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
if self.block_quant:
|
||||
# Marlin doesn't support block-wise fp8
|
||||
self.use_marlin = False
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
# Default to using per_token quantization if cutlass is supported
|
||||
use_per_token_if_dynamic=cutlass_fp8_supported())
|
||||
@@ -203,10 +201,16 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
if self.block_quant:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
layer.weight_block_size = self.quant_config.weight_block_size
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0],
|
||||
self.quant_config.weight_block_size[1],
|
||||
@@ -229,12 +233,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
f"{output_partition_size} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
|
||||
# WEIGHT
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||
@@ -303,9 +301,11 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return weight
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
size_k_first = True
|
||||
# TODO(rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
size_k_first = False
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale_inv, _ = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
@@ -321,21 +321,12 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = Parameter(weight_scale_inv,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||
scale=None)
|
||||
|
||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
||||
# so extend the weight scales to be channelwise.
|
||||
if self.use_marlin:
|
||||
assert weight_scale.numel() == 1
|
||||
weight_scale = convert_to_channelwise(
|
||||
weight_scale.expand(len(layer.logical_widths)),
|
||||
layer.logical_widths)
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
@@ -349,20 +340,14 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||
requires_grad=False)
|
||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
||||
# so extend the weight scales to be channelwise.
|
||||
if self.use_marlin:
|
||||
weight = layer.weight
|
||||
weight_scale = convert_to_channelwise(layer.weight_scale,
|
||||
layer.logical_widths)
|
||||
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
else:
|
||||
if not self.use_marlin:
|
||||
# Dequant -> Quant with max scale so we can run per tensor.
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
@@ -388,7 +373,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
requires_grad=False)
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
prepare_fp8_layer_for_marlin(layer, size_k_first)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
|
||||
@@ -444,6 +429,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
# Check for DeepGemm support.
|
||||
self.allow_deep_gemm = False
|
||||
if envs.VLLM_USE_DEEP_GEMM:
|
||||
@@ -461,10 +454,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
if self.block_quant:
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
layer.weight_block_size = self.quant_config.weight_block_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0],
|
||||
@@ -630,10 +630,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight_scale_inv = \
|
||||
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
|
||||
|
||||
return
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data,
|
||||
dtype=fp8_dtype)
|
||||
@@ -677,8 +675,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
# MoE kernels require single activation scale and single weight
|
||||
# scale for w13 per expert.
|
||||
@@ -766,7 +762,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -801,6 +802,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
|
||||
@@ -21,8 +21,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, check_moe_marlin_supports_layer,
|
||||
marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks,
|
||||
verify_marlin_supported)
|
||||
marlin_make_workspace_new, marlin_moe_permute_scales,
|
||||
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
@@ -350,6 +350,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.quant_type.size_bits == 4:
|
||||
self.quant_type = scalar_types.uint4b8
|
||||
elif self.quant_config.quant_type.size_bits == 8:
|
||||
self.quant_type = scalar_types.uint8b128
|
||||
else:
|
||||
raise ValueError(
|
||||
"GPTQMarlinMoEMethod only supports int4 and int8 now.")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -498,11 +505,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
layer.workspace = torch.zeros((sms * 4, ),
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
@@ -633,12 +636,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=self.quant_type.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
g_idx2=layer.w2_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
num_bits=self.quant_config.quant_type.size_bits,
|
||||
workspace=layer.workspace,
|
||||
is_k_full=self.is_k_full)
|
||||
|
||||
@@ -8,7 +8,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
|
||||
marlin_make_workspace_new, marlin_permute_scales, marlin_sort_g_idx,
|
||||
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
@@ -53,8 +53,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
|
||||
# Allocate marlin workspace.
|
||||
self.workspace = marlin_make_workspace(c.partition_weight_shape[1],
|
||||
device)
|
||||
self.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# Default names since marlin requires empty parameters for these,
|
||||
# TODO: remove this requirement from marlin (allow optional tensors)
|
||||
@@ -127,6 +126,5 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
wtype=c.weight_type,
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
has_zp=self.config.zero_points,
|
||||
is_k_full=self.is_k_full,
|
||||
bias=bias)
|
||||
|
||||
@@ -7,12 +7,15 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
from .quant_utils import pack_cols, unpack_cols
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
GPTQ_MARLIN_TILE = 16
|
||||
GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||
GPTQ_MARLIN_MIN_THREAD_K = 128
|
||||
@@ -29,9 +32,11 @@ USE_FP32_REDUCE_DEFAULT = True
|
||||
# For binary size and compile time, we don't support the same types for with and
|
||||
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
||||
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
||||
def query_marlin_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
def query_marlin_supported_quant_types(
|
||||
has_zp: bool,
|
||||
include_fp_type: bool = True,
|
||||
device_capability: Optional[int] = None,
|
||||
):
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
@@ -42,12 +47,13 @@ def query_marlin_supported_quant_types(has_zp: bool,
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + runtime zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
return [scalar_types.uint4]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
|
||||
# to add `scalar_types.float8_e4m3fn` here
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
res = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
if include_fp_type:
|
||||
res += [scalar_types.float8_e4m3fn]
|
||||
return res
|
||||
|
||||
|
||||
def _check_marlin_supported(
|
||||
@@ -62,7 +68,7 @@ def _check_marlin_supported(
|
||||
capability_tuple.to_int())
|
||||
|
||||
supported_types = query_marlin_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
has_zp, True, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"Marlin does not support weight_bits = {quant_type}. "
|
||||
@@ -175,6 +181,17 @@ def marlin_make_workspace(output_size_per_partition: int,
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def marlin_make_workspace_new(device: torch.device,
|
||||
max_blocks_per_sm: int = 1) -> torch.Tensor:
|
||||
# In the new marlin kernel, we use the num of threadblocks as workspace
|
||||
# size. The num of threadblocks is is sms_count * max_blocks_per_sm.
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
return torch.zeros(sms * max_blocks_per_sm,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
||||
return (not act_order) or (act_order and not is_row_parallel)
|
||||
|
||||
@@ -304,21 +321,50 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
|
||||
return output
|
||||
|
||||
|
||||
def maybe_warn_marlin_atomic_add(device, dtype):
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return
|
||||
device_capability = torch.cuda.get_device_capability(device)
|
||||
if device_capability[0] < 9 and dtype == torch.bfloat16:
|
||||
logger.info_once(
|
||||
"You are running Marlin kernel with bf16 on GPUs before SM90. "
|
||||
"You can consider change to fp16 to achieve better performance "
|
||||
"if possible.")
|
||||
|
||||
|
||||
def maybe_warn_marlin_atomic_add_env():
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return
|
||||
if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
|
||||
return
|
||||
logger.info_once(
|
||||
"Marlin kernel can achieve better performance for small size_n "
|
||||
"with experimental use_atomic_add feature. "
|
||||
"You can consider set environment variable "
|
||||
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.")
|
||||
|
||||
|
||||
def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
|
||||
dtype: torch.dtype) -> bool:
|
||||
|
||||
# the performance of atomicAdd is better than global reduce
|
||||
# only when m*n is small and k is large
|
||||
if n >= 2048 or k < 2048 or device.type != "cuda":
|
||||
return False
|
||||
|
||||
# disable atomicAdd reduce by default,
|
||||
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
|
||||
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD or device.type != "cuda":
|
||||
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD:
|
||||
maybe_warn_marlin_atomic_add_env()
|
||||
return False
|
||||
|
||||
# sm8x doesn't support atomicAdd + bfloat16 natively
|
||||
device_capability = torch.cuda.get_device_capability(device)
|
||||
if device_capability[0] < 9 and dtype == torch.bfloat16:
|
||||
maybe_warn_marlin_atomic_add(device, dtype)
|
||||
return False
|
||||
|
||||
# the performance of atomicAdd is better than global reduce
|
||||
# only when m*n is small and k is large
|
||||
return n < 2048 and k >= 2048
|
||||
return True
|
||||
|
||||
|
||||
def apply_gptq_marlin_linear(
|
||||
@@ -332,7 +378,6 @@ def apply_gptq_marlin_linear(
|
||||
wtype: ScalarType,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
has_zp: bool,
|
||||
is_k_full: bool,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
||||
@@ -346,6 +391,7 @@ def apply_gptq_marlin_linear(
|
||||
dtype=input.dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
weight_scale,
|
||||
weight_zp,
|
||||
@@ -358,7 +404,6 @@ def apply_gptq_marlin_linear(
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
has_zp=has_zp,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
@@ -391,6 +436,7 @@ def apply_awq_marlin_linear(
|
||||
dtype=input.dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
weight_scale,
|
||||
weight_zp,
|
||||
@@ -401,8 +447,6 @@ def apply_awq_marlin_linear(
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=True,
|
||||
has_zp=True,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
@@ -6,9 +6,11 @@ import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
|
||||
should_use_atomic_add_reduce)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -18,30 +20,40 @@ def is_fp8_marlin_supported():
|
||||
|
||||
|
||||
def apply_fp8_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: Optional[torch.Tensor],
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the
|
||||
# Marlin kernel for fast weight-only FP8 quantization
|
||||
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (size_n, )
|
||||
|
||||
output = ops.fp8_marlin_gemm(
|
||||
a=reshaped_x,
|
||||
b_q_weight=weight,
|
||||
b_scales=weight_scale,
|
||||
workspace=workspace,
|
||||
num_bits=8,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
)
|
||||
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
|
||||
n=size_n,
|
||||
k=size_k,
|
||||
device=input.device,
|
||||
dtype=input.dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(a=reshaped_x,
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
b_scales=weight_scale,
|
||||
b_zeros=None,
|
||||
g_idx=None,
|
||||
perm=None,
|
||||
workspace=workspace,
|
||||
b_q_type=scalar_types.float8_e4m3fn,
|
||||
size_m=reshaped_x.size(0),
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
@@ -50,7 +62,7 @@ def apply_fp8_marlin_linear(
|
||||
|
||||
|
||||
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
strategy: str = "tensor") -> None:
|
||||
size_k_first: bool = True) -> None:
|
||||
logger.warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||
@@ -60,51 +72,234 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
|
||||
if size_k_first:
|
||||
assert layer.weight.shape == (part_size_k, part_size_n)
|
||||
else:
|
||||
assert layer.weight.shape == (part_size_n, part_size_k)
|
||||
|
||||
device = layer.weight.device
|
||||
|
||||
# WORKSPACE
|
||||
layer.workspace = marlin_make_workspace(part_size_n, device)
|
||||
layer.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# WEIGHT
|
||||
# Repack weights to marlin format
|
||||
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32(
|
||||
layer.weight),
|
||||
perm=torch.empty(0,
|
||||
dtype=torch.int,
|
||||
device=device),
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
qweight = pack_fp8_to_int32(layer.weight, size_k_first)
|
||||
if not size_k_first:
|
||||
qweight = qweight.T.contiguous()
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
|
||||
perm=perm,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
num_bits=8)
|
||||
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
||||
|
||||
# WEIGHT SCALES
|
||||
scales = layer.weight_scale.to(layer.orig_dtype)
|
||||
# Permute scales
|
||||
if "weight_scale" in dir(layer):
|
||||
scales = layer.weight_scale.to(layer.orig_dtype)
|
||||
elif "weight_scale_inv" in dir(layer):
|
||||
scales = layer.weight_scale_inv.to(layer.orig_dtype)
|
||||
del layer.weight_scale_inv
|
||||
|
||||
if layer.weight_block_size is None:
|
||||
group_size = -1
|
||||
else:
|
||||
group_size = layer.weight_block_size[1]
|
||||
|
||||
# marlin kernel only support channel-wise and group-wise quantization
|
||||
# we need to convert the scales
|
||||
if layer.weight_block_size is None:
|
||||
if scales.nelement() == 1:
|
||||
# tensor-wise quantization -> channel-wise quantization
|
||||
# (1, 1) =>(repeat)=> (1, size_n)
|
||||
scales = scales.view(1, 1).repeat_interleave(part_size_n, 1)
|
||||
elif scales.nelement() > 1 and scales.nelement() != part_size_n:
|
||||
assert part_size_n % scales.nelement() == 0
|
||||
s_size = scales.nelement()
|
||||
# tensor-wise quantization (for gate-up proj)
|
||||
# -> channel-wise quantization
|
||||
# (1, s_size) =>(repeat)=> (1, size_n)
|
||||
scales = scales.view(1, s_size)
|
||||
scales = scales.repeat_interleave(part_size_n // s_size, 1)
|
||||
else:
|
||||
# channel-wise quantization
|
||||
# (1, size_n)
|
||||
scales = scales.view(1, part_size_n)
|
||||
else:
|
||||
# block-wise quantization -> group-wise quantization
|
||||
# (size_k // block_size[1], ceil(size_n / block_size[0]))
|
||||
# =>(repeat)=> (size_k // block_size[1], size_n)
|
||||
block_n = layer.weight_block_size[0]
|
||||
scales = scales.T.repeat_interleave(block_n, 1)
|
||||
# size_n may not divisible by block_size[0]
|
||||
scales = scales[:, :part_size_n]
|
||||
|
||||
marlin_scales = marlin_permute_scales(s=scales,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=-1)
|
||||
group_size=group_size)
|
||||
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
||||
|
||||
|
||||
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
||||
def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
size_k_first: bool = True) -> None:
|
||||
logger.warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
e = layer.num_experts
|
||||
k = layer.hidden_size
|
||||
n = layer.intermediate_size_per_partition
|
||||
|
||||
# WORKSPACE
|
||||
device = layer.w13_weight.device
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
|
||||
# WEIGHT
|
||||
# Repack weights to marlin format
|
||||
for name in ["w13_weight", "w2_weight"]:
|
||||
weight = getattr(layer, name)
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = n * 2, k
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
if size_k_first:
|
||||
assert weight.shape == (e, size_k, size_n)
|
||||
else:
|
||||
assert weight.shape == (e, size_n, size_k)
|
||||
|
||||
for i in range(e):
|
||||
qweight = pack_fp8_to_int32(weight[i], size_k_first)
|
||||
if not size_k_first:
|
||||
qweight = qweight.T.contiguous()
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
|
||||
perm=perm,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=8)
|
||||
tensor_list.append(marlin_qweight)
|
||||
|
||||
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
setattr(layer, name, weight)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
if layer.weight_block_size is None:
|
||||
group_size = -1
|
||||
else:
|
||||
group_size = layer.weight_block_size[1]
|
||||
|
||||
for name in ["w13", "w2"]:
|
||||
if name + "_weight_scale" in dir(layer):
|
||||
new_name = name + "_weight_scale"
|
||||
scales = getattr(layer, new_name).to(layer.orig_dtype)
|
||||
delattr(layer, new_name)
|
||||
elif name + "_weight_scale_inv" in dir(layer):
|
||||
new_name = name + "_weight_scale_inv"
|
||||
scales = getattr(layer, new_name).to(layer.orig_dtype)
|
||||
delattr(layer, new_name)
|
||||
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = n * 2, k
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
# marlin kernel only support channel-wise and group-wise quantization
|
||||
# we need to convert the scales
|
||||
if layer.weight_block_size is None:
|
||||
if scales.nelement() == e:
|
||||
# tensor-wise quantization -> channel-wise quantization
|
||||
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
|
||||
scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2)
|
||||
elif scales.nelement() > e and scales.nelement() != e * size_n:
|
||||
assert (e * size_n) % scales.nelement() == 0
|
||||
s_size = scales.nelement() // e
|
||||
# tensor-wise quantization (for gate-up proj)
|
||||
# -> channel-wise quantization
|
||||
# (e, 1, s_size) =>(repeat)=> (e, 1, size_n)
|
||||
scales = scales.view(e, 1, s_size)
|
||||
scales = scales.repeat_interleave(size_n // s_size, 2)
|
||||
else:
|
||||
# channel-wise quantization
|
||||
# (e, 1, size_n)
|
||||
scales = scales.view(e, 1, size_n)
|
||||
else:
|
||||
# block-wise quantization -> group-wise quantization
|
||||
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
|
||||
# =>(repeat)=> (e, size_k // block_size[1], size_n)
|
||||
block_n = layer.weight_block_size[0]
|
||||
scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2)
|
||||
# size_n may not divisible by block_size[0]
|
||||
scales = scales[..., :size_n].contiguous()
|
||||
|
||||
for i in range(e):
|
||||
marlin_scales = marlin_permute_scales(s=scales[i],
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size)
|
||||
tensor_list.append(marlin_scales)
|
||||
|
||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
scales = torch.nn.Parameter(scales, requires_grad=False)
|
||||
|
||||
setattr(layer, name + "_weight_scale", scales)
|
||||
|
||||
|
||||
def pack_fp8_to_int32(fp8_tensor: torch.Tensor,
|
||||
size_k_first: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Repack FP8 weights to gptq format (packed int32 elements)
|
||||
"""
|
||||
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
||||
assert fp8_tensor.shape[0] % 4 == 0
|
||||
assert fp8_tensor.ndim == 2
|
||||
|
||||
# Reshape to prepare for packing
|
||||
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
||||
fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
|
||||
fp8_tensor = fp8_tensor.contiguous()
|
||||
# fp8_tensor is contiguous and have shape (N, K) now
|
||||
# with `.view(torch.int32)`, it become (N, K // 4)
|
||||
int32_tensor = fp8_tensor.view(torch.int32)
|
||||
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
|
||||
|
||||
# Convert fp8 to uint8 (byte) representation
|
||||
byte_tensor = reshaped.view(torch.uint8)
|
||||
|
||||
# Pack 4 uint8 values into one int32
|
||||
packed = (byte_tensor[:, 0].to(torch.int32) |
|
||||
(byte_tensor[:, 1].to(torch.int32) << 8) |
|
||||
(byte_tensor[:, 2].to(torch.int32) << 16) |
|
||||
(byte_tensor[:, 3].to(torch.int32) << 24))
|
||||
def marlin_quant_fp8_torch(weight, group_size):
|
||||
size_n, size_k = weight.shape
|
||||
device = weight.device
|
||||
|
||||
return packed.view(fp8_tensor.shape[0] // 4,
|
||||
*fp8_tensor.shape[1:]).contiguous()
|
||||
if group_size != -1:
|
||||
scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
|
||||
repeated_scales = scales.repeat_interleave(group_size, 1)
|
||||
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
|
||||
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
|
||||
else:
|
||||
scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
|
||||
repeated_scales = scales.repeat_interleave(size_k, 1)
|
||||
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
|
||||
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
|
||||
|
||||
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=packed_weight,
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
|
||||
marlin_scales = marlin_permute_scales(s=scales.T,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size)
|
||||
|
||||
return weight_ref.T, marlin_qweight, marlin_scales
|
||||
|
||||
Reference in New Issue
Block a user