From 47621bb99072432ba194af079054b5a3d4db3251 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 16:25:08 +0000 Subject: [PATCH] add NVFP4SymmBuffer + get_symm_buffer_for_nvfp4_mega_moe Python wrapper The C++ binding was registered but there was no Python wrapper. vLLM patch imports get_symm_buffer_for_nvfp4_mega_moe from deep_gemm.mega. --- deep_gemm/mega/__init__.py | 64 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 8e6437f..dd455f2 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -72,6 +72,70 @@ def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup, ) +class NVFP4SymmBuffer: + """Symmetric buffer for NVFP4 mega_moe kernel. + + Same structure as SymmBuffer but with NVFP4-specific SF layout: + group_size=16 means 2x the SF data compared to MXFP4 (group_size=32). + x_sf is packed int32 (4 UE4M3 per int32) with shape (M, K//64). + """ + def __init__(self, group: dist.ProcessGroup, + num_experts: int, + num_max_tokens_per_rank: int, num_topk: int, + hidden: int, intermediate_hidden: int, + use_fp8_dispatch: bool = True, + activation: str = 'swiglu'): + self.group = group + self.num_experts = num_experts + self.num_max_tokens_per_rank = num_max_tokens_per_rank + self.num_topk = num_topk + self.hidden = hidden + self.intermediate_hidden = intermediate_hidden + + # Allocate a symmetric buffer (NVFP4 variant) + num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_nvfp4_mega_moe( + group.size(), num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + use_fp8_dispatch, activation + ) + self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda') + self.handle = symm_mem.rendezvous(self.buffer, group=group) + self.buffer.zero_() + self.group.barrier() + torch.cuda.synchronize() + + # Create input buffer views + (self.x, self.x_sf, + self.topk_idx, self.topk_weights, + self.l1_acts, self.l1_acts_sf, + self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer) + + def destroy(self): + self.handle = None + self.buffer = None + self.group = None + self.x = None + self.x_sf = None + + +def get_symm_buffer_for_nvfp4_mega_moe(group: dist.ProcessGroup, + num_experts: int, + num_max_tokens_per_rank: int, num_topk: int, + hidden: int, intermediate_hidden: int, + use_fp8_dispatch: bool = True, + activation: str = 'swiglu') -> NVFP4SymmBuffer: + # Token count must be aligned to block sizes + num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_nvfp4_mega_moe()) + + return NVFP4SymmBuffer( + group, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + use_fp8_dispatch, activation + ) + + def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: # [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up] def interleave(t, gran: int = 8) -> torch.Tensor: