[BugFix] Fix parameter names and process_after_weight_loading for W4A16 MoE Group Act Order (#11528)
Signed-off-by: ElizaWszola <eliza@neuralmagic.com> Co-authored-by: ElizaWszola <eliza@neuralmagic.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
@@ -386,8 +386,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
@@ -402,30 +402,34 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
# scales, the output_size of the weights for both the gate and up
|
||||
# layers must be divisible by block_n.
|
||||
# Required by column parallel or enabling merged weights
|
||||
if intermediate_size % block_n != 0:
|
||||
if intermediate_size_per_partition % block_n != 0:
|
||||
raise ValueError(
|
||||
f"The output_size of gate's and up's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
if (tp_size > 1 and intermediate_size % block_k != 0):
|
||||
if (tp_size > 1
|
||||
and intermediate_size_per_partition % block_k != 0):
|
||||
# Required by row parallel
|
||||
raise ValueError(f"The input_size of down's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
raise ValueError(
|
||||
f"The input_size of down's weight = "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
dtype=params_dtype),
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
@@ -446,7 +450,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * ((intermediate_size + block_n - 1) // block_n),
|
||||
2 * ((intermediate_size_per_partition + block_n - 1) //
|
||||
block_n),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
@@ -456,7 +461,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
torch.ones(
|
||||
num_experts,
|
||||
(hidden_size + block_n - 1) // block_n,
|
||||
(intermediate_size + block_k - 1) // block_k,
|
||||
(intermediate_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
|
||||
Reference in New Issue
Block a user