[Kernel] Zero point support in fused MarlinMoE kernel + AWQ Fused MoE (#8973)
Co-authored-by: Dipika <dipikasikka1@gmail.com> Co-authored-by: Dipika Sikka <ds3822@columbia.edu>
This commit is contained in:
@@ -1,16 +1,21 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
|
||||
marlin_permute_scales, moe_awq_to_marlin_zero_points,
|
||||
verify_marlin_supported, verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
@@ -35,12 +40,13 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
self.group_size = group_size
|
||||
self.has_zp = has_zp
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.weight_bits = weight_bits
|
||||
|
||||
if weight_bits not in self.TYPE_MAP:
|
||||
raise ValueError(f"Unsupported num_bits = {weight_bits}. "
|
||||
if self.weight_bits not in self.TYPE_MAP:
|
||||
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
|
||||
f"Supported num_bits = {self.TYPE_MAP.keys()}")
|
||||
|
||||
self.quant_type = self.TYPE_MAP[weight_bits]
|
||||
self.quant_type = self.TYPE_MAP[self.weight_bits]
|
||||
|
||||
verify_marlin_supported(self.quant_type,
|
||||
group_size=self.group_size,
|
||||
@@ -98,10 +104,12 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["AWQMarlinLinearMethod"]:
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return AWQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return AWQMoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
@@ -271,4 +279,182 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
quant_type=self.quant_config.quant_type,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
bias=bias)
|
||||
|
||||
|
||||
class AWQMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: AWQMarlinConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
extra_weight_attrs.update({
|
||||
"is_transposed":
|
||||
True,
|
||||
"quant_method":
|
||||
FusedMoeWeightScaleSupported.GROUP.value,
|
||||
})
|
||||
|
||||
w13_qweight = Parameter(torch.empty(num_experts,
|
||||
hidden_size,
|
||||
2 * intermediate_size //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
|
||||
w2_qweight = Parameter(torch.empty(num_experts,
|
||||
intermediate_size,
|
||||
hidden_size //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
num_groups_w13 = hidden_size // self.quant_config.group_size
|
||||
num_groups_w2 = intermediate_size // self.quant_config.group_size
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
w13_scales = Parameter(torch.empty(num_experts,
|
||||
num_groups_w13,
|
||||
intermediate_size * 2,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
|
||||
w2_scales = Parameter(torch.empty(num_experts,
|
||||
num_groups_w2,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_ZERO_POINT
|
||||
# Allocate 2 zero points for w1 and w3 respectively.
|
||||
w13_qzeros = Parameter(torch.empty(num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
|
||||
w2_qzeros = Parameter(torch.empty(num_experts,
|
||||
num_groups_w2,
|
||||
hidden_size //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
num_experts = layer.w13_qweight.shape[0]
|
||||
device = layer.w13_qweight.device
|
||||
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
marlin_w13_qweight = ops.awq_marlin_moe_repack(
|
||||
layer.w13_qweight,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
size_k=layer.w13_qweight.shape[1],
|
||||
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
|
||||
marlin_w2_qweight = ops.awq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
size_k=layer.w2_qweight.shape[1],
|
||||
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
|
||||
# Why does this take the intermediate size for size_k?
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
marlin_w13_zp = moe_awq_to_marlin_zero_points(
|
||||
layer.w13_qzeros,
|
||||
size_k=layer.w13_qzeros.shape[1],
|
||||
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
|
||||
|
||||
marlin_w2_zp = moe_awq_to_marlin_zero_points(
|
||||
layer.w2_qzeros,
|
||||
size_k=layer.w2_qzeros.shape[1],
|
||||
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe)
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function)
|
||||
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
layer.w13_scales,
|
||||
layer.w2_scales,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_zeros=layer.w13_qzeros,
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user