diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 840663703..fadf56be1 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -348,7 +348,7 @@ def flashinfer_trtllm_fp4_moe( hidden_states=hidden_states_fp4, hidden_states_scale=hidden_states_scale_linear_fp4.view( torch.float8_e4m3fn - ).flatten(), + ).reshape(*hidden_states_fp4.shape[:-1], -1), gemm1_weights=layer.w13_weight.data, gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn), gemm1_bias=None, @@ -432,7 +432,7 @@ def flashinfer_trtllm_fp4_routed_moe( hidden_states=hidden_states_fp4, hidden_states_scale=hidden_states_scale_linear_fp4.view( torch.float8_e4m3fn - ).flatten(), + ).reshape(*hidden_states_fp4.shape[:-1], -1), gemm1_weights=layer.w13_weight.data, gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn), gemm1_bias=None,