add NVFP4SymmBuffer + get_symm_buffer_for_nvfp4_mega_moe Python wrapper

The C++ binding was registered but there was no Python wrapper.
vLLM patch imports get_symm_buffer_for_nvfp4_mega_moe from deep_gemm.mega.
This commit is contained in:
2026-05-11 16:25:08 +00:00
parent 86a1263f44
commit 47621bb990

View File

@@ -72,6 +72,70 @@ def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup,
)
class NVFP4SymmBuffer:
"""Symmetric buffer for NVFP4 mega_moe kernel.
Same structure as SymmBuffer but with NVFP4-specific SF layout:
group_size=16 means 2x the SF data compared to MXFP4 (group_size=32).
x_sf is packed int32 (4 UE4M3 per int32) with shape (M, K//64).
"""
def __init__(self, group: dist.ProcessGroup,
num_experts: int,
num_max_tokens_per_rank: int, num_topk: int,
hidden: int, intermediate_hidden: int,
use_fp8_dispatch: bool = True,
activation: str = 'swiglu'):
self.group = group
self.num_experts = num_experts
self.num_max_tokens_per_rank = num_max_tokens_per_rank
self.num_topk = num_topk
self.hidden = hidden
self.intermediate_hidden = intermediate_hidden
# Allocate a symmetric buffer (NVFP4 variant)
num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_nvfp4_mega_moe(
group.size(), num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
use_fp8_dispatch, activation
)
self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda')
self.handle = symm_mem.rendezvous(self.buffer, group=group)
self.buffer.zero_()
self.group.barrier()
torch.cuda.synchronize()
# Create input buffer views
(self.x, self.x_sf,
self.topk_idx, self.topk_weights,
self.l1_acts, self.l1_acts_sf,
self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer)
def destroy(self):
self.handle = None
self.buffer = None
self.group = None
self.x = None
self.x_sf = None
def get_symm_buffer_for_nvfp4_mega_moe(group: dist.ProcessGroup,
num_experts: int,
num_max_tokens_per_rank: int, num_topk: int,
hidden: int, intermediate_hidden: int,
use_fp8_dispatch: bool = True,
activation: str = 'swiglu') -> NVFP4SymmBuffer:
# Token count must be aligned to block sizes
num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_nvfp4_mega_moe())
return NVFP4SymmBuffer(
group, num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
use_fp8_dispatch, activation
)
def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
# [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up]
def interleave(t, gran: int = 8) -> torch.Tensor: