fix: add Python wrapper for NVFP4 SymmBuffer allocation
get_symm_buffer_for_nvfp4_mega_moe uses _C.get_symm_buffer_size_for_nvfp4_mega_moe to allocate the correct buffer size (2x SF entries due to group_size=16). Custom init to avoid SymmBuffer's hardcoded MXFP4 allocation.
This commit is contained in:
@@ -225,6 +225,43 @@ def fp8_fp4_mega_moe(y: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
def get_symm_buffer_for_nvfp4_mega_moe(
|
||||
group: "dist.ProcessGroup",
|
||||
num_experts: int,
|
||||
num_max_tokens_per_rank: int, num_topk: int,
|
||||
hidden: int, intermediate_hidden: int,
|
||||
use_fp8_dispatch: bool = True,
|
||||
activation: str = 'swiglu') -> SymmBuffer:
|
||||
"""Allocate a SymmBuffer sized for NVFP4 mega_moe (group_size=16)."""
|
||||
from .. import _C
|
||||
num_max_tokens_per_rank = align(num_max_tokens_per_rank,
|
||||
_C.get_token_alignment_for_nvfp4_mega_moe())
|
||||
buf = SymmBuffer.__new__(SymmBuffer)
|
||||
buf.group = group
|
||||
buf.num_experts = num_experts
|
||||
buf.num_max_tokens_per_rank = num_max_tokens_per_rank
|
||||
buf.num_topk = num_topk
|
||||
buf.hidden = hidden
|
||||
buf.intermediate_hidden = intermediate_hidden
|
||||
# Use NVFP4-specific buffer size (2x SF due to group_size=16)
|
||||
num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
group.size(), num_experts,
|
||||
num_max_tokens_per_rank, num_topk,
|
||||
hidden, intermediate_hidden,
|
||||
use_fp8_dispatch, activation)
|
||||
import torch.distributed._symmetric_memory as symm_mem
|
||||
import torch.distributed as dist
|
||||
buf.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda')
|
||||
buf.handle = symm_mem.rendezvous(buf.buffer, group=group)
|
||||
buf.buffer.zero_()
|
||||
buf.group.barrier()
|
||||
torch.cuda.synchronize()
|
||||
buf.x, buf.x_sf, buf.topk_idx, buf.topk_weights, \
|
||||
buf.l1_acts, buf.l1_acts_sf, buf.l2_acts, buf.l2_acts_sf = \
|
||||
slice_input_buffers(buf.buffer)
|
||||
return buf
|
||||
|
||||
|
||||
def fp8_nvfp4_mega_moe(y: torch.Tensor,
|
||||
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
||||
l2_weights: Tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
Reference in New Issue
Block a user