import torch from typing import Tuple, Optional from ..utils.math import align # noinspection PyBroadException try: # noinspection PyProtectedMember import torch.distributed._symmetric_memory as symm_mem import torch.distributed as dist except Exception as exception: print(f'Failed to load mega kernels, please check your PyTorch version: {exception}') from .. import _C class SymmBuffer: def __init__(self, group: dist.ProcessGroup, # MoE arguments num_experts: int, num_max_tokens_per_rank: int, num_topk: int, hidden: int, intermediate_hidden: int, use_fp8_dispatch: bool = True, activation: str = 'swiglu'): self.group = group self.num_experts = num_experts self.num_max_tokens_per_rank = num_max_tokens_per_rank self.num_topk = num_topk self.hidden = hidden self.intermediate_hidden = intermediate_hidden # Allocate a symmetric buffer num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe( group.size(), num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden, use_fp8_dispatch, activation ) self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda') self.handle = symm_mem.rendezvous(self.buffer, group=group) self.buffer.zero_() self.group.barrier() torch.cuda.synchronize() # Create input buffer views (self.x, self.x_sf, self.topk_idx, self.topk_weights, self.l1_acts, self.l1_acts_sf, self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer) def destroy(self): self.handle = None self.buffer = None self.group = None self.x = None self.x_sf = None def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup, num_experts: int, num_max_tokens_per_rank: int, num_topk: int, hidden: int, intermediate_hidden: int, use_fp8_dispatch: bool = True, activation: str = 'swiglu') -> SymmBuffer: # Token count must be aligned to block sizes num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe()) return SymmBuffer( group, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden, use_fp8_dispatch, activation ) def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: # [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up] def interleave(t, gran: int = 8) -> torch.Tensor: g, n, *rest = t.shape half = n // 2 gate = t[:, :half].reshape(g, half // gran, gran, *rest) up = t[:, half:].reshape(g, half // gran, gran, *rest) return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest)) return interleave(l1_weights[0]), interleave(l1_weights[1]) def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: num_groups, mn, packed_sf_k = sf.shape assert sf.dtype == torch.int and mn % 128 == 0 result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k) .transpose(2, 3) .reshape(num_groups, mn, packed_sf_k)) return torch.empty_like(sf).copy_(result) def transform_weights_for_mega_moe( l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: # L1: interleave gate/up, then transpose SF for UTCCP l1_interleaved = _interleave_l1_weights(l1_weights) l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1])) # L2: only transpose SF for UTCCP l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1])) return l1_weights, l2_weights def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: """Pack NVFP4 UE4M3 block scales (float8_e4m3fn) into int32 UTCCP layout. NVFP4 uses UE4M3 scales with group_size=16 (scale_vec::4X). The UTCCP layout packs 4 consecutive scale bytes into each int32, then applies the 4x32 transpose for TMA consumption. Input: (num_experts, mn, K//16) float8_e4m3fn scales Output: (num_experts, mn, K//64) int32 packed UTCCP-transposed scales """ num_groups, mn, sf_k = sf.shape assert sf_k % 4 == 0, f"NVFP4 SF K dim must be divisible by 4, got {sf_k}" assert mn % 128 == 0, f"MN dim must be divisible by 128, got {mn}" # View as uint8 and pack 4 consecutive bytes into int32 sf_uint8 = sf.view(torch.uint8) # (num_groups, mn, sf_k) # Pack: every 4 uint8 → 1 int32 packed = (sf_uint8[..., 0::4].to(torch.int32) | (sf_uint8[..., 1::4].to(torch.int32) << 8) | (sf_uint8[..., 2::4].to(torch.int32) << 16) | (sf_uint8[..., 3::4].to(torch.int32) << 24)) # (num_groups, mn, sf_k//4) # Apply UTCCP 4x32 transpose (same as MXFP4 — the transpose is determined # by the 128-element alignment, not the scale vector size) packed_sf_k = sf_k // 4 result = (packed.reshape(num_groups, -1, 4, 32, packed_sf_k) .transpose(2, 3) .reshape(num_groups, mn, packed_sf_k)) return torch.empty_like(packed).copy_(result) def transform_nvfp4_weights_for_mega_moe( l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor], l1_weight_scale_2: Optional[torch.Tensor] = None, l2_weight_scale_2: Optional[torch.Tensor] = None ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: """Transform NVFP4 expert weights for the mega_moe kernel. Uses deep_gemm.transform_sf_into_required_layout for proper TMA-aligned UTCCP layout with recipe (1, 1, 16) for NVFP4 group_size=16. """ from deep_gemm import transform_sf_into_required_layout def fold_global_scale(sf: torch.Tensor, scale_2: Optional[torch.Tensor]) -> torch.Tensor: if scale_2 is None: return sf sf_f32 = sf.to(torch.float32) if scale_2.dim() == 1: scale_2 = scale_2.view(-1, 1, 1) sf_f32 = sf_f32 * scale_2 sf_f32 = sf_f32.clamp(0.0, 448.0) return sf_f32.to(torch.float8_e4m3fn) l1_sf = fold_global_scale(l1_weights[1], l1_weight_scale_2) l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2) # Merge NVFP4 block16 scales → block32 for SM100 (scale_vec::2X) # B200 (SM100) doesn't support scale_vec::4X (block16) — requires SM103/SM120 # Take the max of each pair of adjacent block16 scales for block32 def merge_block16_to_block32(sf): # sf: (experts, mn, K//16) float8_e4m3fn # output: (experts, mn, K//32) float8_e4m3fn sf_f32 = sf.to(torch.float32) # Take max of adjacent pairs (preserves magnitude, avoids underflow) sf_merged = torch.maximum(sf_f32[..., 0::2], sf_f32[..., 1::2]) return sf_merged.clamp(0.0, 448.0).to(torch.float8_e4m3fn) l1_sf_32 = merge_block16_to_block32(l1_sf) l2_sf_32 = merge_block16_to_block32(l2_sf) num_experts = l1_weights[0].shape[0] l1_n = l1_weights[0].shape[1] l1_k = l1_weights[0].shape[2] * 2 l2_n = l2_weights[0].shape[1] l2_k = l2_weights[0].shape[2] * 2 # Pack UE4M3 (float8_e4m3fn) into int32 for DeepGEMM TMA consumption # 4 UE4M3 bytes → 1 int32, matching the hardware's 4X scale vector def pack_ue4m3_to_int32(sf): sf_u8 = sf.view(torch.uint8) assert sf_u8.shape[-1] % 4 == 0 packed = (sf_u8[..., 0::4].to(torch.int32) | (sf_u8[..., 1::4].to(torch.int32) << 8) | (sf_u8[..., 2::4].to(torch.int32) << 16) | (sf_u8[..., 3::4].to(torch.int32) << 24)) return packed.contiguous() l1_sf_packed = pack_ue4m3_to_int32(l1_sf_32) l2_sf_packed = pack_ue4m3_to_int32(l2_sf_32) # Transpose to MN-major layout (stride(-2)=1) and make contiguous # transform_sf_into_required_layout expects MN-major input for TMA stride checks l1_sf_mn = l1_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1) l2_sf_mn = l2_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1) # Transform SF into TMA-aligned UTCCP layout using DeepGEMM's C++ function # recipe (1, 32): gran_mn=1, gran_k=16 l1_sf_transformed = transform_sf_into_required_layout( l1_sf_mn, l1_n, l1_k, (1, 32), num_experts) l2_sf_transformed = transform_sf_into_required_layout( l2_sf_mn, l2_n, l2_k, (1, 32), num_experts) # L1: interleave gate/up l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf_packed)) # DeepGEMM expects int8 (kPackedFP4 = torch.kInt8) l1_out = (l1_interleaved[0].view(torch.int8), l1_sf_transformed) l2_out = (l2_weights[0].view(torch.int8), l2_sf_transformed) return l1_out, l2_out def fp8_fp4_mega_moe(y: torch.Tensor, l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor], sym_buffer: SymmBuffer, cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, recipe: Tuple[int, int, int] = (1, 1, 32), activation: str = 'swiglu', activation_clamp: Optional[float] = None, fast_math: bool = True): _C.fp8_fp4_mega_moe( y, l1_weights, l2_weights, cumulative_local_expert_recv_stats, sym_buffer.buffer, sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), sym_buffer.num_max_tokens_per_rank, sym_buffer.num_experts, sym_buffer.num_topk, recipe, activation, activation_clamp, fast_math ) def get_symm_buffer_for_nvfp4_mega_moe( group: "dist.ProcessGroup", num_experts: int, num_max_tokens_per_rank: int, num_topk: int, hidden: int, intermediate_hidden: int, use_fp8_dispatch: bool = True, activation: str = 'swiglu') -> SymmBuffer: """Allocate a SymmBuffer sized for NVFP4 mega_moe (group_size=16).""" from .. import _C num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_nvfp4_mega_moe()) buf = SymmBuffer.__new__(SymmBuffer) buf.group = group buf.num_experts = num_experts buf.num_max_tokens_per_rank = num_max_tokens_per_rank buf.num_topk = num_topk buf.hidden = hidden buf.intermediate_hidden = intermediate_hidden # Use NVFP4-specific buffer size (2x SF due to group_size=16) num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_nvfp4_mega_moe( group.size(), num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden, use_fp8_dispatch, activation) import torch.distributed._symmetric_memory as symm_mem import torch.distributed as dist buf.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda') buf.handle = symm_mem.rendezvous(buf.buffer, group=group) buf.buffer.zero_() buf.group.barrier() torch.cuda.synchronize() buf.x, buf.x_sf, buf.topk_idx, buf.topk_weights, \ buf.l1_acts, buf.l1_acts_sf, buf.l2_acts, buf.l2_acts_sf = \ slice_input_buffers(buf.buffer) return buf def fp8_nvfp4_mega_moe(y: torch.Tensor, l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor], sym_buffer: SymmBuffer, cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, recipe: Tuple[int, int, int] = (1, 1, 32), activation: str = 'swiglu', activation_clamp: Optional[float] = None, fast_math: bool = True): """NVFP4 mega MoE: uses kind::mxf4nvf4.block_scale.scale_vec::4X with UE4M3 block scales (group_size=16). Weight format: (uint8 E2M1 packed, int32 packed UTCCP UE4M3 scales) Recipe: (1, 1, 16) — kGranK=16 for NVFP4 group_size=16. """ _C.fp8_nvfp4_mega_moe( y, l1_weights, l2_weights, cumulative_local_expert_recv_stats, sym_buffer.buffer, sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), sym_buffer.num_max_tokens_per_rank, sym_buffer.num_experts, sym_buffer.num_topk, recipe, activation, activation_clamp, fast_math )