fix: add Python wrapper for NVFP4 SymmBuffer allocation

get_symm_buffer_for_nvfp4_mega_moe uses _C.get_symm_buffer_size_for_nvfp4_mega_moe
to allocate the correct buffer size (2x SF entries due to group_size=16).
Custom init to avoid SymmBuffer's hardcoded MXFP4 allocation.
This commit is contained in:
2026-05-11 08:05:21 +00:00
parent acbe006498
commit deff80c9c1

View File

@@ -225,6 +225,43 @@ def fp8_fp4_mega_moe(y: torch.Tensor,
)
def get_symm_buffer_for_nvfp4_mega_moe(
group: "dist.ProcessGroup",
num_experts: int,
num_max_tokens_per_rank: int, num_topk: int,
hidden: int, intermediate_hidden: int,
use_fp8_dispatch: bool = True,
activation: str = 'swiglu') -> SymmBuffer:
"""Allocate a SymmBuffer sized for NVFP4 mega_moe (group_size=16)."""
from .. import _C
num_max_tokens_per_rank = align(num_max_tokens_per_rank,
_C.get_token_alignment_for_nvfp4_mega_moe())
buf = SymmBuffer.__new__(SymmBuffer)
buf.group = group
buf.num_experts = num_experts
buf.num_max_tokens_per_rank = num_max_tokens_per_rank
buf.num_topk = num_topk
buf.hidden = hidden
buf.intermediate_hidden = intermediate_hidden
# Use NVFP4-specific buffer size (2x SF due to group_size=16)
num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_nvfp4_mega_moe(
group.size(), num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
use_fp8_dispatch, activation)
import torch.distributed._symmetric_memory as symm_mem
import torch.distributed as dist
buf.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda')
buf.handle = symm_mem.rendezvous(buf.buffer, group=group)
buf.buffer.zero_()
buf.group.barrier()
torch.cuda.synchronize()
buf.x, buf.x_sf, buf.topk_idx, buf.topk_weights, \
buf.l1_acts, buf.l1_acts_sf, buf.l2_acts, buf.l2_acts_sf = \
slice_input_buffers(buf.buffer)
return buf
def fp8_nvfp4_mega_moe(y: torch.Tensor,
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor],