Add thread_n=64 support to Marlin MoE (#32360)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-01-15 19:45:44 -05:00
committed by GitHub
parent c277fbdf31
commit 83239ff19a
3 changed files with 8 additions and 5 deletions

View File

@@ -226,6 +226,7 @@ def prepare_fp8_moe_layer_for_marlin(
e = layer.num_experts
k = layer.hidden_size
n = layer.intermediate_size_per_partition
w13_n = w13_weight.size(1)
weight_block_size = getattr(layer, "weight_block_size", None)
# WORKSPACE
@@ -240,7 +241,7 @@ def prepare_fp8_moe_layer_for_marlin(
def repack_weight(name: str, weight: torch.Tensor) -> torch.Tensor:
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
size_n, size_k = w13_n, k
else:
size_n, size_k = k, n
@@ -268,7 +269,7 @@ def prepare_fp8_moe_layer_for_marlin(
scales = scales.to(layer.orig_dtype)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
size_n, size_k = w13_n, k
else:
size_n, size_k = k, n