diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 7375625..57c807e 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -225,6 +225,43 @@ def fp8_fp4_mega_moe(y: torch.Tensor, ) +def get_symm_buffer_for_nvfp4_mega_moe( + group: "dist.ProcessGroup", + num_experts: int, + num_max_tokens_per_rank: int, num_topk: int, + hidden: int, intermediate_hidden: int, + use_fp8_dispatch: bool = True, + activation: str = 'swiglu') -> SymmBuffer: + """Allocate a SymmBuffer sized for NVFP4 mega_moe (group_size=16).""" + from .. import _C + num_max_tokens_per_rank = align(num_max_tokens_per_rank, + _C.get_token_alignment_for_nvfp4_mega_moe()) + buf = SymmBuffer.__new__(SymmBuffer) + buf.group = group + buf.num_experts = num_experts + buf.num_max_tokens_per_rank = num_max_tokens_per_rank + buf.num_topk = num_topk + buf.hidden = hidden + buf.intermediate_hidden = intermediate_hidden + # Use NVFP4-specific buffer size (2x SF due to group_size=16) + num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_nvfp4_mega_moe( + group.size(), num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + use_fp8_dispatch, activation) + import torch.distributed._symmetric_memory as symm_mem + import torch.distributed as dist + buf.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda') + buf.handle = symm_mem.rendezvous(buf.buffer, group=group) + buf.buffer.zero_() + buf.group.barrier() + torch.cuda.synchronize() + buf.x, buf.x_sf, buf.topk_idx, buf.topk_weights, \ + buf.l1_acts, buf.l1_acts_sf, buf.l2_acts, buf.l2_acts_sf = \ + slice_input_buffers(buf.buffer) + return buf + + def fp8_nvfp4_mega_moe(y: torch.Tensor, l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor],