[Bugfix][CPU] Fix llama4 inference on CPU (#34321)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -147,7 +147,7 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
|
||||
const int32_t token_num, const int32_t expert_num,
|
||||
const int32_t topk_num, const int32_t input_size_13,
|
||||
const int32_t output_size_13, const int32_t input_size_2,
|
||||
const int32_t output_size_2) {
|
||||
const int32_t output_size_2, const bool skip_weighted) {
|
||||
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
constexpr int32_t gemm_n_tile_size = gemm_t::NSize;
|
||||
constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize;
|
||||
@@ -582,6 +582,11 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
|
||||
scalar_t* __restrict__ curr_output_buffer =
|
||||
output + token_id * output_size_2;
|
||||
|
||||
if (skip_weighted) {
|
||||
// Only for topk_num == 1
|
||||
*curr_weight = 1.0f;
|
||||
}
|
||||
|
||||
if (topk_num > 1) {
|
||||
{
|
||||
int32_t w2_output_idx = curr_expand_token_id_index_buffer[0];
|
||||
@@ -699,7 +704,7 @@ void cpu_fused_moe(
|
||||
const std::optional<torch::Tensor>& w2_bias, // [expert_num, output_size_2]
|
||||
const torch::Tensor& topk_weights, // [token_num, k], float32
|
||||
const torch::Tensor& topk_id, // [token_num, k], int32
|
||||
const std::string& act, const std::string& isa) {
|
||||
const bool skip_weighted, const std::string& act, const std::string& isa) {
|
||||
const int32_t token_num = input.size(0);
|
||||
const int32_t input_size_13 = input.size(1);
|
||||
const int64_t input_stride = input.stride(0);
|
||||
@@ -711,6 +716,8 @@ void cpu_fused_moe(
|
||||
const int32_t topk_num = topk_id.size(1);
|
||||
const FusedMOEAct act_type = get_act_type(act);
|
||||
cpu_utils::ISA isa_type = cpu_utils::get_isa(isa);
|
||||
TORCH_CHECK(!skip_weighted || topk_num == 1,
|
||||
"skip_weighted is only supported for topk=1 on CPU");
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(w13.scalar_type(), "cpu_fused_moe", [&]() {
|
||||
CPU_ISA_DISPATCH_IMPL(isa_type, [&]() {
|
||||
@@ -721,7 +728,7 @@ void cpu_fused_moe(
|
||||
w2_bias.has_value() ? w2_bias->data_ptr<scalar_t>() : nullptr,
|
||||
topk_weights.data_ptr<float>(), topk_id.data_ptr<int32_t>(), act_type,
|
||||
token_num, expert_num, topk_num, input_size_13, output_size_13,
|
||||
input_size_2, output_size_2);
|
||||
input_size_2, output_size_2, skip_weighted);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user