[Kernel] some optimizations for dense marlin and moe marlin (#16850)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
Jinzhen Lin
2025-05-06 00:39:30 +08:00
committed by GitHub
parent f62cad6431
commit 1d0c9d6b2d
26 changed files with 3512 additions and 3268 deletions

View File

@@ -21,8 +21,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, check_moe_marlin_supports_layer,
marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks,
verify_marlin_supported)
marlin_make_workspace_new, marlin_moe_permute_scales,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
@@ -350,6 +350,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4:
self.quant_type = scalar_types.uint4b8
elif self.quant_config.quant_type.size_bits == 8:
self.quant_type = scalar_types.uint8b128
else:
raise ValueError(
"GPTQMarlinMoEMethod only supports int4 and int8 now.")
def create_weights(
self,
@@ -498,11 +505,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
device = layer.w13_qweight.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
layer.workspace = torch.zeros((sms * 4, ),
dtype=torch.int,
device=device,
requires_grad=False)
layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@@ -633,12 +636,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.quant_config.quant_type.size_bits,
workspace=layer.workspace,
is_k_full=self.is_k_full)