diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 40cc6d2ce..a5c5c115f 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -336,7 +336,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): apply_router_weight_on_input=apply_router_weight_on_input, ) dbo_yield_and_switch_from_compute_to_comm() + assert fused_expert_output.dtype == torch.bfloat16, ( + f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}" + ) combined_x, _, event = self.buffer.combine( + # HT combine only supports BF16 x=fused_expert_output, handle=handle, topk_weights=None, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a0ed88309..0fa98b1c7 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -984,7 +984,7 @@ class FusedMoEModularKernel(torch.nn.Module): assert num_chunks == 0 workspace13 = None workspace2 = None - fused_out = torch.empty_like(a1q) + fused_out = torch.empty_like(a1q, dtype=in_dtype) else: assert num_chunks > 0 workspace13, workspace2, fused_out = self._allocate_buffers(