diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 13c82893d..0e39dc881 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -322,20 +322,23 @@ def _shuffle_deepseek_fp8_moe_weights( block_k = 128 num_experts = w13.shape[0] - w13_shuffled: list[torch.Tensor] = [] - w2_shuffled: list[torch.Tensor] = [] + M13, K13 = w13.shape[1], w13.shape[2] + M2, K2 = w2.shape[1], w2.shape[2] + w13_out = torch.empty( + num_experts, K13 // block_k, M13, block_k, dtype=torch.uint8, device=w13.device + ) + w2_out = torch.empty( + num_experts, K2 // block_k, M2, block_k, dtype=torch.uint8, device=w2.device + ) + for i in range(num_experts): t13 = shuffle_matrix_a(w13[i].view(torch.uint8), epilogue_tile_m) - t13 = convert_to_block_layout(t13, block_k) - w13_shuffled.append(t13) + w13_out[i] = convert_to_block_layout(t13, block_k) t2 = shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m) - t2 = convert_to_block_layout(t2, block_k) - w2_shuffled.append(t2) + w2_out[i] = convert_to_block_layout(t2, block_k) - w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn) - w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn) - return w13_out, w2_out + return w13_out.view(torch.float8_e4m3fn), w2_out.view(torch.float8_e4m3fn) def _shuffle_mxfp8_moe_weights(