[Bug] Fix Trtllm Fp8 MoE Weight Shuffle Memory Fragamentation (#39054)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
This commit is contained in:
@@ -322,20 +322,23 @@ def _shuffle_deepseek_fp8_moe_weights(
|
||||
block_k = 128
|
||||
num_experts = w13.shape[0]
|
||||
|
||||
w13_shuffled: list[torch.Tensor] = []
|
||||
w2_shuffled: list[torch.Tensor] = []
|
||||
M13, K13 = w13.shape[1], w13.shape[2]
|
||||
M2, K2 = w2.shape[1], w2.shape[2]
|
||||
w13_out = torch.empty(
|
||||
num_experts, K13 // block_k, M13, block_k, dtype=torch.uint8, device=w13.device
|
||||
)
|
||||
w2_out = torch.empty(
|
||||
num_experts, K2 // block_k, M2, block_k, dtype=torch.uint8, device=w2.device
|
||||
)
|
||||
|
||||
for i in range(num_experts):
|
||||
t13 = shuffle_matrix_a(w13[i].view(torch.uint8), epilogue_tile_m)
|
||||
t13 = convert_to_block_layout(t13, block_k)
|
||||
w13_shuffled.append(t13)
|
||||
w13_out[i] = convert_to_block_layout(t13, block_k)
|
||||
|
||||
t2 = shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m)
|
||||
t2 = convert_to_block_layout(t2, block_k)
|
||||
w2_shuffled.append(t2)
|
||||
w2_out[i] = convert_to_block_layout(t2, block_k)
|
||||
|
||||
w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn)
|
||||
w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn)
|
||||
return w13_out, w2_out
|
||||
return w13_out.view(torch.float8_e4m3fn), w2_out.view(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def _shuffle_mxfp8_moe_weights(
|
||||
|
||||
Reference in New Issue
Block a user