[Kernel] moe wna16 marlin kernel (#14447)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <michael@neuralmagic.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -17,14 +17,13 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
||||
is_layer_skipped_awq)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
||||
check_marlin_supports_layer, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales,
|
||||
moe_awq_to_marlin_zero_points, verify_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
check_marlin_supports_layer, check_moe_marlin_supports_layer,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
|
||||
marlin_permute_scales, moe_awq_to_marlin_zero_points,
|
||||
verify_marlin_supported, verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
@@ -136,12 +135,15 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
return AWQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if layer.local_num_experts > 32:
|
||||
# For MoEs with many experts the moe_wna16 kernel is faster
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
if not check_moe_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_one(
|
||||
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
|
||||
"Falling back to Moe WNA16 kernels.")
|
||||
return MoeWNA16Config.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
else:
|
||||
return AWQMoEMethod(self)
|
||||
return AWQMoEMethod(self)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -391,6 +393,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
layer.workspace = torch.zeros((sms * 4, ),
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
num_experts = layer.w13_qweight.shape[0]
|
||||
device = layer.w13_qweight.device
|
||||
@@ -473,10 +482,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if expert_map is not None:
|
||||
raise NotImplementedError(
|
||||
"Expert Parallelism is not supported for "
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for"
|
||||
@@ -503,7 +509,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_zeros=layer.w13_qzeros,
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
workspace=layer.workspace,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
)
|
||||
|
||||
@@ -15,13 +15,13 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, marlin_moe_permute_scales,
|
||||
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
|
||||
check_marlin_supported, check_moe_marlin_supports_layer,
|
||||
marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks,
|
||||
verify_marlin_supported)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
@@ -153,12 +153,15 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.local_num_experts > 32:
|
||||
# For MoEs with many experts the moe_wna16 kernel is faster
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
if not check_moe_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_one(
|
||||
f"Layer '{prefix}' is not supported by GPTQMoeMarlin. "
|
||||
"Falling back to Moe WNA16 kernels.")
|
||||
return MoeWNA16Config.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
else:
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
return get_linear_quant_method(self, layer, prefix,
|
||||
GPTQMarlinLinearMethod)
|
||||
|
||||
@@ -408,7 +411,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
torch.empty(num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.half),
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
@@ -418,7 +421,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
torch.empty(num_experts,
|
||||
scales_size2,
|
||||
hidden_size,
|
||||
dtype=torch.half),
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
@@ -493,6 +496,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
w2_g_idx_sort_indices)
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
layer.workspace = torch.zeros((sms * 4, ),
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
# Process act_order
|
||||
@@ -601,10 +611,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
"Apply router weight on input is not supported for"
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
# The input must currently be float16
|
||||
orig_dtype = x.dtype
|
||||
x = x.half()
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@@ -626,9 +632,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
g_idx2=layer.w2_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
num_bits=self.quant_config.quant_type.size_bits,
|
||||
is_k_full=self.is_k_full).to(orig_dtype)
|
||||
workspace=layer.workspace,
|
||||
is_k_full=self.is_k_full)
|
||||
|
||||
@@ -151,6 +151,19 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
|
||||
group_size=group_size)[0]
|
||||
|
||||
|
||||
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
|
||||
-> bool:
|
||||
hidden_size = layer.hidden_size
|
||||
intermediate_size_per_partition = layer.intermediate_size_per_partition
|
||||
|
||||
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
|
||||
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
|
||||
# moe marlin requires n % 128 == 0 and k % 64 == 0
|
||||
return hidden_size % 128 == 0 and \
|
||||
intermediate_size_per_partition % max(64, group_size) == 0 and \
|
||||
group_size in [-1, 32, 64, 128]
|
||||
|
||||
|
||||
def marlin_make_workspace(output_size_per_partition: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
max_workspace_size = (output_size_per_partition //
|
||||
|
||||
Reference in New Issue
Block a user