[Performance][DeepGEMM] Estimate expected_m (#28694)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-11-15 00:52:14 -05:00
committed by GitHub
parent c9e665852a
commit 6965ef436f
3 changed files with 73 additions and 17 deletions

View File

@@ -7,6 +7,7 @@ fp8 block-quantized case.
"""
import dataclasses
from contextlib import contextmanager
import pytest
import torch.distributed
@@ -14,6 +15,7 @@ from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
@@ -61,6 +63,23 @@ requires_deep_gemm = pytest.mark.skipif(
P = ParamSpec("P")
@contextmanager
def with_dp_metadata(M: int, world_size: int):
num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int)
vllm_config = VllmConfig()
vllm_config.parallel_config.data_parallel_size = world_size
vllm_config.parallel_config.enable_expert_parallel = True
with set_forward_context(
None,
vllm_config,
num_tokens=M,
num_tokens_across_dp=num_tokens_across_dp,
):
yield
def next_power_of_2(x):
import math
@@ -285,18 +304,21 @@ def deepep_deepgemm_moe_impl(
quant_config=quant_config,
)
out = mk.forward(
hidden_states=test_tensors.rank_tokens,
w1=w1,
w2=w2,
topk_weights=test_tensors.topk_weights,
topk_ids=test_tensors.topk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False,
)
with with_dp_metadata(
M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
):
out = mk.forward(
hidden_states=test_tensors.rank_tokens,
w1=w1,
w2=w2,
topk_weights=test_tensors.topk_weights,
topk_ids=test_tensors.topk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False,
)
return out