import torch from typing import Tuple, Optional from ..utils.math import align # noinspection PyBroadException try: # noinspection PyProtectedMember import torch.distributed._symmetric_memory as symm_mem import torch.distributed as dist except Exception as exception: print(f'Failed to load mega kernels, please check your PyTorch version: {exception}') from .. import _C class SymmBuffer: def __init__(self, group: dist.ProcessGroup, # MoE arguments num_experts: int, num_max_tokens_per_rank: int, num_topk: int, hidden: int, intermediate_hidden: int, use_fp8_dispatch: bool = True, activation: str = 'swiglu'): self.group = group self.num_experts = num_experts self.num_max_tokens_per_rank = num_max_tokens_per_rank self.num_topk = num_topk self.hidden = hidden self.intermediate_hidden = intermediate_hidden # Allocate a symmetric buffer num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe( group.size(), num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden, use_fp8_dispatch, activation ) self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda') self.handle = symm_mem.rendezvous(self.buffer, group=group) self.buffer.zero_() self.group.barrier() torch.cuda.synchronize() # Create input buffer views (self.x, self.x_sf, self.topk_idx, self.topk_weights, self.l1_acts, self.l1_acts_sf, self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer) def destroy(self): self.handle = None self.buffer = None self.group = None self.x = None self.x_sf = None def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup, num_experts: int, num_max_tokens_per_rank: int, num_topk: int, hidden: int, intermediate_hidden: int, use_fp8_dispatch: bool = True, activation: str = 'swiglu') -> SymmBuffer: # Token count must be aligned to block sizes num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe()) return SymmBuffer( group, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden, use_fp8_dispatch, activation ) def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: # [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up] def interleave(t, gran: int = 8) -> torch.Tensor: g, n, *rest = t.shape half = n // 2 gate = t[:, :half].reshape(g, half // gran, gran, *rest) up = t[:, half:].reshape(g, half // gran, gran, *rest) return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest)) return interleave(l1_weights[0]), interleave(l1_weights[1]) def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: num_groups, mn, packed_sf_k = sf.shape assert sf.dtype == torch.int and mn % 128 == 0 result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k) .transpose(2, 3) .reshape(num_groups, mn, packed_sf_k)) return torch.empty_like(sf).copy_(result) def transform_weights_for_mega_moe( l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: # L1: interleave gate/up, then transpose SF for UTCCP l1_interleaved = _interleave_l1_weights(l1_weights) l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1])) # L2: only transpose SF for UTCCP l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1])) return l1_weights, l2_weights def fp8_fp4_mega_moe(y: torch.Tensor, l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor], sym_buffer: SymmBuffer, cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, recipe: Tuple[int, int, int] = (1, 1, 32), activation: str = 'swiglu', activation_clamp: Optional[float] = None, fast_math: bool = True): _C.fp8_fp4_mega_moe( y, l1_weights, l2_weights, cumulative_local_expert_recv_stats, sym_buffer.buffer, sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), sym_buffer.num_max_tokens_per_rank, sym_buffer.num_experts, sym_buffer.num_topk, recipe, activation, activation_clamp, fast_math )