diff --git a/.gitignore b/.gitignore index 375b1b7eb..8e864d090 100644 --- a/.gitignore +++ b/.gitignore @@ -238,3 +238,6 @@ ep_kernels_workspace/ vllm/grpc/vllm_engine_pb2.py vllm/grpc/vllm_engine_pb2_grpc.py vllm/grpc/vllm_engine_pb2.pyi + +# Ignore generated cpu headers +csrc/cpu/cpu_attn_dispatch_generated.h diff --git a/csrc/cpu/cpu_fused_moe.cpp b/csrc/cpu/cpu_fused_moe.cpp index 090e2d4cd..1a8264539 100644 --- a/csrc/cpu/cpu_fused_moe.cpp +++ b/csrc/cpu/cpu_fused_moe.cpp @@ -147,7 +147,7 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input, const int32_t token_num, const int32_t expert_num, const int32_t topk_num, const int32_t input_size_13, const int32_t output_size_13, const int32_t input_size_2, - const int32_t output_size_2) { + const int32_t output_size_2, const bool skip_weighted) { using scalar_vec_t = typename cpu_utils::VecTypeTrait::vec_t; constexpr int32_t gemm_n_tile_size = gemm_t::NSize; constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize; @@ -582,6 +582,11 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input, scalar_t* __restrict__ curr_output_buffer = output + token_id * output_size_2; + if (skip_weighted) { + // Only for topk_num == 1 + *curr_weight = 1.0f; + } + if (topk_num > 1) { { int32_t w2_output_idx = curr_expand_token_id_index_buffer[0]; @@ -699,7 +704,7 @@ void cpu_fused_moe( const std::optional& w2_bias, // [expert_num, output_size_2] const torch::Tensor& topk_weights, // [token_num, k], float32 const torch::Tensor& topk_id, // [token_num, k], int32 - const std::string& act, const std::string& isa) { + const bool skip_weighted, const std::string& act, const std::string& isa) { const int32_t token_num = input.size(0); const int32_t input_size_13 = input.size(1); const int64_t input_stride = input.stride(0); @@ -711,6 +716,8 @@ void cpu_fused_moe( const int32_t topk_num = topk_id.size(1); const FusedMOEAct act_type = get_act_type(act); cpu_utils::ISA isa_type = cpu_utils::get_isa(isa); + TORCH_CHECK(!skip_weighted || topk_num == 1, + "skip_weighted is only supported for topk=1 on CPU"); VLLM_DISPATCH_FLOATING_TYPES(w13.scalar_type(), "cpu_fused_moe", [&]() { CPU_ISA_DISPATCH_IMPL(isa_type, [&]() { @@ -721,7 +728,7 @@ void cpu_fused_moe( w2_bias.has_value() ? w2_bias->data_ptr() : nullptr, topk_weights.data_ptr(), topk_id.data_ptr(), act_type, token_num, expert_num, topk_num, input_size_13, output_size_13, - input_size_2, output_size_2); + input_size_2, output_size_2, skip_weighted); }); }); } diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index b54447b7d..11e1305c6 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -119,8 +119,8 @@ void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input, const std::optional& w13_bias, const std::optional& w2_bias, const torch::Tensor& topk_weights, - const torch::Tensor& topk_id, const std::string& act, - const std::string& isa); + const torch::Tensor& topk_id, const bool skip_weighted, + const std::string& act, const std::string& isa); TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -320,6 +320,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, " "Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, " + "bool skip_weighted, " "str act, str isa) -> ()"); ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe); #endif diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ea44beda5..d04edf8e2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -3078,6 +3078,7 @@ def cpu_fused_moe( topk_ids: torch.Tensor, act: str, isa: str, + skip_weighted: bool = False, ) -> torch.Tensor: output = torch.empty_like(input) torch.ops._C.cpu_fused_moe( @@ -3089,6 +3090,7 @@ def cpu_fused_moe( w2_bias, topk_weights, topk_ids, + skip_weighted, act, isa, ) diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index e929074d5..127538822 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -238,7 +238,6 @@ class CPUFusedMOE: activation: str = "silu", ) -> torch.Tensor: assert activation in _CPU_MOE_ACT_FN, f"{activation} is not supported." - assert not apply_router_weight_on_input topk_weights, topk_ids = select_experts( hidden_states=x, @@ -261,6 +260,7 @@ class CPUFusedMOE: topk_ids, activation, global_num_experts, + apply_router_weight_on_input, ) def check_grouped_gemm( @@ -355,7 +355,14 @@ class CPUFusedMOE: topk_ids: torch.Tensor, activation: str, global_num_experts: int = -1, + skip_weighted: bool = False, ) -> torch.Tensor: + if skip_weighted: + assert topk_ids.size(1) == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + input.mul_(topk_weights.to(input.dtype)) + output = cpu_fused_moe( input, layer.w13_weight, @@ -366,6 +373,7 @@ class CPUFusedMOE: topk_ids, activation, self.isa, + skip_weighted, ) return output @@ -377,7 +385,14 @@ class CPUFusedMOE: topk_ids: torch.Tensor, activation: str, global_num_experts: int = -1, + skip_weighted: bool = False, ) -> torch.Tensor: + if skip_weighted: + assert topk_ids.size(1) == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + input.mul_(topk_weights.to(input.dtype)) + output = torch.empty_like(input) layer_id = id(layer) torch.ops.vllm.cpu_fused_moe_torch( @@ -388,6 +403,7 @@ class CPUFusedMOE: topk_ids, activation, global_num_experts, + skip_weighted, ) return output @@ -401,6 +417,7 @@ def cpu_fused_moe_torch( topk_ids: torch.Tensor, activation: str, global_num_experts: int = -1, + skip_weighted: bool = False, ) -> None: layer = _CPU_MOE_LAYER_CACHE[layer_id]() @@ -434,13 +451,16 @@ def cpu_fused_moe_torch( new_x = torch.empty_like(outs) new_x[idxs] = outs - final_out = ( - new_x.view(*topk_ids.shape, -1) - .type(topk_weights.dtype) - .mul_(topk_weights.unsqueeze(dim=-1)) - .sum(dim=1) - .type(new_x.dtype) - ) + if skip_weighted: + final_out = new_x + else: + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weights.dtype) + .mul_(topk_weights.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) output.copy_(final_out) diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index 8ccd45bb0..2fbcc9c44 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -160,12 +160,21 @@ class CPUWorker(Worker): x for x in logical_cpu_list if x.numa_node == selected_numa_node ] else: - assert len(logical_cpu_list) >= self.parallel_config.world_size - logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.numa_node) - sim_cpu_num_per_node = ( - len(logical_cpu_list) // self.parallel_config.world_size + # This is a bit tricky because the internal DP size + # is always 1 for non-MoE models + world_size_across_dp = ( + self.parallel_config.world_size + * self.parallel_config._api_process_count ) - start_idx = self.local_rank * sim_cpu_num_per_node + assert len(logical_cpu_list) >= world_size_across_dp + logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.numa_node) + sim_cpu_num_per_node = len(logical_cpu_list) // world_size_across_dp + assert self.parallel_config.data_parallel_rank_local is not None + start_idx = ( + self.local_rank + + self.parallel_config.world_size + * self.parallel_config.data_parallel_rank_local + ) * sim_cpu_num_per_node logical_cpu_list = logical_cpu_list[ start_idx : (start_idx + sim_cpu_num_per_node) ]