[Kernel] Integrate batched/masked deepgemm kernel (#19111)

Signed-off-by: Varun <vsundarr@redhat.com>
Co-authored-by: Varun <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-06-04 17:59:18 -04:00
committed by GitHub
parent ef3f98b59f
commit c3fd4d669a
6 changed files with 472 additions and 51 deletions

View File

@@ -162,12 +162,14 @@ def make_deepep_ll_a2a(pg: ProcessGroup,
low_latency_mode=True,
num_qps_per_rank=deepep_ll_args.num_experts //
pgi.world_size)
return DeepEPLLPrepareAndFinalize(
buffer=buffer,
world_size=pgi.world_size,
dp_size=dp_size,
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
quant_dtype=q_dtype,
block_shape=block_shape,
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
)
@@ -185,4 +187,5 @@ def make_deepep_a2a(pg: ProcessGroup,
block_shape)
assert deepep_ll_args is not None
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype)
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
block_shape)