[Bug] Fix Trtllm Fp8 MoE Weight Shuffle Memory Fragamentation (#39054)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
This commit is contained in:
Wei Zhao
2026-04-07 08:04:08 -04:00
committed by GitHub
parent 7b9de7c892
commit 0be9516ea4

View File

@@ -322,20 +322,23 @@ def _shuffle_deepseek_fp8_moe_weights(
block_k = 128
num_experts = w13.shape[0]
w13_shuffled: list[torch.Tensor] = []
w2_shuffled: list[torch.Tensor] = []
M13, K13 = w13.shape[1], w13.shape[2]
M2, K2 = w2.shape[1], w2.shape[2]
w13_out = torch.empty(
num_experts, K13 // block_k, M13, block_k, dtype=torch.uint8, device=w13.device
)
w2_out = torch.empty(
num_experts, K2 // block_k, M2, block_k, dtype=torch.uint8, device=w2.device
)
for i in range(num_experts):
t13 = shuffle_matrix_a(w13[i].view(torch.uint8), epilogue_tile_m)
t13 = convert_to_block_layout(t13, block_k)
w13_shuffled.append(t13)
w13_out[i] = convert_to_block_layout(t13, block_k)
t2 = shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m)
t2 = convert_to_block_layout(t2, block_k)
w2_shuffled.append(t2)
w2_out[i] = convert_to_block_layout(t2, block_k)
w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn)
w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn)
return w13_out, w2_out
return w13_out.view(torch.float8_e4m3fn), w2_out.view(torch.float8_e4m3fn)
def _shuffle_mxfp8_moe_weights(