Files
DeepGEMM/deep_gemm/mega/__init__.py
Chenggang Zhao 7f2a703ed5 [Public release 26/04] Introducing Mega MoE, FP4 Indexer and other features/fixes (#304)
* Merge with private repo

* Update README

* Update README

* Update README

* Add PyTorch requirements

* Fix sync scopes for MQA logits (#256)

* Update README
2026-04-17 09:45:14 +08:00

129 lines
4.9 KiB
Python

import torch
from typing import Tuple, Optional
from ..utils.math import align
# noinspection PyBroadException
try:
# noinspection PyProtectedMember
import torch.distributed._symmetric_memory as symm_mem
import torch.distributed as dist
except Exception as exception:
print(f'Failed to load mega kernels, please check your PyTorch version: {exception}')
from .. import _C
class SymmBuffer:
def __init__(self, group: dist.ProcessGroup,
# MoE arguments
num_experts: int,
num_max_tokens_per_rank: int, num_topk: int,
hidden: int, intermediate_hidden: int,
use_fp8_dispatch: bool = True,
activation: str = 'swiglu'):
self.group = group
self.num_experts = num_experts
self.num_max_tokens_per_rank = num_max_tokens_per_rank
self.num_topk = num_topk
self.hidden = hidden
self.intermediate_hidden = intermediate_hidden
# Allocate a symmetric buffer
num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe(
group.size(), num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
use_fp8_dispatch, activation
)
self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda')
self.handle = symm_mem.rendezvous(self.buffer, group=group)
self.buffer.zero_()
self.group.barrier()
torch.cuda.synchronize()
# Create input buffer views
(self.x, self.x_sf,
self.topk_idx, self.topk_weights,
self.l1_acts, self.l1_acts_sf,
self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer)
def destroy(self):
self.handle = None
self.buffer = None
self.group = None
self.x = None
self.x_sf = None
def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup,
num_experts: int,
num_max_tokens_per_rank: int, num_topk: int,
hidden: int, intermediate_hidden: int,
use_fp8_dispatch: bool = True,
activation: str = 'swiglu') -> SymmBuffer:
# Token count must be aligned to block m
num_ranks = group.size()
block_m = _C.get_block_m_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk)
num_max_tokens_per_rank = align(num_max_tokens_per_rank, block_m)
return SymmBuffer(
group, num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
use_fp8_dispatch, activation
)
def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
# [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up]
def interleave(t, gran: int = 8) -> torch.Tensor:
g, n, *rest = t.shape
half = n // 2
gate = t[:, :half].reshape(g, half // gran, gran, *rest)
up = t[:, half:].reshape(g, half // gran, gran, *rest)
return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest))
return interleave(l1_weights[0]), interleave(l1_weights[1])
def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
num_groups, mn, packed_sf_k = sf.shape
assert sf.dtype == torch.int and mn % 128 == 0
result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k)
.transpose(2, 3)
.reshape(num_groups, mn, packed_sf_k))
return torch.empty_like(sf).copy_(result)
def transform_weights_for_mega_moe(
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
# L1: interleave gate/up, then transpose SF for UTCCP
l1_interleaved = _interleave_l1_weights(l1_weights)
l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1]))
# L2: only transpose SF for UTCCP
l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1]))
return l1_weights, l2_weights
def fp8_fp4_mega_moe(y: torch.Tensor,
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor],
sym_buffer: SymmBuffer,
recipe: Tuple[int, int, int] = (1, 1, 32),
activation: str = 'swiglu',
activation_clamp: Optional[float] = None,
fast_math: bool = True):
_C.fp8_fp4_mega_moe(
y,
l1_weights, l2_weights,
sym_buffer.buffer,
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
sym_buffer.num_max_tokens_per_rank,
sym_buffer.num_experts, sym_buffer.num_topk,
recipe,
activation, activation_clamp,
fast_math
)