[MoE Refactor][12/N] Marlin Fp8 MoE Pure Function (#31499)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -964,12 +964,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
|
||||
elif self.use_marlin:
|
||||
prepare_moe_fp8_layer_for_marlin(
|
||||
layer, False, input_dtype=self.marlin_input_dtype
|
||||
(
|
||||
workspace,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
) = prepare_moe_fp8_layer_for_marlin(
|
||||
layer,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
layer.workspace = workspace
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
|
||||
replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
|
||||
|
||||
if self.use_cutlass:
|
||||
assert self.weight_quant.strategy != QuantizationStrategy.BLOCK
|
||||
|
||||
@@ -912,6 +912,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
|
||||
w13_weight, w2_weight
|
||||
)
|
||||
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
(
|
||||
workspace,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
) = prepare_moe_fp8_layer_for_marlin(
|
||||
layer,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
)
|
||||
layer.workspace = workspace
|
||||
|
||||
elif self.fp8_backend in [
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
@@ -937,17 +954,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale)
|
||||
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale)
|
||||
|
||||
# TODO(rob): we do this after replace_parameter() because
|
||||
# prepare_moe_fp8_layer_for_marlin uses on the layer's params
|
||||
# directly. We will refactor this in a follow up PR.
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
prepare_moe_fp8_layer_for_marlin(
|
||||
layer, False, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
def _setup_kernel(self, layer: Module) -> None:
|
||||
"""Setup Modular Kernel for TP Case"""
|
||||
# NOTE(rob): this is a WIP refactor. We are first migrating
|
||||
@@ -1194,20 +1200,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
return fp8_w8a16_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
|
||||
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
||||
),
|
||||
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
|
||||
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.weight_block_size,
|
||||
|
||||
@@ -315,10 +315,26 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
|
||||
elif self.use_marlin:
|
||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
(workspace, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale) = (
|
||||
prepare_moe_fp8_layer_for_marlin(
|
||||
layer,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
)
|
||||
)
|
||||
layer.workspace = workspace
|
||||
# TODO(rob): once we apply refactor to Quark, switch to using
|
||||
# replace_parameter for compatibility with reloading in RL.
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False
|
||||
)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_weight_scale, requires_grad=False
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
|
||||
@@ -199,9 +199,18 @@ def prepare_fp8_layer_for_marlin(
|
||||
|
||||
def prepare_moe_fp8_layer_for_marlin(
|
||||
layer: torch.nn.Module,
|
||||
size_k_first: bool = True,
|
||||
w13_weight: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w13_weight_scale: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
) -> tuple[
|
||||
torch.Tensor, # workspace
|
||||
torch.Tensor, # w13_weight
|
||||
torch.Tensor, # w2_weight
|
||||
torch.Tensor, # w13_weight_scale
|
||||
torch.Tensor, # w2_weight_scale
|
||||
]:
|
||||
logger.warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||
@@ -209,7 +218,7 @@ def prepare_moe_fp8_layer_for_marlin(
|
||||
"performance for compute-heavy workloads."
|
||||
)
|
||||
if input_dtype is not None and input_dtype.itemsize == 1:
|
||||
raise RuntimeError("Marlin W8A8 is not supported.")
|
||||
raise NotImplementedError("Marlin W8A8 is not supported.")
|
||||
|
||||
e = layer.num_experts
|
||||
k = layer.hidden_size
|
||||
@@ -218,53 +227,40 @@ def prepare_moe_fp8_layer_for_marlin(
|
||||
|
||||
# WORKSPACE
|
||||
device = layer.w13_weight.device
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
workspace = marlin_make_workspace_new(device, 4)
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
|
||||
# WEIGHT
|
||||
# Repack weights to marlin format
|
||||
for name in ["w13_weight", "w2_weight"]:
|
||||
weight = getattr(layer, name)
|
||||
def repack_weight(name: str, weight: torch.Tensor) -> torch.Tensor:
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = n * 2, k
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
if size_k_first:
|
||||
assert weight.shape == (e, size_k, size_n)
|
||||
else:
|
||||
assert weight.shape == (e, size_n, size_k)
|
||||
assert weight.shape == (e, size_n, size_k)
|
||||
|
||||
for i in range(e):
|
||||
qweight = pack_fp8_to_int32(weight[i], size_k_first)
|
||||
if not size_k_first:
|
||||
qweight = qweight.T.contiguous()
|
||||
qweight = pack_fp8_to_int32(weight[i], size_k_first=False)
|
||||
qweight = qweight.T.contiguous()
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8
|
||||
)
|
||||
tensor_list.append(marlin_qweight)
|
||||
|
||||
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
|
||||
setattr(layer, name, weight)
|
||||
w13_weight = repack_weight("w13", w13_weight)
|
||||
w2_weight = repack_weight("w2", w2_weight)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
group_size = -1 if weight_block_size is None else weight_block_size[1]
|
||||
|
||||
for name in ["w13", "w2"]:
|
||||
if name + "_weight_scale" in dir(layer):
|
||||
new_name = name + "_weight_scale"
|
||||
scales = getattr(layer, new_name).to(layer.orig_dtype)
|
||||
delattr(layer, new_name)
|
||||
elif name + "_weight_scale_inv" in dir(layer):
|
||||
new_name = name + "_weight_scale_inv"
|
||||
scales = getattr(layer, new_name).to(layer.orig_dtype)
|
||||
delattr(layer, new_name)
|
||||
|
||||
def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor:
|
||||
scales = scales.to(layer.orig_dtype)
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = n * 2, k
|
||||
@@ -294,8 +290,7 @@ def prepare_moe_fp8_layer_for_marlin(
|
||||
# block-wise quantization -> group-wise quantization
|
||||
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
|
||||
# =>(repeat)=> (e, size_k // block_size[1], size_n)
|
||||
if not size_k_first:
|
||||
scales = scales.permute(0, 2, 1)
|
||||
scales = scales.permute(0, 2, 1)
|
||||
block_n = weight_block_size[0]
|
||||
scales = scales.repeat_interleave(block_n, 2)
|
||||
# size_n may not divisible by block_size[0]
|
||||
@@ -310,26 +305,18 @@ def prepare_moe_fp8_layer_for_marlin(
|
||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
if input_dtype != torch.float8_e4m3fn:
|
||||
scales = fp8_fused_exponent_bias_into_scales(scales)
|
||||
scales = torch.nn.Parameter(scales, requires_grad=False)
|
||||
return scales
|
||||
|
||||
setattr(layer, name + "_weight_scale", scales)
|
||||
w13_weight_scale = permute_scales(w13_weight_scale, "w13")
|
||||
w2_weight_scale = permute_scales(w2_weight_scale, "w2")
|
||||
|
||||
# BIAS
|
||||
# Permute bias
|
||||
for name in ["w13_bias", "w2_bias"]:
|
||||
if not hasattr(layer, name):
|
||||
continue
|
||||
bias = getattr(layer, name).to(layer.orig_dtype)
|
||||
|
||||
tensor_list = []
|
||||
for i in range(e):
|
||||
expert_bias = bias[i]
|
||||
|
||||
tensor_list.append(marlin_permute_bias(expert_bias))
|
||||
|
||||
bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
bias = torch.nn.Parameter(bias, requires_grad=False)
|
||||
setattr(layer, name, bias)
|
||||
return (
|
||||
workspace,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
)
|
||||
|
||||
|
||||
def pack_fp8_to_int32(
|
||||
|
||||
Reference in New Issue
Block a user