[MoE Refactor][10/N] Cleanup Fp8 Process Weights After Loading (#31169)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@@ -728,6 +727,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.weight_block_size = self.quant_config.weight_block_size
|
self.weight_block_size = self.quant_config.weight_block_size
|
||||||
self.block_quant: bool = self.weight_block_size is not None
|
self.block_quant: bool = self.weight_block_size is not None
|
||||||
|
self.weight_scale_name = (
|
||||||
|
"weight_scale_inv" if self.block_quant else "weight_scale"
|
||||||
|
)
|
||||||
self.fp8_backend = get_fp8_moe_backend(
|
self.fp8_backend = get_fp8_moe_backend(
|
||||||
self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
|
self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
|
||||||
)
|
)
|
||||||
@@ -832,38 +834,28 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
# WEIGHT_SCALES
|
# WEIGHT_SCALES
|
||||||
if not self.block_quant:
|
if not self.block_quant:
|
||||||
# Allocate 2 scales for w1 and w3 respectively.
|
# For per-tensor quant, the scales are per expert and weight.
|
||||||
# They will be combined to a single scale after weight loading.
|
w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
|
||||||
w13_weight_scale = torch.nn.Parameter(
|
w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
|
||||||
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
|
||||||
)
|
|
||||||
w2_weight_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
|
||||||
else:
|
else:
|
||||||
w13_weight_scale = torch.nn.Parameter(
|
# For block quant, the scales are per block (typically 128x128).
|
||||||
torch.ones(
|
w13_scale_data = torch.ones(
|
||||||
num_experts,
|
num_experts,
|
||||||
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
|
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
|
||||||
(hidden_size + block_k - 1) // block_k,
|
(hidden_size + block_k - 1) // block_k,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
)
|
||||||
w2_weight_scale = torch.nn.Parameter(
|
w2_scale_data = torch.ones(
|
||||||
torch.ones(
|
num_experts,
|
||||||
num_experts,
|
(hidden_size + block_n - 1) // block_n,
|
||||||
(hidden_size + block_n - 1) // block_n,
|
(intermediate_size_per_partition + block_k - 1) // block_k,
|
||||||
(intermediate_size_per_partition + block_k - 1) // block_k,
|
dtype=torch.float32,
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
)
|
||||||
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
|
||||||
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
|
||||||
assert self.quant_config.activation_scheme == "dynamic"
|
# Note: name is weight_scale for tensor, weight_scale_inv for block.
|
||||||
|
layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
|
||||||
|
layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)
|
||||||
|
|
||||||
# Add the quantization method used (per tensor/grouped/channel)
|
# Add the quantization method used (per tensor/grouped/channel)
|
||||||
# to ensure the weight scales are loaded in properly
|
# to ensure the weight scales are loaded in properly
|
||||||
@@ -877,6 +869,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
# INPUT_SCALES
|
# INPUT_SCALES
|
||||||
if self.quant_config.activation_scheme == "static":
|
if self.quant_config.activation_scheme == "static":
|
||||||
|
assert not self.block_quant
|
||||||
w13_input_scale = torch.nn.Parameter(
|
w13_input_scale = torch.nn.Parameter(
|
||||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||||
)
|
)
|
||||||
@@ -893,158 +886,60 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_input_scale = None
|
layer.w13_input_scale = None
|
||||||
layer.w2_input_scale = None
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def _convert_weights_to_kernel_format(
|
||||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
self,
|
||||||
return
|
layer: Module,
|
||||||
|
w13_weight: torch.Tensor,
|
||||||
# TODO (rob): refactor block quant into separate class.
|
w2_weight: torch.Tensor,
|
||||||
if self.block_quant:
|
w13_weight_scale: torch.Tensor,
|
||||||
assert self.quant_config.activation_scheme == "dynamic"
|
w2_weight_scale: torch.Tensor,
|
||||||
if current_platform.is_fp8_fnuz():
|
) -> None:
|
||||||
w13_weight, w13_weight_scale_inv, w13_input_scale = (
|
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
assert self.block_quant
|
||||||
layer.w13_weight,
|
w13_weight, w13_weight_scale = deepgemm_post_process_fp8_weight_block(
|
||||||
layer.w13_weight_scale_inv,
|
wq=w13_weight,
|
||||||
layer.w13_input_scale,
|
ws=w13_weight_scale,
|
||||||
)
|
quant_block_shape=tuple(layer.weight_block_size),
|
||||||
)
|
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||||
w2_weight, w2_weight_scale_inv, w2_input_scale = (
|
)
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
w2_weight, w2_weight_scale = deepgemm_post_process_fp8_weight_block(
|
||||||
layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
|
wq=w2_weight,
|
||||||
)
|
ws=w2_weight_scale,
|
||||||
)
|
quant_block_shape=tuple(layer.weight_block_size),
|
||||||
elif self.flashinfer_moe_backend is not None:
|
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||||
# NOTE: weights have to be swapped since the activation is
|
)
|
||||||
# applied on different half for flashinfer vs vllm
|
elif self.fp8_backend == Fp8MoeBackend.AITER:
|
||||||
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
|
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
|
||||||
w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
|
w13_weight, w2_weight
|
||||||
w2_weight = layer.w2_weight.data
|
)
|
||||||
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
|
elif self.fp8_backend in [
|
||||||
|
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||||
|
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||||
|
]:
|
||||||
|
w13_weight = swap_w13_to_w31(w13_weight)
|
||||||
|
if self.block_quant:
|
||||||
|
w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
|
||||||
else:
|
else:
|
||||||
w13_weight = layer.w13_weight.data
|
# TODO(rob): this function is a hack that renames the scaling
|
||||||
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
|
# factors in the Module. This is a hack we should clean up.
|
||||||
w2_weight = layer.w2_weight
|
|
||||||
w2_weight_scale_inv = layer.w2_weight_scale_inv
|
|
||||||
|
|
||||||
# torch.compile() cannot use Parameter subclasses.
|
|
||||||
replace_parameter(layer, "w13_weight", w13_weight)
|
|
||||||
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
|
|
||||||
replace_parameter(layer, "w2_weight", w2_weight)
|
|
||||||
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
|
|
||||||
if self.fp8_backend == Fp8MoeBackend.AITER:
|
|
||||||
# reshaping weights is required for aiter moe kernel.
|
|
||||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
|
||||||
layer.w13_weight.data, layer.w2_weight.data
|
|
||||||
)
|
|
||||||
|
|
||||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
|
||||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
|
||||||
|
|
||||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
|
||||||
# it ahead of time for performance reasons.
|
|
||||||
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
|
|
||||||
dg_w13_weight, dg_w13_weight_scale_inv = (
|
|
||||||
deepgemm_post_process_fp8_weight_block(
|
|
||||||
wq=layer.w13_weight.data,
|
|
||||||
ws=layer.w13_weight_scale_inv.data,
|
|
||||||
quant_block_shape=tuple(layer.weight_block_size),
|
|
||||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
dg_w2_weight, dg_w2_weight_scale_inv = (
|
|
||||||
deepgemm_post_process_fp8_weight_block(
|
|
||||||
wq=layer.w2_weight.data,
|
|
||||||
ws=layer.w2_weight_scale_inv.data,
|
|
||||||
quant_block_shape=tuple(layer.weight_block_size),
|
|
||||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
|
|
||||||
layer.w13_weight_scale_inv = Parameter(
|
|
||||||
dg_w13_weight_scale_inv, requires_grad=False
|
|
||||||
)
|
|
||||||
layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
|
|
||||||
layer.w2_weight_scale_inv = Parameter(
|
|
||||||
dg_w2_weight_scale_inv, requires_grad=False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Fp8 moe kernels require a single activation scale.
|
|
||||||
# We take the max of all the scales in case they differ.
|
|
||||||
if self.quant_config.activation_scheme == "static":
|
|
||||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
|
||||||
raise ValueError(
|
|
||||||
"QuantConfig has static quantization, but found "
|
|
||||||
"activation scales are None."
|
|
||||||
)
|
|
||||||
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
|
||||||
layer.w2_input_scale
|
|
||||||
):
|
|
||||||
logger.warning_once(
|
|
||||||
"Found input_scales that are not equal for "
|
|
||||||
"fp8 MoE layer. Using the maximum across experts "
|
|
||||||
"for each layer."
|
|
||||||
)
|
|
||||||
replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
|
|
||||||
replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
|
|
||||||
if current_platform.is_fp8_fnuz():
|
|
||||||
# Normalize the weights and scales
|
|
||||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
|
||||||
layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
|
|
||||||
)
|
|
||||||
)
|
|
||||||
w2_weight, w2_weight_scale, w2_input_scale = (
|
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
|
||||||
layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Reset the parameter
|
|
||||||
replace_parameter(layer, "w13_weight", w13_weight)
|
|
||||||
replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
|
|
||||||
if w13_input_scale is not None:
|
|
||||||
replace_parameter(layer, "w13_input_scale", w13_input_scale)
|
|
||||||
replace_parameter(layer, "w2_weight", w2_weight)
|
|
||||||
replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
|
|
||||||
if w2_input_scale is not None:
|
|
||||||
replace_parameter(layer, "w2_input_scale", w2_input_scale)
|
|
||||||
|
|
||||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
|
||||||
# We take the max then dequant and requant each expert.
|
|
||||||
assert layer.w13_weight_scale is not None
|
|
||||||
shard_size = layer.intermediate_size_per_partition
|
|
||||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
|
||||||
for expert_id in range(layer.local_num_experts):
|
|
||||||
start = 0
|
|
||||||
for shard_id in range(2):
|
|
||||||
dq_weight = per_tensor_dequantize(
|
|
||||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
|
||||||
layer.w13_weight_scale[expert_id][shard_id],
|
|
||||||
)
|
|
||||||
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
|
||||||
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
|
||||||
)
|
|
||||||
start += shard_size
|
|
||||||
|
|
||||||
if self.fp8_backend == Fp8MoeBackend.AITER:
|
|
||||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
|
||||||
layer.w13_weight, layer.w2_weight
|
|
||||||
)
|
|
||||||
|
|
||||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
|
||||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
|
||||||
|
|
||||||
replace_parameter(layer, "w13_weight_scale", max_w13_scales)
|
|
||||||
|
|
||||||
if self.flashinfer_moe_backend is not None:
|
|
||||||
# NOTE: weights have to be swapped since the activation is
|
|
||||||
# applied on different half for flashinfer vs vllm
|
|
||||||
assert not self.block_quant
|
|
||||||
register_moe_scaling_factors(layer)
|
register_moe_scaling_factors(layer)
|
||||||
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
|
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
|
||||||
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
||||||
layer.w13_weight.data = w13_weight.data
|
elif self.fp8_backend == Fp8MoeBackend.AITER:
|
||||||
|
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
|
||||||
|
w13_weight, w2_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace parameters with updated versions. Note that this helper
|
||||||
|
# function ensures the replacement is compatible with RL weight reloads.
|
||||||
|
replace_parameter(layer, "w13_weight", w13_weight)
|
||||||
|
replace_parameter(layer, "w2_weight", w2_weight)
|
||||||
|
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale)
|
||||||
|
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale)
|
||||||
|
|
||||||
|
# TODO(rob): we do this after replace_parameter() because
|
||||||
|
# prepare_moe_fp8_layer_for_marlin uses on the layer's params
|
||||||
|
# directly. We will refactor this in a follow up PR.
|
||||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||||
prepare_moe_fp8_layer_for_marlin(
|
prepare_moe_fp8_layer_for_marlin(
|
||||||
layer, False, input_dtype=self.marlin_input_dtype
|
layer, False, input_dtype=self.marlin_input_dtype
|
||||||
@@ -1053,6 +948,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
del layer.w13_input_scale
|
del layer.w13_input_scale
|
||||||
del layer.w2_input_scale
|
del layer.w2_input_scale
|
||||||
|
|
||||||
|
def _setup_kernel(self, layer: Module) -> None:
|
||||||
|
"""Setup Modular Kernel for TP Case"""
|
||||||
# NOTE(rob): this is a WIP refactor. We are first migrating
|
# NOTE(rob): this is a WIP refactor. We are first migrating
|
||||||
# all of the kernels in the TP case to use mk. Once this is
|
# all of the kernels in the TP case to use mk. Once this is
|
||||||
# done, then we will initialzie the TP case and DP/EP case
|
# done, then we will initialzie the TP case and DP/EP case
|
||||||
@@ -1134,6 +1031,71 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
self.use_inplace = True
|
self.use_inplace = True
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Allow for accessing weights and scales in standard way.
|
||||||
|
w13_weight = layer.w13_weight
|
||||||
|
w2_weight = layer.w2_weight
|
||||||
|
w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
||||||
|
w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
||||||
|
|
||||||
|
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
|
||||||
|
if current_platform.is_fp8_fnuz():
|
||||||
|
w13_weight, w13_weight_scale, layer.w13_input_scale = (
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
w13_weight, w13_weight_scale, layer.w13_input_scale
|
||||||
|
)
|
||||||
|
)
|
||||||
|
w2_weight, w2_weight_scale, layer.w2_input_scale = (
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
w2_weight, w2_weight_scale, layer.w2_input_scale
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Per tensor kernels require single activation scale. Use the max.
|
||||||
|
if self.quant_config.activation_scheme == "static":
|
||||||
|
assert not self.block_quant
|
||||||
|
assert layer.w13_input_scale is not None
|
||||||
|
assert layer.w2_input_scale is not None
|
||||||
|
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
||||||
|
layer.w2_input_scale
|
||||||
|
):
|
||||||
|
logger.warning_once(
|
||||||
|
"Found input_scales that are not equal for "
|
||||||
|
"fp8 MoE layer. Using the maximum across experts "
|
||||||
|
"for each layer."
|
||||||
|
)
|
||||||
|
replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
|
||||||
|
replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
|
||||||
|
|
||||||
|
# Per tensor kernels require single weight scale for w13 per expert, but
|
||||||
|
# on disk there is a scale for w1 and w3. Use the max to requantize.
|
||||||
|
if not self.block_quant:
|
||||||
|
shard_size = layer.intermediate_size_per_partition
|
||||||
|
max_w13_scales = w13_weight_scale.max(dim=1).values
|
||||||
|
for expert_id in range(layer.local_num_experts):
|
||||||
|
start = 0
|
||||||
|
for shard_id in range(2):
|
||||||
|
dq_weight = per_tensor_dequantize(
|
||||||
|
w13_weight[expert_id][start : start + shard_size, :],
|
||||||
|
w13_weight_scale[expert_id][shard_id],
|
||||||
|
)
|
||||||
|
w13_weight[expert_id][start : start + shard_size, :], _ = (
|
||||||
|
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||||
|
)
|
||||||
|
start += shard_size
|
||||||
|
w13_weight_scale = max_w13_scales
|
||||||
|
|
||||||
|
# Shuffle weights into the runtime format.
|
||||||
|
self._convert_weights_to_kernel_format(
|
||||||
|
layer, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup modular kernel for TP case.
|
||||||
|
self._setup_kernel(layer)
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(
|
def maybe_make_prepare_finalize(
|
||||||
self,
|
self,
|
||||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
@@ -1453,32 +1415,26 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
|||||||
|
|
||||||
# If checkpoint is fp16, quantize in place.
|
# If checkpoint is fp16, quantize in place.
|
||||||
fp8_dtype = current_platform.fp8_dtype()
|
fp8_dtype = current_platform.fp8_dtype()
|
||||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
w13_weight = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
|
||||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
w2_weight = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
|
||||||
|
|
||||||
for expert in range(layer.local_num_experts):
|
for expert in range(layer.local_num_experts):
|
||||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||||
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
ops.scaled_fp8_quant(layer.w13_weight[expert, :, :])
|
||||||
)
|
)
|
||||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||||
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
ops.scaled_fp8_quant(layer.w2_weight[expert, :, :])
|
||||||
)
|
)
|
||||||
replace_parameter(layer, "w13_weight", w13_weight)
|
replace_parameter(layer, "w13_weight", w13_weight)
|
||||||
replace_parameter(layer, "w2_weight", w2_weight)
|
replace_parameter(layer, "w2_weight", w2_weight)
|
||||||
|
|
||||||
# Reshuffle weights for AITER if needed.
|
# Shuffle weights into the runtime format.
|
||||||
if self.fp8_backend == Fp8MoeBackend.AITER:
|
self._convert_weights_to_kernel_format(
|
||||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
layer, w13_weight, w2_weight, layer.w13_weight_scale, layer.w2_weight_scale
|
||||||
layer.w13_weight, layer.w2_weight
|
)
|
||||||
)
|
|
||||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
|
||||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
|
||||||
|
|
||||||
# Rushuffle weights for MARLIN if needed.
|
# Setup modular kernel for TP case.
|
||||||
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
|
self._setup_kernel(layer)
|
||||||
prepare_moe_fp8_layer_for_marlin(
|
|
||||||
layer, False, input_dtype=self.marlin_input_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
|
|||||||
Reference in New Issue
Block a user