[MoE Refactor][14/N] Clean Up FI Quant Config Smuggling (#31593)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2026-01-06 10:47:04 -05:00
committed by GitHub
parent d3e477c013
commit af8fd73051
7 changed files with 174 additions and 85 deletions

View File

@@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8,
flashinfer_cutlass_moe_fp8,
register_moe_scaling_factors,
register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31,
)
@@ -85,7 +85,7 @@ class TestData:
@staticmethod
def make_moe_tensors_8bit(
m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu"
m: int, k: int, n: int, e: int, is_trtllm: bool, activation: str = "silu"
) -> "TestData":
is_gated = activation != "relu2_no_mul"
@@ -123,12 +123,17 @@ class TestData:
all2all_backend="naive",
)
register_moe_scaling_factors(layer)
# flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if reorder:
if is_trtllm:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
layer.w13_weight_scale,
layer.w13_input_scale,
layer.w2_weight_scale,
layer.w2_input_scale,
)
layer.custom_routing_function = Llama4MoE.custom_routing_function
layer.intermediate_size_per_partition = n
layer.ep_rank = 0
@@ -162,7 +167,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = Llama4MoE.custom_routing_function(
@@ -227,7 +232,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(
m, k, n, e, reorder=False, activation=activation
m, k, n, e, is_trtllm=False, activation=activation
)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)