[MoE Refactor][12/N] Marlin Fp8 MoE Pure Function (#31499)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2025-12-29 16:27:00 -05:00
committed by GitHub
parent c2ff33cc8c
commit 9152a30d8f
4 changed files with 92 additions and 76 deletions

View File

@@ -964,12 +964,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
elif self.use_marlin:
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
(
workspace,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
) = prepare_moe_fp8_layer_for_marlin(
layer,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
input_dtype=self.marlin_input_dtype,
)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
layer.workspace = workspace
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
if self.use_cutlass:
assert self.weight_quant.strategy != QuantizationStrategy.BLOCK

View File

@@ -912,6 +912,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
w13_weight, w2_weight
)
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
(
workspace,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
) = prepare_moe_fp8_layer_for_marlin(
layer,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
input_dtype=self.marlin_input_dtype,
)
layer.workspace = workspace
elif self.fp8_backend in [
Fp8MoeBackend.FLASHINFER_CUTLASS,
Fp8MoeBackend.FLASHINFER_TRTLLM,
@@ -937,17 +954,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale)
# TODO(rob): we do this after replace_parameter() because
# prepare_moe_fp8_layer_for_marlin uses on the layer's params
# directly. We will refactor this in a follow up PR.
if self.fp8_backend == Fp8MoeBackend.MARLIN:
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
def _setup_kernel(self, layer: Module) -> None:
"""Setup Modular Kernel for TP Case"""
# NOTE(rob): this is a WIP refactor. We are first migrating
@@ -1194,20 +1200,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) -> FusedMoEQuantConfig | None:
if self.fp8_backend == Fp8MoeBackend.MARLIN:
return fp8_w8a16_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
block_shape=self.weight_block_size,
)
return fp8_w8a8_moe_quant_config(
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.weight_block_size,

View File

@@ -315,10 +315,26 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
elif self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
(workspace, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale) = (
prepare_moe_fp8_layer_for_marlin(
layer,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
)
)
layer.workspace = workspace
# TODO(rob): once we apply refactor to Quark, switch to using
# replace_parameter for compatibility with reloading in RL.
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module

View File

@@ -199,9 +199,18 @@ def prepare_fp8_layer_for_marlin(
def prepare_moe_fp8_layer_for_marlin(
layer: torch.nn.Module,
size_k_first: bool = True,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
input_dtype: torch.dtype | None = None,
) -> None:
) -> tuple[
torch.Tensor, # workspace
torch.Tensor, # w13_weight
torch.Tensor, # w2_weight
torch.Tensor, # w13_weight_scale
torch.Tensor, # w2_weight_scale
]:
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
@@ -209,7 +218,7 @@ def prepare_moe_fp8_layer_for_marlin(
"performance for compute-heavy workloads."
)
if input_dtype is not None and input_dtype.itemsize == 1:
raise RuntimeError("Marlin W8A8 is not supported.")
raise NotImplementedError("Marlin W8A8 is not supported.")
e = layer.num_experts
k = layer.hidden_size
@@ -218,53 +227,40 @@ def prepare_moe_fp8_layer_for_marlin(
# WORKSPACE
device = layer.w13_weight.device
layer.workspace = marlin_make_workspace_new(device, 4)
workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT
# Repack weights to marlin format
for name in ["w13_weight", "w2_weight"]:
weight = getattr(layer, name)
def repack_weight(name: str, weight: torch.Tensor) -> torch.Tensor:
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
if size_k_first:
assert weight.shape == (e, size_k, size_n)
else:
assert weight.shape == (e, size_n, size_k)
assert weight.shape == (e, size_n, size_k)
for i in range(e):
qweight = pack_fp8_to_int32(weight[i], size_k_first)
if not size_k_first:
qweight = qweight.T.contiguous()
qweight = pack_fp8_to_int32(weight[i], size_k_first=False)
qweight = qweight.T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8
)
tensor_list.append(marlin_qweight)
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
weight = torch.nn.Parameter(weight, requires_grad=False)
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
setattr(layer, name, weight)
w13_weight = repack_weight("w13", w13_weight)
w2_weight = repack_weight("w2", w2_weight)
# WEIGHT SCALES
# Permute scales
group_size = -1 if weight_block_size is None else weight_block_size[1]
for name in ["w13", "w2"]:
if name + "_weight_scale" in dir(layer):
new_name = name + "_weight_scale"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
elif name + "_weight_scale_inv" in dir(layer):
new_name = name + "_weight_scale_inv"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor:
scales = scales.to(layer.orig_dtype)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
@@ -294,8 +290,7 @@ def prepare_moe_fp8_layer_for_marlin(
# block-wise quantization -> group-wise quantization
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (e, size_k // block_size[1], size_n)
if not size_k_first:
scales = scales.permute(0, 2, 1)
scales = scales.permute(0, 2, 1)
block_n = weight_block_size[0]
scales = scales.repeat_interleave(block_n, 2)
# size_n may not divisible by block_size[0]
@@ -310,26 +305,18 @@ def prepare_moe_fp8_layer_for_marlin(
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
if input_dtype != torch.float8_e4m3fn:
scales = fp8_fused_exponent_bias_into_scales(scales)
scales = torch.nn.Parameter(scales, requires_grad=False)
return scales
setattr(layer, name + "_weight_scale", scales)
w13_weight_scale = permute_scales(w13_weight_scale, "w13")
w2_weight_scale = permute_scales(w2_weight_scale, "w2")
# BIAS
# Permute bias
for name in ["w13_bias", "w2_bias"]:
if not hasattr(layer, name):
continue
bias = getattr(layer, name).to(layer.orig_dtype)
tensor_list = []
for i in range(e):
expert_bias = bias[i]
tensor_list.append(marlin_permute_bias(expert_bias))
bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
bias = torch.nn.Parameter(bias, requires_grad=False)
setattr(layer, name, bias)
return (
workspace,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
)
def pack_fp8_to_int32(