[Performance][DeepGEMM] Estimate expected_m (#28694)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
c9e665852a
commit
6965ef436f
@@ -7,6 +7,7 @@ fp8 block-quantized case.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
import torch.distributed
|
||||
@@ -14,6 +15,7 @@ from torch.distributed import ProcessGroup
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
@@ -61,6 +63,23 @@ requires_deep_gemm = pytest.mark.skipif(
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def with_dp_metadata(M: int, world_size: int):
|
||||
num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config,
|
||||
num_tokens=M,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
|
||||
@@ -285,18 +304,21 @@ def deepep_deepgemm_moe_impl(
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
out = mk.forward(
|
||||
hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
with with_dp_metadata(
|
||||
M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
|
||||
):
|
||||
out = mk.forward(
|
||||
hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user