diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 27e646bcd..eaebf4e35 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -31,8 +31,6 @@ namespace moe { constexpr unsigned FULL_WARP_MASK = 0xffffffff; constexpr int32_t WARP_SIZE = 32; -constexpr int32_t BLOCK_SIZE = 512; -constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; namespace warp_topk { @@ -65,14 +63,6 @@ __forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index, return res; } -template -int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { - int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k; - int64_t n = std::max(num_of_warp / 2 * k, num_of_warp * WARP_SIZE); - return max(cache_topk, - round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT)); -} - template struct BitonicMerge { @@ -267,6 +257,15 @@ class WarpSort { } } + // Accessors for per-lane selected value/index. + // NOTE: For the common case `capacity == WARP_SIZE`, `max_arr_len_ == 1` + // and callers should use `i == 0`. + __device__ __forceinline__ idxT get_idx(int i = 0) const { + return idx_arr_[i]; + } + + __device__ __forceinline__ T get_val(int i = 0) const { return val_arr_[i]; } + protected: static constexpr int max_arr_len_ = capacity / WARP_SIZE; @@ -285,6 +284,7 @@ class WarpSelect : public WarpSort { __device__ WarpSelect(idxT k, T dummy) : WarpSort(k, dummy), k_th_(dummy), + k_th_idx_(0), k_th_lane_((k - 1) % WARP_SIZE) { extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[]; @@ -346,9 +346,6 @@ class WarpSelect : public WarpSort { idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0; merge_buf_(val, idx); } - - // after done(), smem is used for merging results among warps - __syncthreads(); } private: @@ -503,255 +500,186 @@ __device__ void topk_with_k2(T* output, T const* input, BiasT const* bias, } } -template -__global__ void topk_with_k2_kernel(T* output, T* input, BiasT const* bias, - int64_t const num_tokens, - int64_t const num_cases, - int64_t const n_group, - int64_t const num_experts_per_group) { - int32_t warp_id = threadIdx.x / WARP_SIZE; - int32_t lane_id = threadIdx.x % WARP_SIZE; - - int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; - if (case_id < num_cases) { - input += case_id * num_experts_per_group; - // bias is per expert group, offset to current group - int32_t group_id = case_id % n_group; - BiasT const* group_bias = bias + group_id * num_experts_per_group; - output += case_id; - - cg::thread_block block = cg::this_thread_block(); - cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); -#endif - topk_with_k2(output, input, group_bias, tile, lane_id, - num_experts_per_group); - } -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); -#endif -} - -template -__global__ void group_idx_and_topk_idx_kernel( - T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices, - BiasT const* bias, int64_t const num_tokens, int64_t const n_group, - int64_t const topk_group, int64_t const topk, int64_t const num_experts, - int64_t const num_experts_per_group, bool renormalize, +template +__global__ void grouped_topk_fused_kernel( + T* scores, float* topk_values, IdxT* topk_indices, BiasT const* bias, + int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, + int64_t const topk_group, int64_t const topk, bool renormalize, double routed_scaling_factor) { - int32_t warp_id = threadIdx.x / WARP_SIZE; - int32_t lane_id = threadIdx.x % WARP_SIZE; - int32_t case_id = - blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token - scores += case_id * num_experts; - group_scores += case_id * n_group; - topk_values += case_id * topk; - topk_indices += case_id * topk; + int32_t const token_id = static_cast(blockIdx.x); + if (token_id >= num_tokens) { + return; + } - constexpr bool kUseStaticNGroup = (NGroup > 0); - // use int32 to avoid implicit conversion - int32_t const n_group_i32 = - kUseStaticNGroup ? NGroup : static_cast(n_group); + int32_t const warp_id = threadIdx.x / WARP_SIZE; + int32_t const lane_id = threadIdx.x % WARP_SIZE; - int32_t align_num_experts_per_group = - warp_topk::round_up_to_multiple_of(num_experts_per_group); + int32_t const n_group_i32 = static_cast(n_group); + int32_t const topk_group_i32 = static_cast(topk_group); + int32_t const topk_i32 = static_cast(topk); + int32_t const num_experts_i32 = static_cast(num_experts); + + int32_t const num_warps = blockDim.x / WARP_SIZE; + if (warp_id >= n_group_i32 || num_warps < n_group_i32) { + return; + } + + int32_t const num_experts_per_group = num_experts_i32 / n_group_i32; + + T* scores_token = scores + static_cast(token_id) * num_experts; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); - extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to - // store the target topk idx - int32_t* s_topk_idx = reinterpret_cast(smem_buf); - T* s_topk_value = - reinterpret_cast(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + - warp_id * topk; - s_topk_idx += warp_id * topk; + extern __shared__ char smem_buf[]; + // warpSelect internal staging buffer layout + size_t const val_bytes = + static_cast(num_warps) * WARP_SIZE * sizeof(T); + size_t const val_bytes_aligned = + warp_topk::round_up_to_multiple_of<256>(val_bytes); + size_t const idx_bytes = + static_cast(num_warps) * WARP_SIZE * sizeof(int32_t); + size_t const internal_bytes = val_bytes_aligned + idx_bytes; - T value = neg_inf(); - T topk_group_value = neg_inf(); - int32_t num_equalto_topkth_group; + // user-managed shared memory starts after warpSelect internal staging. + uintptr_t ptr_u = reinterpret_cast(smem_buf + internal_bytes); + ptr_u = (ptr_u + 15) & ~static_cast(15); // align to 16B + T* s_group_scores = reinterpret_cast(ptr_u); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before // acqbulk because it's ptr arithmetic #endif - if (case_id < num_tokens) { - // calculate group_idx - int32_t target_num_min = - WARP_SIZE - n_group_i32 + static_cast(topk_group); - // The check is necessary to avoid abnormal input - if (lane_id < n_group_i32 && is_finite(group_scores[lane_id])) { - value = group_scores[lane_id]; - } + // phase 1: per-group scan + int32_t const group_offset = warp_id * num_experts_per_group; + topk_with_k2(s_group_scores + warp_id, + scores_token + group_offset, bias + group_offset, + tile, lane_id, num_experts_per_group); - int count_equal_to_top_value = WARP_SIZE - n_group_i32; - int pre_count_equal_to_top_value = 0; - // Use loop to find the largset top_group - while (count_equal_to_top_value < target_num_min) { - topk_group_value = cg::reduce(tile, value, cg::greater()); - if (value == topk_group_value) { - value = neg_inf(); - } - pre_count_equal_to_top_value = count_equal_to_top_value; - count_equal_to_top_value = - __popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf()))); - } - num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; - } __syncthreads(); + // phase 2: warp0 selects groups + merges candidates to final topk + if (warp_id != 0) { + return; + } + + topk_values += static_cast(token_id) * topk; + topk_indices += static_cast(token_id) * topk; + + // select topk_group groups by group score warp_topk::WarpSelect - queue((int32_t)topk, neg_inf()); + group_sel(static_cast(topk_group_i32), neg_inf()); - int count_equalto_topkth_group = 0; - bool if_proceed_next_topk = topk_group_value != neg_inf(); - if (case_id < num_tokens && if_proceed_next_topk) { - auto process_group = [&](int i_group) { - if ((group_scores[i_group] > topk_group_value) || - ((group_scores[i_group] == topk_group_value) && - (count_equalto_topkth_group < num_equalto_topkth_group))) { - int32_t offset = i_group * num_experts_per_group; - for (int32_t i = lane_id; i < align_num_experts_per_group; - i += WARP_SIZE) { - T candidates = neg_inf(); - if (i < num_experts_per_group) { - // apply scoring function (if any) and add bias - T input = scores[offset + i]; - if (is_finite(input)) { - T score = apply_scoring(input); - candidates = score + static_cast(bias[offset + i]); - } - } - queue.add(candidates, offset + i); - } - if (group_scores[i_group] == topk_group_value) { - count_equalto_topkth_group++; + // all lanes must participate in WarpSelect::add(). + T gscore = (lane_id < n_group_i32) ? s_group_scores[lane_id] : neg_inf(); + group_sel.add(gscore, lane_id); + group_sel.done(); + + // proceed only if the k-th selected group score is not -inf + bool proceed = false; + if (topk_group_i32 > 0) { + int const kth_lane = topk_group_i32 - 1; + // broadcast the k-th selected group score to all lanes + T kth_val = __shfl_sync(FULL_WARP_MASK, group_sel.get_val(0), kth_lane); + proceed = (kth_val != neg_inf()); + } + + if (!proceed) { + for (int i = lane_id; i < topk_i32; i += WARP_SIZE) { + topk_indices[i] = static_cast(i); + topk_values[i] = 1.0f / static_cast(topk_i32); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif + return; + } + + // merge per-group topk candidates for selected groups, then select topk + warp_topk::WarpSelect + expert_sel(static_cast(topk_i32), neg_inf()); + + // selected group ids reside in lanes [0, topk_group) + int32_t sel_gid_lane = (lane_id < topk_group_i32) ? group_sel.get_idx(0) : 0; + + // add candidates from selected groups to expert_sel + for (int32_t g = 0; g < topk_group_i32; ++g) { + int32_t gid = __shfl_sync(FULL_WARP_MASK, sel_gid_lane, g); + int32_t const offset = gid * num_experts_per_group; + int32_t const align_num_experts_per_group = + warp_topk::round_up_to_multiple_of(num_experts_per_group); + for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) { + // all lanes must call `add()` the same number of times. + T cand = neg_inf(); + int32_t idx = 0; + if (i < num_experts_per_group) { + idx = offset + i; + T input = scores_token[idx]; + if (is_finite(input)) { + T score = apply_scoring(input); + cand = score + static_cast(bias[idx]); } } - }; - - if constexpr (kUseStaticNGroup) { -#pragma unroll - for (int i_group = 0; i_group < NGroup; ++i_group) { - process_group(i_group); - } - } else { - for (int i_group = 0; i_group < n_group_i32; ++i_group) { - process_group(i_group); - } - } - queue.done(); - // Get the topk_idx - queue.dumpIdx(s_topk_idx); - } - - // Load the valid score value - // Calculate the summation - float topk_sum = 1e-20; - if (case_id < num_tokens && if_proceed_next_topk) { - for (int i = lane_id; - i < warp_topk::round_up_to_multiple_of(topk); - i += WARP_SIZE) { - T value = cuda_cast(0.0f); - if (i < topk) { - // Load the score value (without bias) for normalization - T input = scores[s_topk_idx[i]]; - value = apply_scoring(input); - s_topk_value[i] = value; - } - if (renormalize) { - topk_sum += - cg::reduce(tile, cuda_cast(value), cg::plus()); - } + expert_sel.add(cand, idx); } } + expert_sel.done(); - __syncthreads(); - - if (case_id < num_tokens) { - if (if_proceed_next_topk) { - float scale = routed_scaling_factor; - if (renormalize) { - scale /= topk_sum; - } - for (int i = lane_id; i < topk; i += WARP_SIZE) { - float base = cuda_cast(s_topk_value[i]); - float value = base * scale; - topk_indices[i] = s_topk_idx[i]; - topk_values[i] = value; - } - } else { - for (int i = lane_id; i < topk; i += WARP_SIZE) { - topk_indices[i] = i; - topk_values[i] = 1.0f / topk; - } - } - // Note: when if_proceed_next_topk==false, choose the first 8 experts as the - // default result. + // compute unbiased routing weights + optional renorm. + float lane_unbiased = 0.0f; + IdxT lane_idx = 0; + if (lane_id < topk_i32) { + lane_idx = static_cast(expert_sel.get_idx(0)); + T in = scores_token[static_cast(lane_idx)]; + lane_unbiased = cuda_cast(apply_scoring(in)); } + + float topk_sum = 1e-20f; + if (renormalize) { + topk_sum += cg::reduce(tile, lane_unbiased, cg::plus()); + } + + float scale = static_cast(routed_scaling_factor); + if (renormalize) { + scale /= topk_sum; + } + + if (lane_id < topk_i32) { + topk_indices[lane_id] = lane_idx; + topk_values[lane_id] = lane_unbiased * scale; + } + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } -template -inline void launch_group_idx_and_topk_kernel( - cudaLaunchConfig_t const& config, T* scores, T* group_scores, - float* topk_values, IdxT* topk_indices, BiasT const* bias, - int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, - int64_t const topk, int64_t const num_experts, - int64_t const num_experts_per_group, bool const renormalize, - double const routed_scaling_factor) { - auto launch = [&](auto* kernel_instance2) { - cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, bias, num_tokens, n_group, - topk_group, topk, num_experts, num_experts_per_group, - renormalize, routed_scaling_factor); - }; - - switch (n_group) { - case 4: { - launch(&group_idx_and_topk_idx_kernel); - break; - } - case 8: { - launch(&group_idx_and_topk_idx_kernel); - break; - } - case 16: { - launch(&group_idx_and_topk_idx_kernel); - break; - } - case 32: { - launch(&group_idx_and_topk_idx_kernel); - break; - } - default: { - launch(&group_idx_and_topk_idx_kernel); - break; - } - } -} - template -void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, - IdxT* topk_indices, BiasT const* bias, - int64_t const num_tokens, int64_t const num_experts, - int64_t const n_group, int64_t const topk_group, - int64_t const topk, bool const renormalize, - double const routed_scaling_factor, int const scoring_func, - bool enable_pdl = false, cudaStream_t const stream = 0) { - int64_t num_cases = num_tokens * n_group; - int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; +void invokeNoAuxTc(T* scores, float* topk_values, IdxT* topk_indices, + BiasT const* bias, int64_t const num_tokens, + int64_t const num_experts, int64_t const n_group, + int64_t const topk_group, int64_t const topk, + bool const renormalize, double const routed_scaling_factor, + int const scoring_func, bool enable_pdl = false, + cudaStream_t const stream = 0) { cudaLaunchConfig_t config; - config.gridDim = topk_with_k2_num_blocks; - config.blockDim = BLOCK_SIZE; - config.dynamicSmemBytes = 0; + // One block per token; one warp per group. + config.gridDim = static_cast(num_tokens); + config.blockDim = static_cast(n_group) * WARP_SIZE; + // Dynamic shared memory: WarpSelect staging + per-group topk buffers. + int32_t const num_warps = static_cast(n_group); + size_t const val_bytes = + static_cast(num_warps) * WARP_SIZE * sizeof(T); + size_t const val_bytes_aligned = + warp_topk::round_up_to_multiple_of<256>(val_bytes); + size_t const idx_bytes = + static_cast(num_warps) * WARP_SIZE * sizeof(int32_t); + size_t const internal_bytes = val_bytes_aligned + idx_bytes; + size_t const extra_bytes = 16 + static_cast(n_group) * sizeof(T); + config.dynamicSmemBytes = internal_bytes + extra_bytes; config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; @@ -759,66 +687,35 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, config.numAttrs = 1; config.attrs = attrs; auto const sf = static_cast(scoring_func); - int64_t const num_experts_per_group = num_experts / n_group; - auto launch_topk_with_k2 = [&](auto* kernel_instance1) { - cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, - num_tokens, num_cases, n_group, num_experts_per_group); - }; switch (sf) { case SCORING_NONE: { - auto* kernel_instance1 = &topk_with_k2_kernel; - launch_topk_with_k2(kernel_instance1); - break; + auto* kernel_instance = + &grouped_topk_fused_kernel; + cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values, + topk_indices, bias, num_tokens, num_experts, n_group, + topk_group, topk, renormalize, routed_scaling_factor); + return; } case SCORING_SIGMOID: { - auto* kernel_instance1 = &topk_with_k2_kernel; - launch_topk_with_k2(kernel_instance1); - break; + auto* kernel_instance = + &grouped_topk_fused_kernel; + cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values, + topk_indices, bias, num_tokens, num_experts, n_group, + topk_group, topk, renormalize, routed_scaling_factor); + return; } default: // should be guarded by higher level checks. TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc"); } - - int64_t topk_with_k_group_num_blocks = - (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; - size_t dynamic_smem_in_bytes = - warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, - topk); - config.gridDim = topk_with_k_group_num_blocks; - config.blockDim = BLOCK_SIZE; - config.dynamicSmemBytes = dynamic_smem_in_bytes; - config.stream = stream; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; - config.numAttrs = 1; - config.attrs = attrs; - switch (sf) { - case SCORING_NONE: { - launch_group_idx_and_topk_kernel( - config, scores, group_scores, topk_values, topk_indices, bias, - num_tokens, n_group, topk_group, topk, num_experts, - num_experts_per_group, renormalize, routed_scaling_factor); - break; - } - case SCORING_SIGMOID: { - launch_group_idx_and_topk_kernel( - config, scores, group_scores, topk_values, topk_indices, bias, - num_tokens, n_group, topk_group, topk, num_experts, - num_experts_per_group, renormalize, routed_scaling_factor); - break; - } - default: - TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc"); - } } -#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \ - template void invokeNoAuxTc( \ - T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \ - BiasT const* bias, int64_t const num_tokens, int64_t const num_experts, \ - int64_t const n_group, int64_t const topk_group, int64_t const topk, \ - bool const renormalize, double const routed_scaling_factor, \ +#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \ + template void invokeNoAuxTc( \ + T * scores, float* topk_values, IdxT* topk_indices, BiasT const* bias, \ + int64_t const num_tokens, int64_t const num_experts, \ + int64_t const n_group, int64_t const topk_group, int64_t const topk, \ + bool const renormalize, double const routed_scaling_factor, \ int const scoring_func, bool enable_pdl, cudaStream_t const stream); INSTANTIATE_NOAUX_TC(float, float, int32_t); @@ -843,17 +740,21 @@ std::tuple grouped_topk( int64_t num_tokens = input_size[0]; int64_t num_experts = input_size[1]; TORCH_CHECK(input_size.size() == 2, "scores must be a 2D Tensor"); + TORCH_CHECK(n_group > 0, "n_group must be positive"); + TORCH_CHECK(topk > 0, "topk must be positive"); + TORCH_CHECK(topk_group > 0, "topk_group must be positive"); + TORCH_CHECK(topk_group <= n_group, "topk_group must be <= n_group"); TORCH_CHECK(num_experts % n_group == 0, "num_experts should be divisible by n_group"); TORCH_CHECK(n_group <= 32, "n_group should be smaller than or equal to 32 for now"); TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); + TORCH_CHECK(topk <= topk_group * (num_experts / n_group), + "topk must be <= topk_group * (num_experts / n_group)"); TORCH_CHECK(scoring_func == vllm::moe::SCORING_NONE || scoring_func == vllm::moe::SCORING_SIGMOID, "scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)"); - torch::Tensor group_scores = torch::empty( - {num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA)); // Always output float32 for topk_values (eliminates Python-side conversion) torch::Tensor topk_values = torch::empty( {num_tokens, topk}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); @@ -868,7 +769,6 @@ std::tuple grouped_topk( case torch::kFloat16: \ vllm::moe::invokeNoAuxTc( \ reinterpret_cast(scores.mutable_data_ptr()), \ - reinterpret_cast(group_scores.mutable_data_ptr()), \ reinterpret_cast(topk_values.mutable_data_ptr()), \ reinterpret_cast(topk_indices.mutable_data_ptr()), \ reinterpret_cast(bias.data_ptr()), num_tokens, \ @@ -879,7 +779,6 @@ std::tuple grouped_topk( case torch::kFloat32: \ vllm::moe::invokeNoAuxTc( \ reinterpret_cast(scores.mutable_data_ptr()), \ - reinterpret_cast(group_scores.mutable_data_ptr()), \ reinterpret_cast(topk_values.mutable_data_ptr()), \ reinterpret_cast(topk_indices.mutable_data_ptr()), \ reinterpret_cast(bias.data_ptr()), num_tokens, \ @@ -890,7 +789,6 @@ std::tuple grouped_topk( case torch::kBFloat16: \ vllm::moe::invokeNoAuxTc( \ reinterpret_cast(scores.mutable_data_ptr()), \ - reinterpret_cast(group_scores.mutable_data_ptr()), \ reinterpret_cast(topk_values.mutable_data_ptr()), \ reinterpret_cast(topk_indices.mutable_data_ptr()), \ reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), \ diff --git a/tests/models/utils.py b/tests/models/utils.py index 1b820d284..4830f18dc 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -454,6 +454,9 @@ def dummy_hf_overrides( # Ensure at least 2 expert per group # Since `grouped_topk` assumes top-2 n_group = getattr(text_config, "n_group", None) + # Kimi uses `num_expert_group` instead of `n_group`. + if n_group is None: + n_group = getattr(text_config, "num_expert_group", None) num_experts = n_group * 2 if n_group is not None else 2 # we use three layers for Gemma-3n to check @@ -487,6 +490,8 @@ def dummy_hf_overrides( { "num_experts": num_experts, "num_experts_per_tok": 2, + # Kimi uses `num_experts_per_token`. + "num_experts_per_token": 2, "num_local_experts": num_experts, # Otherwise there will not be any expert layers "first_k_dense_replace": 0,