diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 7016ff34c..f6b62254e 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -144,12 +144,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): "apply_router_weight_on_input is only implemented for topk=1") a1 = a1 * topk_weights.to(a1.dtype) - if quant_config.per_act_token_quant: + if quant_config.is_block_quantized: + # Quant and Dispatch a1q, a1q_scale = moe_kernel_quantize_input( a1, a1_scale, quant_dtype=quant_config.quant_dtype, - per_act_token_quant=True, + per_act_token_quant=quant_config.per_act_token_quant, block_shape=quant_config.block_shape, ) if a1q_scale is not None and a1q_scale.numel() == 1: @@ -162,8 +163,10 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): rank_topk_weights=topk_weights, num_experts=num_experts) else: - # DeepEP kernels only support dispatching per-token-quant - # quantization. dispatch in bfloat16. + # Dispatch and Quant + # DeepEP kernels only support dispatching block-quantized + # activation scales. + # Dispatch in bfloat16 (expert_x, _, expert_tokens_meta, expert_topk_ids, expert_topk_weights) = self._do_dispatch( tokens=a1, @@ -171,7 +174,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): rank_topk_ids=topk_ids, rank_topk_weights=topk_weights, num_experts=num_experts) - # quantize now + # Quantize after dispatch. expert_x_scale = None if expert_x.numel() != 0: expert_x, expert_x_scale = moe_kernel_quantize_input(