[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
Jinzhen Lin
2025-11-29 23:19:33 +08:00
committed by GitHub
parent fa59fe417f
commit 1656ad3704
46 changed files with 4371 additions and 2240 deletions

View File

@@ -24,7 +24,7 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_in
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new,
marlin_moe_intermediate_size,
maybe_warn_marlin_atomic_add,
marlin_quant_input,
)
from vllm.scalar_type import ScalarType, scalar_types
@@ -65,6 +65,8 @@ def _fused_marlin_moe(
activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None
] = default_activation_func,
input_global_scale1: torch.Tensor | None = None,
input_global_scale2: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None,
@@ -77,6 +79,7 @@ def _fused_marlin_moe(
intermediate_cache13: torch.Tensor | None = None,
intermediate_cache2: torch.Tensor | None = None,
output: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
is_k_full: bool = True,
) -> torch.Tensor:
assert hidden_states.ndim == 2
@@ -106,18 +109,22 @@ def _fused_marlin_moe(
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * num_topk, N))
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
use_atomic_add = (
hidden_states.dtype == torch.half
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
)
a_scales1 = None
gate_up_input = hidden_states
if input_dtype == torch.int8:
gate_up_input, a_scales1 = marlin_quant_input(hidden_states, input_dtype)
if input_global_scale1 is not None:
a_scales1 = a_scales1 * input_global_scale1
elif input_dtype == torch.float8_e4m3fn:
gate_up_input, a_scales1 = marlin_quant_input(hidden_states, input_dtype)
intermediate_cache1 = ops.moe_wna16_marlin_gemm(
hidden_states,
gate_up_input,
intermediate_cache1,
w1,
bias1,
w1_scale,
a_scales1,
global_scale1,
w1_zeros,
g_idx1,
@@ -136,7 +143,7 @@ def _fused_marlin_moe(
size_n=2 * N,
size_k=K,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False,
)
@@ -151,12 +158,25 @@ def _fused_marlin_moe(
if expert_map is not None:
output.zero_()
a_scales2 = None
if input_dtype == torch.int8:
intermediate_cache2, a_scales2 = marlin_quant_input(
intermediate_cache2, input_dtype
)
if input_global_scale2 is not None:
a_scales2 = a_scales2 * input_global_scale2
elif input_dtype == torch.float8_e4m3fn:
intermediate_cache2, a_scales2 = marlin_quant_input(
intermediate_cache2, input_dtype
)
output = ops.moe_wna16_marlin_gemm(
intermediate_cache2,
output,
w2,
bias2,
w2_scale,
a_scales2,
global_scale2,
w2_zeros,
g_idx2,
@@ -175,7 +195,7 @@ def _fused_marlin_moe(
size_n=K,
size_k=N,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False,
)
@@ -203,6 +223,8 @@ def fused_marlin_moe(
] = default_activation_func,
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
expert_map: torch.Tensor | None = None,
input_global_scale1: torch.Tensor | None = None,
input_global_scale2: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None,
@@ -216,6 +238,7 @@ def fused_marlin_moe(
intermediate_cache2: torch.Tensor | None = None,
is_k_full: bool = True,
output: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
inplace: bool = False,
) -> torch.Tensor:
"""
@@ -287,6 +310,9 @@ def fused_marlin_moe(
if M * topk / E / block_size_m < 0.9:
break
if input_dtype is not None and input_dtype.itemsize == 1:
block_size_m = max(block_size_m, 16)
if global_num_experts == -1:
global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
@@ -313,6 +339,8 @@ def fused_marlin_moe(
num_tokens_post_padded=num_tokens_post_padded,
activation=activation,
activation_func=activation_func,
input_global_scale1=input_global_scale1,
input_global_scale2=input_global_scale2,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
@@ -325,6 +353,7 @@ def fused_marlin_moe(
intermediate_cache13=intermediate_cache13,
intermediate_cache2=intermediate_cache2,
output=None,
input_dtype=input_dtype,
is_k_full=is_k_full,
).view(-1, topk, K)

View File

@@ -266,7 +266,7 @@ class AutoRoundConfig(QuantizationConfig):
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig,
AWQMarlinLinearMethod,
AWQMoEMethod,
AWQMarlinMoEMethod,
)
quant_args_marlin = AWQMarlinConfig(
@@ -291,7 +291,7 @@ class AutoRoundConfig(QuantizationConfig):
if isinstance(layer, FusedMoE):
if use_marlin:
return AWQMoEMethod(quant_args_marlin, layer.moe_config)
return AWQMarlinMoEMethod(quant_args_marlin, layer.moe)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
config = {

View File

@@ -106,7 +106,7 @@ class AWQConfig(QuantizationConfig):
return AWQLinearMethod(self)
elif isinstance(layer, FusedMoE):
# Lazy import to avoid circular import.
from .awq_marlin import AWQMarlinConfig, AWQMoEMethod
from .awq_marlin import AWQMarlinConfig, AWQMarlinMoEMethod
from .moe_wna16 import MoeWNA16Config
from .utils.marlin_utils import check_moe_marlin_supports_layer
@@ -136,7 +136,7 @@ class AWQConfig(QuantizationConfig):
awq_marlin_config = AWQMarlinConfig.from_config(
marlin_compatible_config_dict
)
return AWQMoEMethod(awq_marlin_config, layer.moe_config)
return AWQMarlinMoEMethod(awq_marlin_config, layer.moe_config)
return None
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):

View File

@@ -40,6 +40,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
check_marlin_supports_layer,
check_moe_marlin_supports_layer,
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_empty_g_idx,
marlin_make_workspace_new,
marlin_moe_permute_scales,
@@ -69,7 +71,6 @@ class AWQMarlinConfig(QuantizationConfig):
# num_bits -> type
TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
def __init__(
@@ -193,7 +194,9 @@ class AWQMarlinConfig(QuantizationConfig):
return AWQConfig.from_config(self.full_config).get_quant_method(
layer, prefix
)
return AWQMarlinLinearMethod(self)
quant_method = AWQMarlinLinearMethod(self)
quant_method.input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE):
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
@@ -211,7 +214,9 @@ class AWQMarlinConfig(QuantizationConfig):
return MoeWNA16Config.from_config(self.full_config).get_quant_method(
layer, prefix
)
return AWQMoEMethod(self, layer.moe_config)
moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
return None
@classmethod
@@ -270,6 +275,8 @@ class AWQMarlinLinearMethod(LinearMethodBase):
def __init__(self, quant_config: AWQMarlinConfig) -> None:
self.quant_config = quant_config
self.quant_type = scalar_types.uint4
self.input_dtype = None
def create_weights(
self,
@@ -312,6 +319,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
)
num_groups = input_size_per_partition // group_size
layer.num_groups = num_groups
qzeros = PackedvLLMParameter(
data=torch.empty(
@@ -358,12 +366,19 @@ class AWQMarlinLinearMethod(LinearMethodBase):
# Allocate marlin workspace
layer.workspace = marlin_make_workspace_new(device)
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(layer.qweight, layer.qzeros, inplace=True)
layer.scales.data = layer.scales.data * 512
# Repack weights from AWQ format to marlin format.
marlin_qweight = ops.awq_marlin_repack(
layer.qweight,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "qweight", marlin_qweight)
@@ -373,7 +388,16 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups > 1:
marlin_scales, input_global_scale = marlin_act_int8_process_scales(
marlin_scales
)
layer.register_parameter(
"input_global_scale", Parameter(input_global_scale, requires_grad=False)
)
replace_parameter(layer, "scales", marlin_scales)
# Permute zero-points from AWQ format to marlin format.
@@ -382,6 +406,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.num_groups,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "qzeros", marlin_zp)
@@ -409,11 +434,13 @@ class AWQMarlinLinearMethod(LinearMethodBase):
quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
input_global_scale=getattr(layer, "input_global_scale", None),
bias=bias,
input_dtype=self.input_dtype,
)
class AWQMoEMethod(FusedMoEMethodBase):
class AWQMarlinMoEMethod(FusedMoEMethodBase):
def __init__(
self,
quant_config: AWQMarlinConfig,
@@ -422,8 +449,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
super().__init__(moe)
self.quant_config = quant_config
if self.quant_config.weight_bits != 4:
raise ValueError("AWQMoEMethod only supports 4bit now.")
raise ValueError("AWQMarlinMoEMethod only supports 4bit now.")
self.quant_type = scalar_types.uint4
self.input_dtype = None
self.use_marlin = True
def create_weights(
@@ -435,6 +463,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.input_dtype = self.input_dtype
extra_weight_attrs.update(
{
"is_transposed": True,
@@ -468,6 +497,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_groups_w13 = hidden_size // self.quant_config.group_size
num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size
layer.num_groups_w13 = num_groups_w13
layer.num_groups_w2 = num_groups_w2
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
@@ -522,6 +553,21 @@ class AWQMoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(
layer.w13_qweight.view(-1, layer.w13_qweight.size(2)),
layer.w13_qzeros.view(-1, layer.w13_qzeros.size(2)),
inplace=True,
)
ops.marlin_int4_fp8_preprocess(
layer.w2_qweight.view(-1, layer.w2_qweight.size(2)),
layer.w2_qzeros.view(-1, layer.w2_qzeros.size(2)),
inplace=True,
)
layer.w13_scales.data = layer.w13_scales.data * 512
layer.w2_scales.data = layer.w2_scales.data * 512
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
@@ -538,6 +584,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w13_qweight.shape[1],
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
@@ -547,6 +594,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w2_qweight.shape[1],
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
@@ -556,7 +604,16 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
@@ -565,7 +622,17 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.intermediate_size_per_partition,
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
marlin_w13_zp = moe_awq_to_marlin_zero_points(
@@ -573,6 +640,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w13_qzeros.shape[1],
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
@@ -581,6 +649,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w2_qzeros.shape[1],
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
@@ -636,6 +705,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
@@ -643,4 +714,5 @@ class AWQMoEMethod(FusedMoEMethodBase):
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
workspace=layer.workspace,
input_dtype=self.input_dtype,
)

View File

@@ -157,7 +157,9 @@ class CompressedTensorsConfig(QuantizationConfig):
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix)
return CompressedTensorsMoEMethod.get_moe_method(
self, layer, layer_name=prefix
)
return None
def _add_fused_moe_to_target_scheme_map(self):
@@ -547,6 +549,7 @@ class CompressedTensorsConfig(QuantizationConfig):
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
format: str | None = None,
layer_name: str | None = None,
) -> "CompressedTensorsScheme":
# use the per-layer format if defined, otherwise, use global format
format = format if format is not None else self.quant_format
@@ -585,6 +588,7 @@ class CompressedTensorsConfig(QuantizationConfig):
symmetric=weight_quant.symmetric,
group_size=weight_quant.group_size,
actorder=weight_quant.actorder,
layer_name=layer_name,
)
act_quant_format = is_activation_quantization_format(format)
@@ -724,7 +728,10 @@ class CompressedTensorsConfig(QuantizationConfig):
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore
weight_quant=weight_quant, input_quant=input_quant, format=format
weight_quant=weight_quant,
input_quant=input_quant,
format=format,
layer_name=layer_name,
)
# Raise error if device does not support the scheme

View File

@@ -64,6 +64,8 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer,
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_workspace_new,
marlin_moe_permute_scales,
)
@@ -101,7 +103,7 @@ __all__ = [
"CompressedTensorsW8A8Int8MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
"CompressedTensorsW4A4Nvfp4MoeMethod",
"CompressedTensorsW4A4Nvfp4MoEMethod",
"CompressedTensorsW4A8Int8MoEMethod",
]
@@ -111,13 +113,13 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
prefix: str,
layer_name: str,
) -> "CompressedTensorsMoEMethod":
# FusedMoE was made by combining multiple Linears so need to
# make sure quantization config for Linear can target it
quant_config._add_fused_moe_to_target_scheme_map()
unfused_names = [
prefix + proj_name
layer_name + proj_name
for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
]
# TODO: refactor this to use expert_mapping and check all layer numbers
@@ -158,32 +160,40 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
"WNA16MoE is not supported with actorder=group/dynamic."
)
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(quant_config, layer.moe_config)
return CompressedTensorsWNA16MoEMethod(
quant_config, layer.moe_config, layer_name
)
else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(
quant_config, layer.moe_config
quant_config, layer.moe_config, layer_name
)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4Nvfp4MoeMethod(layer.moe_config)
return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name)
elif (
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
or quant_config._is_fp8_w8a8(weight_quant, input_quant)
):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config, layer.moe_config)
return CompressedTensorsW8A8Fp8MoEMethod(
quant_config, layer.moe_config, layer_name
)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(quant_config, layer.moe_config)
return CompressedTensorsW8A8Int8MoEMethod(
quant_config, layer.moe_config, layer_name
)
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
return CompressedTensorsW4A8Int8MoEMethod(quant_config, layer.moe_config)
return CompressedTensorsW4A8Int8MoEMethod(
quant_config, layer.moe_config, layer_name
)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}"
)
class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig):
class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support,
)
@@ -194,17 +204,21 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.group_size = 16
self.layer_name = layer_name
self.marlin_input_dtype = (
get_marlin_input_dtype(layer_name) if self.use_marlin else None
)
self.flashinfer_moe_backend = None
if self.allow_flashinfer:
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
" for CompressedTensorsW4A4Nvfp4MoeMethod."
" for CompressedTensorsW4A4Nvfp4MoEMethod."
)
elif self.use_marlin:
logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoeMethod.")
logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoEMethod.")
else:
logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoeMethod.")
logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoEMethod.")
def create_weights(
self,
@@ -354,7 +368,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
)
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
return
# w13
if (
@@ -538,7 +552,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
):
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
"EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet."
)
return flashinfer_trtllm_fp4_moe(
@@ -576,6 +590,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
@@ -610,7 +625,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
assert expert_map is None, (
"Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4Nvfp4MoeMethod."
"CompressedTensorsW4A4Nvfp4MoEMethod."
)
assert self.moe_quant_config is not None
@@ -637,6 +652,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
@@ -690,6 +706,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
or self.is_fp8_w8a8_sm100
)
self.disable_expert_map = False
self.layer_name = layer_name
self.marlin_input_dtype = (
get_marlin_input_dtype(layer_name) if self.use_marlin else None
)
def create_weights(
self,
@@ -931,7 +951,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
elif self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False)
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
@@ -1144,6 +1166,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
@@ -1240,6 +1263,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
@@ -1392,6 +1416,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
@@ -1403,6 +1428,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
self.strategy = config.strategy
self.group_size = config.group_size
self.actorder = config.actorder
self.layer_name = layer_name
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
assert config.symmetric, "Only symmetric quantization is supported for MoE"
if not (
@@ -1477,6 +1504,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size
layer.num_groups_w13 = num_groups_w13
layer.num_groups_w2 = num_groups_w2
w13_scale = torch.nn.Parameter(
torch.ones(
num_experts,
@@ -1560,6 +1590,17 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_weight_g_idx.shape[0]
device = layer.w13_weight_g_idx.device
is_a_8bit = (
self.marlin_input_dtype is not None
and self.marlin_input_dtype.itemsize == 1
)
if self.marlin_input_dtype == torch.float8_e4m3fn:
# NOTE: for non-zp quantization format only
ops.marlin_int4_fp8_preprocess(layer.w13_weight_packed, inplace=True)
ops.marlin_int4_fp8_preprocess(layer.w2_weight_packed, inplace=True)
layer.w13_weight_scale.data = layer.w13_weight_scale.data * 512
layer.w2_weight_scale.data = layer.w2_weight_scale.data * 512
# when running models with grouped act order,
# resort to g_idx values provided in checkpoint
@@ -1610,31 +1651,54 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight_packed.shape[1] * self.packed_factor,
layer.w13_weight_packed.shape[2],
self.num_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_weight_packed,
layer.w2_g_idx_sort_indices,
layer.w2_weight_packed.shape[1] * self.packed_factor,
layer.w2_weight_packed.shape[2],
self.num_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_weight_scale,
size_k=layer.w13_weight_packed.shape[2],
size_n=layer.w13_weight_scale.shape[2],
group_size=self.group_size,
is_a_8bit=is_a_8bit,
)
if self.marlin_input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_weight_scale,
size_k=layer.w2_weight_scale.shape[1]
* (self.group_size if self.group_size != -1 else self.packed_factor),
size_n=layer.w2_weight_scale.shape[2],
group_size=self.group_size,
is_a_8bit=is_a_8bit,
)
if self.marlin_input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
layer.workspace = marlin_make_workspace_new(device, 4)
@@ -1729,6 +1793,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
router_logits,
topk_weights,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
@@ -1738,6 +1804,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace,
input_dtype=self.marlin_input_dtype,
is_k_full=self.is_k_full,
)
@@ -1747,6 +1814,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
@@ -1999,6 +2067,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.has_bias = self.moe.has_bias

View File

@@ -14,7 +14,11 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig,
choose_mp_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import (
MarlinLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
marlin_repeat_scales_on_all_ranks,
)
from vllm.model_executor.parameter import (
@@ -45,12 +49,14 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
group_size: int | None = None,
symmetric: bool | None = True,
actorder: ActivationOrdering | None = None,
layer_name: str | None = None,
):
self.pack_factor = 32 // num_bits
self.strategy = strategy
self.symmetric = symmetric
self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP
self.layer_name = layer_name
if self.group_size == -1 and self.strategy != "channel":
raise ValueError(
@@ -108,6 +114,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
logger.info("Using %s for CompressedTensorsWNA16", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
if isinstance(kernel_type, MarlinLinearKernel):
input_dtype = get_marlin_input_dtype(self.layer_name)
if input_dtype is not None:
mp_linear_kernel_config.act_type = input_dtype
# If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = input_size != input_size_per_partition

View File

@@ -69,6 +69,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_tensor_strategy,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
@@ -316,7 +319,9 @@ class Fp8Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
quant_method = Fp8LinearMethod(self)
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE):
if is_layer_skipped(
prefix=prefix,
@@ -324,7 +329,9 @@ class Fp8Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
return Fp8MoEMethod(self, layer)
moe_quant_method = Fp8MoEMethod(self, layer)
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
@@ -375,6 +382,7 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.marlin_input_dtype = None
self.use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
@@ -552,7 +560,9 @@ class Fp8LinearMethod(LinearMethodBase):
)
if self.use_marlin:
prepare_fp8_layer_for_marlin(layer, size_k_first)
prepare_fp8_layer_for_marlin(
layer, size_k_first, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin.
del layer.input_scale
return
@@ -610,6 +620,7 @@ class Fp8LinearMethod(LinearMethodBase):
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
input_dtype=self.marlin_input_dtype,
bias=bias,
)
@@ -657,6 +668,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.block_quant, layer.moe_parallel_config
)
self.marlin_input_dtype = None
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
@@ -1031,7 +1043,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight.data = w13_weight.data
if self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False)
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
@@ -1270,6 +1284,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:

View File

@@ -41,6 +41,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
check_moe_marlin_supports_layer,
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_workspace_new,
marlin_moe_permute_scales,
marlin_permute_bias,
@@ -251,8 +253,21 @@ class GPTQMarlinConfig(QuantizationConfig):
return MoeWNA16Config.from_config(self.full_config).get_quant_method(
layer, prefix
)
return get_moe_quant_method(self, layer, prefix, GPTQMarlinMoEMethod)
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
moe_quant_method = get_moe_quant_method(
self, layer, prefix, GPTQMarlinMoEMethod
)
if moe_quant_method is None:
return None
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
quant_method = get_linear_quant_method(
self, layer, prefix, GPTQMarlinLinearMethod
)
if quant_method is None:
return None
quant_method.input_dtype = get_marlin_input_dtype(prefix)
return quant_method
@classmethod
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
@@ -319,6 +334,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
self.input_dtype = None
self.quant_type = self.quant_config.quant_type
# Verify supported on platform.
verify_marlin_supported(
@@ -339,6 +356,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
input_dtype = self.input_dtype
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
@@ -347,7 +365,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition,
),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
act_type=params_dtype if input_dtype is None else input_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
has_g_idx=self.quant_config.desc_act,
@@ -482,6 +500,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
self.quant_type = scalar_types.uint8b128
else:
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
self.input_dtype = None
self.use_marlin = True
def create_weights(
@@ -493,6 +512,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.input_dtype = self.input_dtype
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit:
assert self.quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
self.is_k_full = (not self.quant_config.desc_act) or (
@@ -513,6 +540,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
scales_size2 = 1
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
layer.num_groups_w13 = scales_size13
layer.num_groups_w2 = scales_size2
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
# Fused gate_up_proj (column parallel)
w13_qweight = torch.nn.Parameter(
@@ -630,6 +660,19 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit:
assert self.quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(layer.w13_qweight, inplace=True)
ops.marlin_int4_fp8_preprocess(layer.w2_qweight, inplace=True)
layer.w13_scales.data = layer.w13_scales.data * 512
layer.w2_scales.data = layer.w2_scales.data * 512
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
@@ -678,6 +721,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
@@ -686,6 +730,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Repack scales
@@ -694,7 +739,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
@@ -706,7 +761,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
),
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
@@ -761,6 +826,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
@@ -771,4 +838,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace,
is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
)

View File

@@ -351,6 +351,7 @@ class HQQMarlinMethod(LinearMethodBase):
bias,
scales,
None,
None,
zeros,
layer.g_idx,
layer.g_idx_sort_indices,

View File

@@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES,
apply_gptq_marlin_linear,
check_marlin_supports_shape,
marlin_act_int8_process_scales,
marlin_is_k_full,
marlin_make_empty_g_idx,
marlin_make_workspace_new,
@@ -21,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
@@ -65,6 +67,18 @@ class MarlinLinearKernel(MPLinearKernel):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
is_a_8bit = c.act_type is not None and c.act_type.itemsize == 1
if is_a_8bit:
assert c.weight_type == scalar_types.uint4b8, (
"W8A8 is not supported by marlin kernel."
)
if c.act_type == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(getattr(layer, self.w_q_name), inplace=True)
getattr(layer, self.w_s_name).data = (
getattr(layer, self.w_s_name).data * 512
)
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
@@ -88,6 +102,7 @@ class MarlinLinearKernel(MPLinearKernel):
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
is_a_8bit=is_a_8bit,
)
return x
@@ -99,7 +114,22 @@ class MarlinLinearKernel(MPLinearKernel):
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size,
is_a_8bit=is_a_8bit,
)
if c.group_size == -1:
num_groups = 1
else:
num_groups = c.partition_weight_shape[0] // c.group_size
if c.act_type == torch.int8 and num_groups > 1:
x.data, input_global_scale = marlin_act_int8_process_scales(x.data)
layer.register_parameter(
"input_global_scale",
torch.nn.Parameter(input_global_scale, requires_grad=False),
)
else:
layer.input_global_scale = None
return x
if c.has_g_idx:
@@ -129,6 +159,7 @@ class MarlinLinearKernel(MPLinearKernel):
size_k=grouped_k,
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
is_a_8bit=is_a_8bit,
),
)
else:
@@ -150,6 +181,7 @@ class MarlinLinearKernel(MPLinearKernel):
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin
return apply_gptq_marlin_linear(
input=x,
weight=w_q,
@@ -162,5 +194,7 @@ class MarlinLinearKernel(MPLinearKernel):
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
is_k_full=self.is_k_full,
input_global_scale=getattr(layer, "input_global_scale", None),
bias=bias,
input_dtype=c.act_type,
)

View File

@@ -55,6 +55,9 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
select_cutlass_fp8_gemm_impl,
swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear,
is_fp4_marlin_supported,
@@ -170,9 +173,15 @@ class ModelOptQuantConfigBase(QuantizationConfig):
# now, the layer is quantized, handle it here
if isinstance(layer, LinearBase):
return self.LinearMethodCls(self)
quant_method = self.LinearMethodCls(self)
if getattr(quant_method, "backend", "") == "marlin":
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE):
return self.FusedMoEMethodCls(quant_config=self, layer=layer)
quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer)
if getattr(quant_method, "backend", "") == "marlin":
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
return None
@@ -898,6 +907,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config
self.marlin_input_dtype = None
self.backend = "none"
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
@@ -1065,6 +1075,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
input_dtype=self.marlin_input_dtype,
)
output_dtype = x.dtype
@@ -1124,6 +1135,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.marlin_input_dtype = None
self.flashinfer_moe_backend = None
if self.allow_flashinfer:
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
@@ -1517,7 +1529,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
workspace=layer.workspace,
input_dtype=self.marlin_input_dtype,
)
elif self.allow_flashinfer:

View File

@@ -38,6 +38,9 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin,
)
@@ -205,7 +208,9 @@ class Mxfp4Config(QuantizationConfig):
if current_platform.is_xpu():
return IpexMxfp4MoEMethod(layer.moe_config)
else:
return Mxfp4MoEMethod(layer.moe_config)
quant_method = Mxfp4MoEMethod(layer.moe_config)
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, Attention):
# TODO: Add support for MXFP4 Attention.
logger.debug_once(
@@ -220,6 +225,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
self.marlin_input_dtype = None
self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
@@ -385,7 +392,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer):
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
prepare_moe_fp4_layer_for_marlin(layer)
prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
@@ -914,6 +921,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
)
assert _can_support_mxfp4(

View File

@@ -9,6 +9,11 @@ import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
@@ -286,10 +291,10 @@ def get_scale_perms():
def marlin_permute_scales(
s: torch.Tensor, size_k: int, size_n: int, group_size: int
s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False
) -> torch.Tensor:
scale_perm, scale_perm_single = get_scale_perms()
if group_size < size_k and group_size != -1:
if group_size < size_k and group_size != -1 and not is_a_8bit:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
@@ -305,11 +310,15 @@ def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
return s.reshape(*origin_shape).contiguous()
def marlin_act_int8_process_scales(s: torch.Tensor):
a_scales_scale_factor = 1 / 4096 * s.max().float()
s = s / s.max() * 4096
s = s.round().to(torch.int16).view(s.dtype)
return s, a_scales_scale_factor
def marlin_moe_permute_scales(
s: torch.Tensor,
size_k: int,
size_n: int,
group_size: int,
s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False
):
num_experts = s.shape[0]
output = torch.empty(
@@ -319,12 +328,12 @@ def marlin_moe_permute_scales(
)
for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size, is_a_8bit)
return output
def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int, is_a_8bit: bool = False
) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
@@ -339,7 +348,8 @@ def marlin_zero_points(
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
if not is_a_8bit:
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
@@ -347,7 +357,11 @@ def marlin_zero_points(
def awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
q_zp_packed: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
@@ -366,12 +380,16 @@ def awq_to_marlin_zero_points(
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits, is_a_8bit)
return marlin_zp
def moe_awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
q_zp_packed: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
):
num_experts = q_zp_packed.shape[0]
output = torch.empty(
@@ -380,7 +398,9 @@ def moe_awq_to_marlin_zero_points(
dtype=q_zp_packed.dtype,
)
for e in range(num_experts):
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
output[e] = awq_to_marlin_zero_points(
q_zp_packed[e], size_k, size_n, num_bits, is_a_8bit
)
return output
@@ -432,6 +452,48 @@ def should_use_atomic_add_reduce(
return True
_quant_fp8_method: QuantFP8 | None = None
def get__quant_fp8_method() -> QuantFP8:
global _quant_fp8_method
if _quant_fp8_method is None:
_quant_fp8_method = QuantFP8(False, GroupShape.PER_TOKEN)
return _quant_fp8_method
def get_marlin_input_dtype(prefix):
if envs.VLLM_MARLIN_INPUT_DTYPE is None:
return
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "int8":
return torch.int8
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "fp8":
if not current_platform.is_device_capability(
89
) and not current_platform.is_device_capability(120):
raise ValueError(
"Marlin W4A8-FP8 only support SM89 or SM120 device "
"(It is slower than Marlin W4A16 on other devices). "
"You can consider using W4A8-INT8 instead"
"(set VLLM_MARLIN_INPUT_DTYPE=int8)."
)
_ = get__quant_fp8_method()
return torch.float8_e4m3fn
else:
return
def marlin_quant_input(x: torch.Tensor, quant_dtype: torch.dtype):
x = x.reshape(-1, x.shape[-1])
if quant_dtype == torch.int8:
return per_token_quant_int8(x)
elif quant_dtype == torch.float8_e4m3fn:
return get__quant_fp8_method()(x)
else:
raise ValueError(f"unsupported quant_dtype {quant_dtype}")
def apply_gptq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
@@ -444,8 +506,10 @@ def apply_gptq_marlin_linear(
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
@@ -458,12 +522,27 @@ def apply_gptq_marlin_linear(
dtype=input.dtype,
)
a_scales = None
if input_dtype == torch.int8:
assert wtype == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
assert wtype == scalar_types.uint4b8, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
a_scales,
None,
weight_zp,
g_idx,
@@ -493,8 +572,10 @@ def apply_awq_marlin_linear(
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
@@ -507,12 +588,20 @@ def apply_awq_marlin_linear(
dtype=input.dtype,
)
a_scales = None
if input_dtype == torch.int8:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
a_scales,
None,
weight_zp,
g_idx,
@@ -538,8 +627,10 @@ def apply_rtn_marlin_linear(
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
@@ -552,12 +643,20 @@ def apply_rtn_marlin_linear(
dtype=input.dtype,
)
a_scales = None
if input_dtype == torch.int8:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
a_scales,
None,
None,
None,

View File

@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new,
marlin_permute_bias,
marlin_permute_scales,
marlin_quant_input,
should_use_atomic_add_reduce,
)
from vllm.platforms import current_platform
@@ -37,12 +38,6 @@ def nvfp4_marlin_process_scales(marlin_scales):
# convert to half first, we would convert to fp8 later
marlin_scales = marlin_scales.to(torch.half)
# 8 is the number of scale number using by one thread
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
marlin_scales.size(0) * 2, -1
)
# fit the layout of fp8 dequantization
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
marlin_scales.size(0), -1
@@ -62,18 +57,20 @@ def nvfp4_marlin_process_scales(marlin_scales):
return marlin_scales
def mxfp4_marlin_process_scales(marlin_scales):
# 8 is the number of scale number using by one thread
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
marlin_scales.size(0) * 2, -1
)
def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None):
# fit the layout of fp8 dequantization
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
marlin_scales.size(0), -1
)
if input_dtype is None or input_dtype.itemsize == 2:
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
marlin_scales.size(0), -1
)
marlin_scales = marlin_scales.to(torch.float8_e8m0fnu)
if input_dtype == torch.float8_e4m3fn:
marlin_scales = marlin_scales.view(torch.uint8)
assert marlin_scales.max() <= 249
# exponent_bias (fp4->fp8) = 2 ** 3 - 2 ** 1 = 6
marlin_scales = marlin_scales + 6
marlin_scales = marlin_scales.view(torch.float8_e8m0fnu)
return marlin_scales
@@ -99,6 +96,7 @@ def apply_fp4_marlin_linear(
size_n: int,
size_k: int,
bias: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
# For GPUs that lack FP4 hardware support, we can leverage the
@@ -111,12 +109,24 @@ def apply_fp4_marlin_linear(
m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype
)
inputs = reshaped_x
a_scales = None
is_nvfp4 = weight_scale_2 is not None
if input_dtype is not None and input_dtype.itemsize == 1:
if is_nvfp4:
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
elif input_dtype != torch.float8_e4m3fn:
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
output = ops.gptq_marlin_gemm(
a=reshaped_x,
a=inputs,
c=None,
b_q_weight=weight,
b_bias=bias,
b_scales=weight_scale,
a_scales=a_scales,
global_scale=weight_scale_2,
b_zeros=None,
g_idx=None,
@@ -133,7 +143,9 @@ def apply_fp4_marlin_linear(
return output.reshape(out_shape)
def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
def prepare_fp4_layer_for_marlin(
layer: torch.nn.Module, input_dtype: torch.dtype | None = None
) -> None:
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will "
@@ -160,12 +172,14 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
perm = torch.empty(0, dtype=torch.int, device=device)
qweight = layer.weight.view(torch.int32).T.contiguous()
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight,
perm=perm,
size_k=part_size_k,
size_n=part_size_n,
num_bits=4,
is_a_8bit=is_a_8bit,
)
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
@@ -178,7 +192,11 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
weight_scale = weight_scale.to(param_dtype)
weight_scale = marlin_permute_scales(
s=weight_scale, size_k=part_size_k, size_n=part_size_n, group_size=group_size
s=weight_scale,
size_k=part_size_k,
size_n=part_size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
)
if is_nvfp4:
@@ -189,7 +207,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2)
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False)
else:
weight_scale = mxfp4_marlin_process_scales(weight_scale)
weight_scale = mxfp4_marlin_process_scales(
weight_scale, input_dtype=input_dtype
)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
if hasattr(layer, "bias") and layer.bias is not None:
@@ -200,7 +220,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
return
def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
def prepare_moe_fp4_layer_for_marlin(
layer: torch.nn.Module, input_dtype: torch.dtype | None = None
) -> None:
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will "
@@ -220,6 +242,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
param_dtype = layer.params_dtype
layer.workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device)
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
# WEIGHT
# Repack weights to marlin format
@@ -237,7 +260,12 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
qweight = weight[i].view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=4
b_q_weight=qweight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=4,
is_a_8bit=is_a_8bit,
)
tensor_list.append(marlin_qweight)
@@ -266,12 +294,18 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
scale = scales[i].T
marlin_scales = marlin_permute_scales(
s=scale, size_k=size_k, size_n=size_n, group_size=group_size
s=scale,
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
)
if is_nvfp4:
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
else:
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
marlin_scales = mxfp4_marlin_process_scales(
marlin_scales, input_dtype=input_dtype
)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
@@ -301,7 +335,10 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
setattr(layer, name, bias)
def rand_marlin_weight_nvfp4_like(weight, group_size):
def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
assert not is_a_8bit, "NVFP4 weight + INT8/FP8 activation is not supported."
assert group_size > 0
size_n, size_k = weight.shape
device = weight.device
@@ -337,10 +374,15 @@ def rand_marlin_weight_nvfp4_like(weight, group_size):
size_k=size_k,
size_n=size_n,
num_bits=4,
is_a_8bit=is_a_8bit,
)
marlin_scales = marlin_permute_scales(
s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size
s=scales.T.to(weight.dtype),
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
)
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
@@ -349,14 +391,20 @@ def rand_marlin_weight_nvfp4_like(weight, group_size):
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
def rand_marlin_weight_mxfp4_like(weight, group_size):
def rand_marlin_weight_mxfp4_like(weight, group_size, input_dtype=None):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
if is_a_8bit:
assert input_dtype == torch.float8_e4m3fn, (
"MXFP4 weight + INT8 activation is not supported."
)
assert group_size > 0
size_n, size_k = weight.shape
device = weight.device
scales = torch.randint(
100,
125,
110,
120,
(size_n, size_k // group_size),
dtype=torch.uint8,
device=weight.device,
@@ -380,18 +428,25 @@ def rand_marlin_weight_mxfp4_like(weight, group_size):
).view(size_n, size_k)
weight_ref = weight_ref * scales.repeat_interleave(group_size, 1).to(weight.dtype)
perm = torch.empty(0, dtype=torch.int, device=device)
fp4_weight = fp4_weight.view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
perm=torch.empty(0, dtype=torch.int, device=device),
b_q_weight=fp4_weight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=4,
is_a_8bit=is_a_8bit,
)
marlin_scales = marlin_permute_scales(
s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size
s=scales.T.to(weight.dtype),
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
)
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
marlin_scales = mxfp4_marlin_process_scales(marlin_scales, input_dtype=input_dtype)
return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu)

View File

@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new,
marlin_permute_bias,
marlin_permute_scales,
marlin_quant_input,
should_use_atomic_add_reduce,
)
from vllm.platforms import current_platform
@@ -45,6 +46,7 @@ def apply_fp8_marlin_linear(
size_n: int,
size_k: int,
bias: torch.Tensor | None,
input_dtype: torch.dtype | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
# For GPUs that lack FP8 hardware support, we can leverage the
@@ -57,12 +59,21 @@ def apply_fp8_marlin_linear(
m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype
)
inputs = reshaped_x
a_scales = None
if input_dtype is not None and input_dtype.itemsize == 1:
if input_dtype != torch.float8_e4m3fn:
raise RuntimeError("FP8 weight + INT8 activation is not supported.")
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
output = ops.gptq_marlin_gemm(
a=reshaped_x,
c=None,
b_q_weight=weight,
b_bias=bias,
b_scales=weight_scale,
a_scales=a_scales,
global_scale=None,
b_zeros=None,
g_idx=None,
@@ -80,7 +91,9 @@ def apply_fp8_marlin_linear(
def prepare_fp8_layer_for_marlin(
layer: torch.nn.Module, size_k_first: bool = True
layer: torch.nn.Module,
size_k_first: bool = True,
input_dtype: torch.dtype | None = None,
) -> None:
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
@@ -162,7 +175,8 @@ def prepare_fp8_layer_for_marlin(
marlin_scales = marlin_permute_scales(
s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size
)
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
if input_dtype != torch.float8_e4m3fn:
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
if hasattr(layer, "bias") and layer.bias is not None:
@@ -172,7 +186,9 @@ def prepare_fp8_layer_for_marlin(
def prepare_moe_fp8_layer_for_marlin(
layer: torch.nn.Module, size_k_first: bool = True
layer: torch.nn.Module,
size_k_first: bool = True,
input_dtype: torch.dtype | None = None,
) -> None:
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
@@ -278,7 +294,8 @@ def prepare_moe_fp8_layer_for_marlin(
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
scales = fp8_fused_exponent_bias_into_scales(scales)
if input_dtype != torch.float8_e4m3fn:
scales = fp8_fused_exponent_bias_into_scales(scales)
scales = torch.nn.Parameter(scales, requires_grad=False)
setattr(layer, name + "_weight_scale", scales)
@@ -318,7 +335,11 @@ def pack_fp8_to_int32(
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
def marlin_quant_fp8_torch(weight, group_size):
def marlin_quant_fp8_torch(weight, group_size, input_dtype=None):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
if is_a_8bit:
assert input_dtype == torch.float8_e4m3fn
size_n, size_k = weight.shape
device = weight.device
@@ -334,16 +355,22 @@ def marlin_quant_fp8_torch(weight, group_size):
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
perm = torch.empty(0, dtype=torch.int, device=device)
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_weight,
perm=torch.empty(0, dtype=torch.int, device=device),
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=8,
is_a_8bit=is_a_8bit,
)
marlin_scales = marlin_permute_scales(
s=scales.T, size_k=size_k, size_n=size_n, group_size=group_size
s=scales.T,
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
)
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)

View File

@@ -5,7 +5,8 @@
import numpy as np
import torch
from vllm.scalar_type import ScalarType
from vllm import _custom_ops as ops
from vllm.scalar_type import ScalarType, scalar_types
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
from .quant_utils import (
@@ -29,13 +30,19 @@ class MarlinWorkspace:
self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
def marlin_permute_weights(
q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE, is_a_8bit=False
):
assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
if is_a_8bit:
# Permute weights to 32x32 marlin tiles
q_w = q_w.reshape((size_k // (tile * 2), tile * 2, size_n // tile, tile))
else:
# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // tile, size_n * tile))
@@ -44,9 +51,9 @@ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
return q_w
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
def marlin_weights(q_w, size_k, size_n, num_bits, perm, is_a_8bit=False):
# Permute
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
q_w = marlin_permute_weights(q_w, size_k, size_n, perm, is_a_8bit=is_a_8bit)
# Pack
pack_factor = get_pack_factor(num_bits)
@@ -63,28 +70,53 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
return q_packed
def get_weight_perm(num_bits: int):
def get_weight_perm(num_bits: int, is_a_8bit: bool = False):
perm_list: list[int] = []
for i in range(32):
perm1: list[int] = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
if is_a_8bit:
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
4 * (i % 4),
4 * (i % 4) + 1,
4 * (i % 4) + 2,
4 * (i % 4) + 3,
4 * (i % 4 + 4),
4 * (i % 4 + 4) + 1,
4 * (i % 4 + 4) + 2,
4 * (i % 4 + 4) + 3,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(2):
perm_list.extend([p + 512 * j for p in perm1])
else:
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = np.array(perm_list)
if num_bits == 4:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
if is_a_8bit: # noqa: SIM108
interleave = np.array([0, 4, 1, 5, 2, 6, 3, 7])
else:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = np.array([0, 2, 1, 3])
if is_a_8bit: # noqa: SIM108
interleave = np.array([0, 1, 2, 3])
else:
interleave = np.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
@@ -99,7 +131,10 @@ def marlin_quantize(
group_size: int,
act_order: bool,
test_perm: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
size_k, size_n = w.shape
num_bits = quant_type.size_bits
@@ -120,9 +155,15 @@ def marlin_quantize(
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Reformat to marlin
weight_perm = get_weight_perm(num_bits)
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
weight_perm = get_weight_perm(num_bits, is_a_8bit)
marlin_q_w = marlin_weights(
q_w, size_k, size_n, num_bits, weight_perm, is_a_8bit=is_a_8bit
)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, is_a_8bit=is_a_8bit)
if input_dtype == torch.float8_e4m3fn and quant_type == scalar_types.uint4b8:
ops.marlin_int4_fp8_preprocess(marlin_q_w, inplace=True)
marlin_s = marlin_s * 512
# Create result
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
@@ -132,7 +173,13 @@ def marlin_quantize(
return res_list
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
def awq_marlin_quantize(
w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
input_dtype: torch.dtype | None = None,
):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
size_k, size_n = w.shape
# Normalize group_size
@@ -147,11 +194,22 @@ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int
# Quantize with zp
w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
if input_dtype == torch.float8_e4m3fn and quant_type == scalar_types.uint4:
repeated_zp = zp.repeat_interleave(group_size, 0)
q_w_old = q_w
q_w = q_w_old - repeated_zp
q_w[q_w < 0] = 15 - q_w_old[q_w < 0]
s = s * 512
# Reformat to marlin
weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
marlin_q_w = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit=is_a_8bit
)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, is_a_8bit=is_a_8bit)
marlin_zp = marlin_zero_points(
zp, num_groups, size_n, quant_type.size_bits, is_a_8bit=is_a_8bit
)
# Create result
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]