[EP+DP] Optimize the little operations in the DeepGEMM + DeepEP low latency case (#19885)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Tyler Michael Smith
2025-06-23 14:07:47 -04:00
committed by GitHub
parent c3649e4fee
commit 68aaeb3749
3 changed files with 263 additions and 18 deletions

View File

@@ -45,7 +45,8 @@ if current_platform.is_cuda_alike():
from .pplx_prepare_finalize import PplxPrepareAndFinalize
if has_deepep:
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
DeepEPLLPrepareAndFinalize)
else:
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore
@@ -377,6 +378,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager.world_size)
handle = all2all_manager.get_handle(all_to_all_args)
# Note : We may want to use FP8 dispatch even otherwise just to
# reduce datamovement
assert act_quant_block_size is not None
use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype()
and act_quant_block_size[1]
== DEEPEP_QUANT_BLOCK_SIZE)
# Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now.
prepare_finalize = DeepEPLLPrepareAndFinalize(
@@ -386,7 +394,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
max_tokens_per_rank=moe.max_num_tokens,
quant_dtype=quant_dtype,
block_shape=act_quant_block_size,
use_fp8_dispatch=False,
use_fp8_dispatch=use_fp8_dispatch,
)
self.topk_indices_dtype = None