[Performance][DeepGEMM] Estimate expected_m (#28694)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
c9e665852a
commit
6965ef436f
@@ -5,6 +5,7 @@
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
@@ -19,7 +20,7 @@ from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -313,6 +314,33 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
||||
return (workspace13, workspace2, output)
|
||||
|
||||
def estimate_expected_m(
|
||||
self, global_num_experts: int, max_tokens_per_expert: int, topk: int
|
||||
) -> int:
|
||||
dp_meta = (
|
||||
get_forward_context().dp_metadata
|
||||
if is_forward_context_available()
|
||||
else None
|
||||
)
|
||||
if dp_meta is None:
|
||||
logger.warning_once(
|
||||
"DPMetadata unavailable. Defaulting expected_m to "
|
||||
f"{max_tokens_per_expert}.",
|
||||
scope="local",
|
||||
)
|
||||
return max_tokens_per_expert
|
||||
|
||||
total_num_tokens = dp_meta.num_tokens_across_dp_cpu.sum().item()
|
||||
total_num_tokens_replicated = total_num_tokens * topk
|
||||
|
||||
# Assume even load balancing
|
||||
assert global_num_experts != 0
|
||||
estimate = round_up(int(total_num_tokens_replicated // global_num_experts), 16)
|
||||
# clamp estimate
|
||||
estimate = max(estimate, 16)
|
||||
estimate = min(max_tokens_per_expert, estimate)
|
||||
return estimate
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
@@ -348,10 +376,12 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
||||
|
||||
# (from deepgemm docs) : A value hint (which is a value on CPU)
|
||||
# for the M expectation of each batch, correctly setting this value
|
||||
# may lead to better performance.
|
||||
expected_m = max_num_tokens
|
||||
expected_m = self.estimate_expected_m(
|
||||
global_num_experts=global_num_experts,
|
||||
max_tokens_per_expert=max_num_tokens,
|
||||
topk=topk_ids.size(-1),
|
||||
)
|
||||
|
||||
fp8_m_grouped_gemm_nt_masked(
|
||||
(a1q, a1q_scale),
|
||||
(w1, self.w1_scale),
|
||||
|
||||
Reference in New Issue
Block a user