[Kernel] some optimizations for dense marlin and moe marlin (#16850)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
Jinzhen Lin
2025-05-06 00:39:30 +08:00
committed by GitHub
parent f62cad6431
commit 1d0c9d6b2d
26 changed files with 3512 additions and 3268 deletions

View File

@@ -22,9 +22,10 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
check_marlin_supports_layer, check_moe_marlin_supports_layer,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_permute_scales, moe_awq_to_marlin_zero_points,
verify_marlin_supported, verify_marlin_supports_shape)
marlin_make_empty_g_idx, marlin_make_workspace_new,
marlin_moe_permute_scales, marlin_permute_scales,
moe_awq_to_marlin_zero_points, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
@@ -267,8 +268,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
requires_grad=False)
# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
layer.workspace = marlin_make_workspace_new(device)
# Repack weights from AWQ format to marlin format.
marlin_qweight = ops.awq_marlin_repack(
@@ -322,6 +322,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: AWQMarlinConfig):
self.quant_config = quant_config
if self.quant_config.weight_bits != 4:
raise ValueError("AWQMoEMethod only supports 4bit now.")
self.quant_type = scalar_types.uint4
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@@ -396,11 +399,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_qzeros, extra_weight_attrs)
device = layer.w13_qweight.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
layer.workspace = torch.zeros((sms * 4, ),
dtype=torch.int,
device=device,
requires_grad=False)
layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
@@ -511,10 +510,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
)
workspace=layer.workspace)