import torch from typing import Tuple, Optional from ..utils.math import align # noinspection PyBroadException try: # noinspection PyProtectedMember import torch.distributed._symmetric_memory as symm_mem import torch.distributed as dist except Exception as exception: print(f'Failed to load mega kernels, please check your PyTorch version: {exception}') from .. import _C class SymmBuffer: def __init__(self, group: dist.ProcessGroup, # MoE arguments num_experts: int, num_max_tokens_per_rank: int, num_topk: int, hidden: int, intermediate_hidden: int, use_fp8_dispatch: bool = True, activation: str = 'swiglu'): self.group = group self.num_experts = num_experts self.num_max_tokens_per_rank = num_max_tokens_per_rank self.num_topk = num_topk self.hidden = hidden self.intermediate_hidden = intermediate_hidden # Allocate a symmetric buffer num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe( group.size(), num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden, use_fp8_dispatch, activation ) self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda') self.handle = symm_mem.rendezvous(self.buffer, group=group) self.buffer.zero_() self.group.barrier() torch.cuda.synchronize() # Create input buffer views (self.x, self.x_sf, self.topk_idx, self.topk_weights, self.l1_acts, self.l1_acts_sf, self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer) def destroy(self): self.handle = None self.buffer = None self.group = None self.x = None self.x_sf = None def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup, num_experts: int, num_max_tokens_per_rank: int, num_topk: int, hidden: int, intermediate_hidden: int, use_fp8_dispatch: bool = True, activation: str = 'swiglu') -> SymmBuffer: # Token count must be aligned to block sizes num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe()) return SymmBuffer( group, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden, use_fp8_dispatch, activation ) class NVFP4SymmBuffer: """Symmetric buffer for NVFP4 mega_moe kernel. Same structure as SymmBuffer but with NVFP4-specific SF layout: group_size=16 means 2x the SF data compared to MXFP4 (group_size=32). x_sf is packed int32 (4 UE4M3 per int32) with shape (M, K//64). """ def __init__(self, group: dist.ProcessGroup, num_experts: int, num_max_tokens_per_rank: int, num_topk: int, hidden: int, intermediate_hidden: int, use_fp8_dispatch: bool = True, activation: str = 'swiglu'): self.group = group self.num_experts = num_experts self.num_max_tokens_per_rank = num_max_tokens_per_rank self.num_topk = num_topk self.hidden = hidden self.intermediate_hidden = intermediate_hidden # Allocate a symmetric buffer (NVFP4 variant) num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_nvfp4_mega_moe( group.size(), num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden, use_fp8_dispatch, activation ) self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda') self.handle = symm_mem.rendezvous(self.buffer, group=group) self.buffer.zero_() self.group.barrier() torch.cuda.synchronize() # Create input buffer views (self.x, self.x_sf, self.topk_idx, self.topk_weights, self.l1_acts, self.l1_acts_sf, self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer) def destroy(self): self.handle = None self.buffer = None self.group = None self.x = None self.x_sf = None def get_symm_buffer_for_nvfp4_mega_moe(group: dist.ProcessGroup, num_experts: int, num_max_tokens_per_rank: int, num_topk: int, hidden: int, intermediate_hidden: int, use_fp8_dispatch: bool = True, activation: str = 'swiglu') -> NVFP4SymmBuffer: # Token count must be aligned to block sizes num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_nvfp4_mega_moe()) return NVFP4SymmBuffer( group, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden, use_fp8_dispatch, activation ) def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: # [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up] def interleave(t, gran: int = 8) -> torch.Tensor: g, n, *rest = t.shape half = n // 2 gate = t[:, :half].reshape(g, half // gran, gran, *rest) up = t[:, half:].reshape(g, half // gran, gran, *rest) return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest)) def interleave_sf_mn_major(t, gran: int = 8) -> torch.Tensor: """Interleave SF while preserving MN-major layout (stride(-2)=1, stride(-1)=TMA-aligned). Input/Output shape: (num_groups, mn, packed_sf_k) with MN-major strides. Interleaves the mn dimension: [gate_0..7, up_0..7, gate_8..15, up_8..15, ...] """ # t: (groups, mn, packed_sf_k) MN-major, stride(-2)=1 # Transpose to K-major C-contiguous for safe interleave ops t_k = t.transpose(-2, -1).contiguous() # (groups, packed_sf_k, mn) C-contiguous g, k, mn = t_k.shape half = mn // 2 gate = t_k[:, :, :half].reshape(g, k, half // gran, gran) up = t_k[:, :, half:].reshape(g, k, half // gran, gran) interleaved_k = torch.empty(g, k, mn, dtype=t.dtype, device=t.device) interleaved_k.copy_(torch.stack([gate, up], dim=3).reshape(g, k, mn)) # Single transpose back to MN-major: (g, mn, k) with stride(-2)=1 return interleaved_k.transpose(-2, -1) return interleave(l1_weights[0]), interleave_sf_mn_major(l1_weights[1]) def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: num_groups, mn, packed_sf_k = sf.shape assert sf.dtype == torch.int and mn % 128 == 0 result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k) .transpose(2, 3) .reshape(num_groups, mn, packed_sf_k)) return torch.empty_like(sf).copy_(result) def transform_weights_for_mega_moe( l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: # L1: interleave gate/up, then transpose SF for UTCCP l1_interleaved = _interleave_l1_weights(l1_weights) l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1])) # L2: only transpose SF for UTCCP l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1])) return l1_weights, l2_weights def transform_nvfp4_weights_for_mega_moe( l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor], l1_weight_scale_2: Optional[torch.Tensor] = None, l2_weight_scale_2: Optional[torch.Tensor] = None ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: """Transform NVFP4 expert weights for the mega_moe kernel. NVFP4 weights come as (weight, scale) where: - weight: uint8 E2M1 packed, shape (num_experts, N, K//2) - scale: float8_e4m3fn UE4M3 block scales, shape (num_experts, N, K//16) The kernel expects (weight, packed_sf) where packed_sf is int32 TMA-aligned UTCCP layout with gran_k=16. If weight_scale_2 (float32 global scale) is provided, it is folded into the block scales: effective_scale = block_scale * global_scale → re-quantized to UE4M3. """ from deep_gemm import transform_sf_into_required_layout def fold_global_scale(sf: torch.Tensor, scale_2: Optional[torch.Tensor]) -> torch.Tensor: """Fold weight_scale_2 into block scales: UE4M3 * FP32 → UE4M3""" if scale_2 is None: return sf # UE4M3 → float32: checkpoint stores float8_e4m3fn (standard NVFP4 spec) # NOT UE8M0 — shift-by-23 was wrong (Bug #7 fix: data IS E4M3, not E8M0) sf_f32 = sf.to(torch.float32) if scale_2.dim() == 1: scale_2 = scale_2.view(-1, 1, 1) sf_f32 = sf_f32 * scale_2 sf_f32 = sf_f32.clamp(0.0, 448.0) return sf_f32.to(torch.float8_e4m3fn) # Fold global scales into block scales l1_sf = fold_global_scale(l1_weights[1], l1_weight_scale_2) l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2) num_experts = l1_weights[0].shape[0] l1_n = l1_weights[0].shape[1] l1_k = l1_weights[0].shape[2] * 2 # K (weight is K//2 uint8) l2_n = l2_weights[0].shape[1] l2_k = l2_weights[0].shape[2] * 2 # Pack UE4M3 (float8_e4m3fn) into int32 for DeepGEMM TMA consumption # 4 UE4M3 bytes → 1 int32 def pack_ue4m3_to_int32(sf): sf_u8 = sf.view(torch.uint8) assert sf_u8.shape[-1] % 4 == 0 packed = (sf_u8[..., 0::4].to(torch.int32) | (sf_u8[..., 1::4].to(torch.int32) << 8) | (sf_u8[..., 2::4].to(torch.int32) << 16) | (sf_u8[..., 3::4].to(torch.int32) << 24)) return packed.contiguous() l1_sf_packed = pack_ue4m3_to_int32(l1_sf) l2_sf_packed = pack_ue4m3_to_int32(l2_sf) # Transpose to MN-major layout (stride(-2)=1) and make contiguous # transform_sf_into_required_layout expects MN-major input for TMA stride checks l1_sf_mn = l1_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1) l2_sf_mn = l2_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1) # Transform SF into TMA-aligned UTCCP layout using DeepGEMM's C++ function # recipe (1, 16): gran_mn=1, gran_k=16 (NVFP4 native block16) l1_sf_transformed = transform_sf_into_required_layout( l1_sf_mn, l1_n, l1_k, (1, 16), num_experts) l2_sf_transformed = transform_sf_into_required_layout( l2_sf_mn, l2_n, l2_k, (1, 16), num_experts) # L1: interleave gate/up l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf_transformed)) # DeepGEMM expects int8 (kPackedFP4 = torch.kInt8) l1_out = (l1_interleaved[0].view(torch.int8), l1_interleaved[1]) l2_out = (l2_weights[0].view(torch.int8), l2_sf_transformed) return l1_out, l2_out def fp8_fp4_mega_moe(y: torch.Tensor, l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor], sym_buffer: SymmBuffer, cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, recipe: Tuple[int, int, int] = (1, 1, 32), activation: str = 'swiglu', activation_clamp: Optional[float] = None, fast_math: bool = True): _C.fp8_fp4_mega_moe( y, l1_weights, l2_weights, cumulative_local_expert_recv_stats, sym_buffer.buffer, sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), sym_buffer.num_max_tokens_per_rank, sym_buffer.num_experts, sym_buffer.num_topk, recipe, activation, activation_clamp, fast_math ) def fp8_nvfp4_mega_moe(y: torch.Tensor, l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor], sym_buffer: SymmBuffer, cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, recipe: Tuple[int, int, int] = (1, 1, 16), activation: str = 'swiglu', activation_clamp: Optional[float] = None, fast_math: bool = True): """NVFP4 mega MoE: uses kind::mxf4nvf4.block_scale.scale_vec::4X with UE4M3 block scales (group_size=16). Both activations AND weights are E2M1 packed (FP4×FP4). Weight format: (uint8 E2M1 packed, int32 packed UTCCP UE4M3 scales) Activation format: E2M1 packed uint8 + UE4M3 scales (computed by staging kernel) Recipe: (1, 1, 16) — kGranK=16 for NVFP4 group_size=16. """ l1_w, l1_w_sf = l1_weights l2_w, l2_w_sf = l2_weights _C.fp8_nvfp4_mega_moe( y, (l1_w, l1_w_sf), (l2_w, l2_w_sf), cumulative_local_expert_recv_stats, sym_buffer.buffer, sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), sym_buffer.num_max_tokens_per_rank, sym_buffer.num_experts, sym_buffer.num_topk, recipe, activation, activation_clamp, fast_math )