diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index cf46eefa3..6f490f00b 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -112,6 +112,24 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): ] return (weight_key, activation_key) in SUPPORTED_W_A + def moe_problem_size( + self, + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + ) -> tuple[int, int, int, int, int]: + """Override to handle 4D BlockMajorK weights (E, K/bk, Mn, bk).""" + if w1.dim() == 4: + # BlockMajorK: (E, K/bk, Mn, bk) + E = w1.shape[0] + N = w1.shape[2] + K = a1.size(-1) + M = a1.size(0) if a1.dim() == 2 else a1.size(1) + topk = topk_ids.size(1) + return E, M, N, K, topk + return super().moe_problem_size(a1, w1, w2, topk_ids) + def workspace_shapes( self, M: int, @@ -152,7 +170,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): apply_router_weight_on_input: bool, ): import flashinfer - from flashinfer.fused_moe import Fp8QuantizationType + from flashinfer.fused_moe import Fp8QuantizationType, WeightLayout # Pack topk ids and weights into format expected by the kernel. packed_topk_ids = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights) @@ -170,10 +188,12 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): if is_mxfp8: fp8_quant_type = Fp8QuantizationType.MxFp8 use_shuffled_weight = True + weight_layout = WeightLayout.MajorK hidden_states_scale = a1q_scale else: fp8_quant_type = Fp8QuantizationType.DeepSeekFp8 - use_shuffled_weight = False + use_shuffled_weight = True + weight_layout = WeightLayout.BlockMajorK hidden_states_scale = a1q_scale.t().contiguous() # `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the @@ -199,7 +219,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): routed_scaling_factor=None, routing_method_type=1, use_shuffled_weight=use_shuffled_weight, - weight_layout=0, + weight_layout=weight_layout, fp8_quantization_type=fp8_quant_type, # output=output, ) @@ -322,7 +342,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit topk_group: int | None = None, ) -> torch.Tensor: import flashinfer - from flashinfer.fused_moe import Fp8QuantizationType + from flashinfer.fused_moe import Fp8QuantizationType, WeightLayout assert not apply_router_weight_on_input assert activation == MoEActivation.SILU @@ -342,10 +362,12 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit if is_mxfp8: fp8_quant_type = Fp8QuantizationType.MxFp8 use_shuffled_weight = True + weight_layout = WeightLayout.MajorK hidden_states_scale = a1q_scale else: fp8_quant_type = Fp8QuantizationType.DeepSeekFp8 - use_shuffled_weight = False + use_shuffled_weight = True + weight_layout = WeightLayout.BlockMajorK hidden_states_scale = a1q_scale.t().contiguous() return flashinfer.fused_moe.trtllm_fp8_block_scale_moe( @@ -367,6 +389,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit routed_scaling_factor=routed_scaling_factor, routing_method_type=self.routing_method_type, use_shuffled_weight=use_shuffled_weight, + weight_layout=weight_layout, fp8_quantization_type=fp8_quant_type, ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 66827488f..13c82893d 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -305,6 +305,39 @@ def align_fp8_moe_weights_for_fi( return padded_w13, padded_w2, padded_intermediate +def _shuffle_deepseek_fp8_moe_weights( + w13: torch.Tensor, + w2: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Preprocess DeepSeek FP8 block-scale weights for the FlashInfer TRT-LLM + kernel using the shuffle + BlockMajorK layout variant. + + Returns 4D weight tensors in BlockMajorK layout + (E, K/block_k, Mn, block_k) + """ + from flashinfer import shuffle_matrix_a + from flashinfer.fused_moe import convert_to_block_layout + + epilogue_tile_m = 64 + block_k = 128 + num_experts = w13.shape[0] + + w13_shuffled: list[torch.Tensor] = [] + w2_shuffled: list[torch.Tensor] = [] + for i in range(num_experts): + t13 = shuffle_matrix_a(w13[i].view(torch.uint8), epilogue_tile_m) + t13 = convert_to_block_layout(t13, block_k) + w13_shuffled.append(t13) + + t2 = shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m) + t2 = convert_to_block_layout(t2, block_k) + w2_shuffled.append(t2) + + w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn) + w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn) + return w13_out, w2_out + + def _shuffle_mxfp8_moe_weights( w13: torch.Tensor, w2: torch.Tensor, @@ -405,6 +438,7 @@ def prepare_fp8_moe_layer_for_fi( hasattr(layer, "weight_block_size") and layer.weight_block_size is not None ) is_mxfp8 = block_quant and w13_scale.dtype == torch.uint8 + is_deepseek_fp8 = block_quant and not is_mxfp8 is_gated = layer.activation.is_gated # MXFP8 TRT-LLM requires W31 swap + reorder + shuffle. @@ -447,6 +481,10 @@ def prepare_fp8_moe_layer_for_fi( if block_quant: w13_scale = swap_w13_to_w31(w13_scale) + # DeepSeekFp8 TRT-LLM: shuffle weights into BlockMajorK layout. + if is_deepseek_fp8 and is_trtllm: + w13, w2 = _shuffle_deepseek_fp8_moe_weights(w13, w2) + # FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle # and registration of alpha scales. if is_trtllm and not block_quant: diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 1cafccd49..07476906e 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -13,7 +13,7 @@ import vllm.envs as envs from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M -from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod +from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, ) @@ -168,14 +168,12 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: ): return False - if not isinstance(module.quant_method, FusedMoEModularMethod): - # modular kernels could invoke deep_gemm_moe_fp8 - return True + moe_kernel = getattr(module.quant_method, "moe_kernel", None) + if moe_kernel is None: + return False - # Further check if the ModularKernel implementation uses the DeepGemmExperts - return isinstance( - module.quant_method.moe_kernel, (DeepGemmExperts, TritonOrDeepGemmExperts) - ) + fused_experts = moe_kernel.impl.fused_experts + return isinstance(fused_experts, (DeepGemmExperts, TritonOrDeepGemmExperts)) FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()