Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -10,30 +10,48 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
MPLinearLayerConfig,
choose_mp_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_dynamic_override, get_linear_quant_method, override_config)
get_dynamic_override,
get_linear_quant_method,
override_config,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, check_moe_marlin_supports_layer,
marlin_make_workspace_new, marlin_moe_permute_scales, marlin_permute_bias,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
check_marlin_supported,
check_moe_marlin_supports_layer,
marlin_make_workspace_new,
marlin_moe_permute_scales,
marlin_permute_bias,
marlin_repeat_scales_on_all_ranks,
verify_marlin_supported,
)
from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter,
)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.transformers_utils.config import get_safetensors_params_metadata
@@ -52,9 +70,13 @@ def get_moe_quant_method(
if isinstance(layer, FusedMoE):
# False = skip module, None = no override, else = Positive match
if get_dynamic_override( # noqa: E712
if (
get_dynamic_override( # noqa: E712
cloned_config, # noqa: E712
layer_name=prefix) == False: # noqa: E712
layer_name=prefix,
)
== False
): # noqa: E712
return UnquantizedFusedMoEMethod(layer.moe_config)
if prefix:
@@ -75,15 +97,16 @@ class GPTQMarlinConfig(QuantizationConfig):
}
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
full_config: dict[str, Any],
modules_in_block_to_quantize: Optional[list[str]] = None) -> None:
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
full_config: dict[str, Any],
modules_in_block_to_quantize: Optional[list[str]] = None,
) -> None:
super().__init__()
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
@@ -125,8 +148,9 @@ class GPTQMarlinConfig(QuantizationConfig):
self.full_config = full_config
if (weight_bits, is_sym) not in self.TYPE_MAP:
raise ValueError("Unsupported quantization config: "
f"bits={weight_bits}, sym={is_sym}")
raise ValueError(
f"Unsupported quantization config: bits={weight_bits}, sym={is_sym}"
)
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
@@ -169,50 +193,64 @@ class GPTQMarlinConfig(QuantizationConfig):
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None)
return cls(weight_bits, group_size, desc_act, is_sym,
lm_head_quantized, dynamic, config,
modules_in_block_to_quantize)
config, ["modules_in_block_to_quantize"], default=None
)
return cls(
weight_bits,
group_size,
desc_act,
is_sym,
lm_head_quantized,
dynamic,
config,
modules_in_block_to_quantize,
)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
cls, hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]:
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
or user_quant == "gptq_marlin")
is_valid_user_quant = (
user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
)
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
msg = (
"The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name())
)
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "gptq":
logger.info("Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_marlin for"
" faster inference")
logger.info(
"Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_marlin for"
" faster inference"
)
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, FusedMoE):
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by GPTQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
return get_moe_quant_method(self, layer, prefix,
GPTQMarlinMoEMethod)
return get_linear_quant_method(self, layer, prefix,
GPTQMarlinLinearMethod)
"Falling back to Moe WNA16 kernels."
)
return MoeWNA16Config.from_config(self.full_config).get_quant_method(
layer, prefix
)
return get_moe_quant_method(self, layer, prefix, GPTQMarlinMoEMethod)
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
@classmethod
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
@@ -229,41 +267,40 @@ class GPTQMarlinConfig(QuantizationConfig):
return False
# Marlin conversion is only valid if required properties are found
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
if num_bits is None or group_size is None or sym is None or desc_act is None:
return False
if (num_bits, sym) not in cls.TYPE_MAP:
return False
return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size)
return check_marlin_supported(
quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
)
def apply_vllm_mapper(self, hf_to_vllm_mapper):
if self.modules_in_block_to_quantize is not None:
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
self.modules_in_block_to_quantize)
self.modules_in_block_to_quantize
)
def maybe_update_config(self,
model_name: str,
revision: Optional[str] = None):
def maybe_update_config(self, model_name: str, revision: Optional[str] = None):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]
# flatten original modules_in_block_to_quantize
self.modules_in_block_to_quantize = [
item for sublist in self.modules_in_block_to_quantize
item
for sublist in self.modules_in_block_to_quantize
for item in sublist
]
return
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
metadata = get_safetensors_params_metadata(model_name,
revision=revision)
metadata = get_safetensors_params_metadata(model_name, revision=revision)
quant_layers: set[str] = {
param_name.rsplit(".", 1)[0]
for param_name, info in metadata.items()
if (dtype := info.get('dtype', None))
if (dtype := info.get("dtype", None))
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
}
self.modules_in_block_to_quantize = list(quant_layers)
@@ -282,8 +319,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
self.quant_config = quant_config
# Verify supported on platform.
verify_marlin_supported(quant_type=self.quant_config.quant_type,
group_size=self.quant_config.group_size)
verify_marlin_supported(
quant_type=self.quant_config.quant_type,
group_size=self.quant_config.group_size,
)
def create_weights(
self,
@@ -301,20 +340,21 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
partition_weight_shape=(
input_size_per_partition,
output_size_per_partition,
),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
has_g_idx=self.quant_config.desc_act
has_g_idx=self.quant_config.desc_act,
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for GPTQMarlinLinearMethod",
kernel_type.__name__)
logger.info("Using %s for GPTQMarlinLinearMethod", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# Normalize group_size
@@ -324,9 +364,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
group_size = input_size
# Determine sharding
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
self.quant_config.group_size,
is_row_parallel):
if marlin_repeat_scales_on_all_ranks(
self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel
):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim = None
@@ -348,67 +388,69 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
weight_loader=weight_loader,
)
# Activation order
g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
g_idx = RowvLLMParameter(
data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader,
)
qzeros_args = {
"data":
torch.empty(
"data": torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
"weight_loader": weight_loader,
}
weight_scale_args = {
"data":
torch.empty(
"data": torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
"weight_loader": weight_loader,
}
if scales_and_zp_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
**qzeros_args,
)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
scales = GroupQuantScaleParameter(
output_dim=1, input_dim=0, **weight_scale_args
)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
**qzeros_args,
)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="qweight",
w_s_param_name="scales",
w_zp_param_name="qzeros",
w_gidx_param_name="g_idx")
self.kernel = kernel_type(
mp_linear_kernel_config,
w_q_param_name="qweight",
w_s_param_name="scales",
w_zp_param_name="qzeros",
w_gidx_param_name="g_idx",
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
@@ -437,8 +479,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
elif self.quant_config.quant_type.size_bits == 8:
self.quant_type = scalar_types.uint8b128
else:
raise ValueError(
"GPTQMarlinMoEMethod only supports int4 and int8 now.")
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
def create_weights(
self,
@@ -449,28 +490,27 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
intermediate_size_full = extra_weight_attrs.pop(
"intermediate_size_full")
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
self.is_k_full = (not self.quant_config.desc_act) or (
intermediate_size_per_partition == intermediate_size_full)
intermediate_size_per_partition == intermediate_size_full
)
if self.quant_config.group_size != -1:
scales_size13 = hidden_size // self.quant_config.group_size
w2_scales_size = (intermediate_size_full
if self.quant_config.desc_act else
intermediate_size_per_partition)
scales_size2 = (w2_scales_size // self.quant_config.group_size)
w2_scales_size = (
intermediate_size_full
if self.quant_config.desc_act
else intermediate_size_per_partition
)
scales_size2 = w2_scales_size // self.quant_config.group_size
strategy = FusedMoeWeightScaleSupported.GROUP.value
else:
scales_size13 = 1
scales_size2 = 1
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
extra_weight_attrs.update({
"quant_method": strategy,
"is_transposed": True
})
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
# Fused gate_up_proj (column parallel)
w13_qweight = torch.nn.Parameter(
torch.empty(
@@ -487,8 +527,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w2_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition //
self.quant_config.pack_factor,
intermediate_size_per_partition // self.quant_config.pack_factor,
hidden_size,
dtype=torch.int32,
),
@@ -498,51 +537,51 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_qweight, extra_weight_attrs)
# up_proj scales
w13_scales = torch.nn.Parameter(
torch.empty(num_experts,
scales_size13,
2 * intermediate_size_per_partition,
dtype=params_dtype),
torch.empty(
num_experts,
scales_size13,
2 * intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
# down_proj scales
w2_scales = torch.nn.Parameter(
torch.empty(num_experts,
scales_size2,
hidden_size,
dtype=params_dtype),
torch.empty(num_experts, scales_size2, hidden_size, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
# don't shard the w2 scales when running act order
set_weight_attrs(w2_scales,
{"load_full_w2": self.quant_config.desc_act})
set_weight_attrs(w2_scales, {"load_full_w2": self.quant_config.desc_act})
# up_proj scales
w13_qzeros = torch.nn.Parameter(
torch.empty(num_experts,
scales_size13,
2 * intermediate_size_per_partition //
self.quant_config.pack_factor,
dtype=params_dtype),
torch.empty(
num_experts,
scales_size13,
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
# down_proj scales
w2_qzeros = torch.nn.Parameter(
torch.empty(num_experts,
scales_size2,
hidden_size // self.quant_config.pack_factor,
dtype=params_dtype),
torch.empty(
num_experts,
scales_size2,
hidden_size // self.quant_config.pack_factor,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
# don't shard the w2 scales when running act order
set_weight_attrs(w2_qzeros,
{"load_full_w2": self.quant_config.desc_act})
set_weight_attrs(w2_qzeros, {"load_full_w2": self.quant_config.desc_act})
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
@@ -571,8 +610,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
@@ -582,15 +620,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
device = layer.w13_qweight.device
layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
@@ -600,42 +636,36 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(
layer.w13_g_idx[e]).to(torch.int32)
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
torch.int32
)
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
torch.int32)
w13_sorted_g_idx[e] = layer.w13_g_idx[e][
w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][
w2_g_idx_sort_indices[e]]
torch.int32
)
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
else:
# Reset g_idx related tensors
num_experts = layer.w13_g_idx.shape[0]
device = layer.w13_g_idx.device
layer.w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
# Repack weights
@@ -665,9 +695,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.w2_scales.shape[1] *
(self.quant_config.group_size if self.quant_config.group_size != -1
else self.quant_config.pack_factor),
size_k=layer.w2_scales.shape[1]
* (
self.quant_config.group_size
if self.quant_config.group_size != -1
else self.quant_config.pack_factor
),
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
)
@@ -680,7 +713,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
@@ -710,7 +744,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")
"EPLB not supported for `GPTQMarlinMoEMethod` yet."
)
assert activation == "silu", "Only SiLU activation is supported."
@@ -726,7 +761,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
indices_type=self.topk_indices_dtype,
)
return torch.ops.vllm.fused_marlin_moe(
x,
@@ -748,4 +784,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace,
is_k_full=self.is_k_full)
is_k_full=self.is_k_full,
)