[MoE Refactor][14/N] Clean Up FI Quant Config Smuggling (#31593)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8,
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
register_moe_scaling_factors,
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe,
|
||||
rotate_flashinfer_fp8_moe_weights,
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
@@ -85,7 +85,7 @@ class TestData:
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(
|
||||
m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu"
|
||||
m: int, k: int, n: int, e: int, is_trtllm: bool, activation: str = "silu"
|
||||
) -> "TestData":
|
||||
is_gated = activation != "relu2_no_mul"
|
||||
|
||||
@@ -123,12 +123,17 @@ class TestData:
|
||||
all2all_backend="naive",
|
||||
)
|
||||
|
||||
register_moe_scaling_factors(layer)
|
||||
|
||||
# flashinfer expects swapped rows for w13
|
||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||
if reorder:
|
||||
if is_trtllm:
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer,
|
||||
layer.w13_weight_scale,
|
||||
layer.w13_input_scale,
|
||||
layer.w2_weight_scale,
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
layer.custom_routing_function = Llama4MoE.custom_routing_function
|
||||
layer.intermediate_size_per_partition = n
|
||||
layer.ep_rank = 0
|
||||
@@ -162,7 +167,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
set_random_seed(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
|
||||
td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids = Llama4MoE.custom_routing_function(
|
||||
@@ -227,7 +232,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(
|
||||
m, k, n, e, reorder=False, activation=activation
|
||||
m, k, n, e, is_trtllm=False, activation=activation
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
Reference in New Issue
Block a user