import torch from typing import Tuple, Optional from ..utils.math import align # noinspection PyBroadException try: # noinspection PyProtectedMember import torch.distributed._symmetric_memory as symm_mem import torch.distributed as dist except Exception as exception: print(f'Failed to load mega kernels, please check your PyTorch version: {exception}') from .. import _C class SymmBuffer: def __init__(self, group: dist.ProcessGroup, # MoE arguments num_experts: int, num_max_tokens_per_rank: int, num_topk: int, hidden: int, intermediate_hidden: int, use_fp8_dispatch: bool = True, activation: str = 'swiglu'): self.group = group self.num_experts = num_experts self.num_max_tokens_per_rank = num_max_tokens_per_rank self.num_topk = num_topk self.hidden = hidden self.intermediate_hidden = intermediate_hidden # Allocate a symmetric buffer num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe( group.size(), num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden, use_fp8_dispatch, activation ) self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda') self.handle = symm_mem.rendezvous(self.buffer, group=group) self.buffer.zero_() self.group.barrier() torch.cuda.synchronize() # Create input buffer views (self.x, self.x_sf, self.topk_idx, self.topk_weights, self.l1_acts, self.l1_acts_sf, self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer) def destroy(self): self.handle = None self.buffer = None self.group = None self.x = None self.x_sf = None def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup, num_experts: int, num_max_tokens_per_rank: int, num_topk: int, hidden: int, intermediate_hidden: int, use_fp8_dispatch: bool = True, activation: str = 'swiglu') -> SymmBuffer: # Token count must be aligned to block sizes num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe()) return SymmBuffer( group, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden, use_fp8_dispatch, activation ) def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: # [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up] def interleave(t, gran: int = 8) -> torch.Tensor: g, n, *rest = t.shape half = n // 2 gate = t[:, :half].reshape(g, half // gran, gran, *rest) up = t[:, half:].reshape(g, half // gran, gran, *rest) return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest)) return interleave(l1_weights[0]), interleave(l1_weights[1]) def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: num_groups, mn, packed_sf_k = sf.shape assert sf.dtype == torch.int and mn % 128 == 0 result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k) .transpose(2, 3) .reshape(num_groups, mn, packed_sf_k)) return torch.empty_like(sf).copy_(result) def transform_weights_for_mega_moe( l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: # L1: interleave gate/up, then transpose SF for UTCCP l1_interleaved = _interleave_l1_weights(l1_weights) l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1])) # L2: only transpose SF for UTCCP l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1])) return l1_weights, l2_weights def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: """Pack NVFP4 UE4M3 block scales (float8_e4m3fn) into int32 UTCCP layout. NVFP4 uses UE4M3 scales with group_size=16 (scale_vec::4X). The UTCCP layout packs 4 consecutive scale bytes into each int32, then applies the 4x32 transpose for TMA consumption. Input: (num_experts, mn, K//16) float8_e4m3fn scales Output: (num_experts, mn, K//64) int32 packed UTCCP-transposed scales """ num_groups, mn, sf_k = sf.shape assert sf_k % 4 == 0, f"NVFP4 SF K dim must be divisible by 4, got {sf_k}" assert mn % 128 == 0, f"MN dim must be divisible by 128, got {mn}" # View as uint8 and pack 4 consecutive bytes into int32 sf_uint8 = sf.view(torch.uint8) # (num_groups, mn, sf_k) # Pack: every 4 uint8 → 1 int32 packed = (sf_uint8[..., 0::4].to(torch.int32) | (sf_uint8[..., 1::4].to(torch.int32) << 8) | (sf_uint8[..., 2::4].to(torch.int32) << 16) | (sf_uint8[..., 3::4].to(torch.int32) << 24)) # (num_groups, mn, sf_k//4) # Apply UTCCP 4x32 transpose (same as MXFP4 — the transpose is determined # by the 128-element alignment, not the scale vector size) packed_sf_k = sf_k // 4 result = (packed.reshape(num_groups, -1, 4, 32, packed_sf_k) .transpose(2, 3) .reshape(num_groups, mn, packed_sf_k)) return torch.empty_like(packed).copy_(result) def transform_nvfp4_weights_for_mega_moe( l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: """Transform NVFP4 expert weights for the mega_moe kernel. NVFP4 weights come as (weight, scale) where: - weight: uint8 E2M1 packed, shape (num_experts, N, K//2) - scale: float8_e4m3fn UE4M3 block scales, shape (num_experts, N, K//16) The kernel expects (weight, packed_sf) where packed_sf is int32 UTCCP layout. """ # L1: interleave gate/up, then pack + transpose SF for UTCCP l1_interleaved = _interleave_l1_weights(l1_weights) l1_weights = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1])) # L2: only pack + transpose SF for UTCCP l2_weights = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_weights[1])) return l1_weights, l2_weights def fp8_fp4_mega_moe(y: torch.Tensor, l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor], sym_buffer: SymmBuffer, cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, recipe: Tuple[int, int, int] = (1, 1, 32), activation: str = 'swiglu', activation_clamp: Optional[float] = None, fast_math: bool = True): _C.fp8_fp4_mega_moe( y, l1_weights, l2_weights, cumulative_local_expert_recv_stats, sym_buffer.buffer, sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), sym_buffer.num_max_tokens_per_rank, sym_buffer.num_experts, sym_buffer.num_topk, recipe, activation, activation_clamp, fast_math ) def fp8_nvfp4_mega_moe(y: torch.Tensor, l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor], sym_buffer: SymmBuffer, cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, recipe: Tuple[int, int, int] = (1, 1, 16), activation: str = 'swiglu', activation_clamp: Optional[float] = None, fast_math: bool = True): """NVFP4 mega MoE: uses kind::mxf4nvf4.block_scale.scale_vec::4X with UE4M3 block scales (group_size=16). Weight format: (uint8 E2M1 packed, int32 packed UTCCP UE4M3 scales) Recipe: (1, 1, 16) — kGranK=16 for NVFP4 group_size=16. """ _C.fp8_nvfp4_mega_moe( y, l1_weights, l2_weights, cumulative_local_expert_recv_stats, sym_buffer.buffer, sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), sym_buffer.num_max_tokens_per_rank, sym_buffer.num_experts, sym_buffer.num_topk, recipe, activation, activation_clamp, fast_math )