Files
DeepGEMM/deep_gemm/mega/__init__.py
biondizzle 03b8c99ee1 fix: use mxf8f6f4 (UE8M0) on SM100 — mxf4nvf4 requires SM103+
B200 (SM100) does NOT support kind::mxf4nvf4 at all (neither 2X nor 4X).
Only mxf8f6f4.block_scale with UE8M0 scales is available on SM100.

Strategy: keep NVFP4 E2M1 weights, convert UE4M3 block scales → UE8M0
in the weight transformation. This is a scale format adaptation for
hardware compatibility, not a format conversion.

Changes:
- Kernel: back to mxf8f6F4 instruction + float_ue8m0_t descriptor
- L1 epilogue: back to UE8M0 (>> 23) activation scales
- Python: merge block16→block32, convert UE4M3→float32→UE8M0
- Packing: uint8 (UE8M0) → int32, same as MXFP4
2026-05-11 09:28:45 +00:00

316 lines
13 KiB
Python

import torch
from typing import Tuple, Optional
from ..utils.math import align
# noinspection PyBroadException
try:
# noinspection PyProtectedMember
import torch.distributed._symmetric_memory as symm_mem
import torch.distributed as dist
except Exception as exception:
print(f'Failed to load mega kernels, please check your PyTorch version: {exception}')
from .. import _C
class SymmBuffer:
def __init__(self, group: dist.ProcessGroup,
# MoE arguments
num_experts: int,
num_max_tokens_per_rank: int, num_topk: int,
hidden: int, intermediate_hidden: int,
use_fp8_dispatch: bool = True,
activation: str = 'swiglu'):
self.group = group
self.num_experts = num_experts
self.num_max_tokens_per_rank = num_max_tokens_per_rank
self.num_topk = num_topk
self.hidden = hidden
self.intermediate_hidden = intermediate_hidden
# Allocate a symmetric buffer
num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe(
group.size(), num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
use_fp8_dispatch, activation
)
self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda')
self.handle = symm_mem.rendezvous(self.buffer, group=group)
self.buffer.zero_()
self.group.barrier()
torch.cuda.synchronize()
# Create input buffer views
(self.x, self.x_sf,
self.topk_idx, self.topk_weights,
self.l1_acts, self.l1_acts_sf,
self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer)
def destroy(self):
self.handle = None
self.buffer = None
self.group = None
self.x = None
self.x_sf = None
def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup,
num_experts: int,
num_max_tokens_per_rank: int, num_topk: int,
hidden: int, intermediate_hidden: int,
use_fp8_dispatch: bool = True,
activation: str = 'swiglu') -> SymmBuffer:
# Token count must be aligned to block sizes
num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe())
return SymmBuffer(
group, num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
use_fp8_dispatch, activation
)
def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
# [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up]
def interleave(t, gran: int = 8) -> torch.Tensor:
g, n, *rest = t.shape
half = n // 2
gate = t[:, :half].reshape(g, half // gran, gran, *rest)
up = t[:, half:].reshape(g, half // gran, gran, *rest)
return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest))
return interleave(l1_weights[0]), interleave(l1_weights[1])
def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
num_groups, mn, packed_sf_k = sf.shape
assert sf.dtype == torch.int and mn % 128 == 0
result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k)
.transpose(2, 3)
.reshape(num_groups, mn, packed_sf_k))
return torch.empty_like(sf).copy_(result)
def transform_weights_for_mega_moe(
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
# L1: interleave gate/up, then transpose SF for UTCCP
l1_interleaved = _interleave_l1_weights(l1_weights)
l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1]))
# L2: only transpose SF for UTCCP
l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1]))
return l1_weights, l2_weights
def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
"""Pack NVFP4 UE4M3 block scales (float8_e4m3fn) into int32 UTCCP layout.
NVFP4 uses UE4M3 scales with group_size=16 (scale_vec::4X).
The UTCCP layout packs 4 consecutive scale bytes into each int32,
then applies the 4x32 transpose for TMA consumption.
Input: (num_experts, mn, K//16) float8_e4m3fn scales
Output: (num_experts, mn, K//64) int32 packed UTCCP-transposed scales
"""
num_groups, mn, sf_k = sf.shape
assert sf_k % 4 == 0, f"NVFP4 SF K dim must be divisible by 4, got {sf_k}"
assert mn % 128 == 0, f"MN dim must be divisible by 128, got {mn}"
# View as uint8 and pack 4 consecutive bytes into int32
sf_uint8 = sf.view(torch.uint8) # (num_groups, mn, sf_k)
# Pack: every 4 uint8 → 1 int32
packed = (sf_uint8[..., 0::4].to(torch.int32) |
(sf_uint8[..., 1::4].to(torch.int32) << 8) |
(sf_uint8[..., 2::4].to(torch.int32) << 16) |
(sf_uint8[..., 3::4].to(torch.int32) << 24)) # (num_groups, mn, sf_k//4)
# Apply UTCCP 4x32 transpose (same as MXFP4 — the transpose is determined
# by the 128-element alignment, not the scale vector size)
packed_sf_k = sf_k // 4
result = (packed.reshape(num_groups, -1, 4, 32, packed_sf_k)
.transpose(2, 3)
.reshape(num_groups, mn, packed_sf_k))
return torch.empty_like(packed).copy_(result)
def transform_nvfp4_weights_for_mega_moe(
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor],
l1_weight_scale_2: Optional[torch.Tensor] = None,
l2_weight_scale_2: Optional[torch.Tensor] = None
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
"""Transform NVFP4 expert weights for the mega_moe kernel.
Uses deep_gemm.transform_sf_into_required_layout for proper TMA-aligned
UTCCP layout with recipe (1, 1, 16) for NVFP4 group_size=16.
"""
from deep_gemm import transform_sf_into_required_layout
def fold_global_scale(sf: torch.Tensor, scale_2: Optional[torch.Tensor]) -> torch.Tensor:
if scale_2 is None:
return sf
sf_f32 = sf.to(torch.float32)
if scale_2.dim() == 1:
scale_2 = scale_2.view(-1, 1, 1)
sf_f32 = sf_f32 * scale_2
sf_f32 = sf_f32.clamp(0.0, 448.0)
return sf_f32.to(torch.float8_e4m3fn)
l1_sf = fold_global_scale(l1_weights[1], l1_weight_scale_2)
l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2)
# Merge NVFP4 block16 scales → block32 for SM100 (scale_vec::2X)
# B200 (SM100) doesn't support scale_vec::4X (block16) — requires SM103/SM120
# Take the max of each pair of adjacent block16 scales for block32
def merge_block16_to_block32(sf):
# sf: (experts, mn, K//16) float8_e4m3fn
# output: (experts, mn, K//32) uint8 (UE8M0)
# SM100 (B200) doesn't support mxf4nvf4 — must use mxf8f6f4 with UE8M0 scales
# Convert UE4M3 → float32 → UE8M0 (power-of-2)
sf_f32 = sf.to(torch.float32)
# Take max of adjacent pairs
sf_merged = torch.maximum(sf_f32[..., 0::2], sf_f32[..., 1::2])
# Convert to UE8M0: 2^(floor(log2(v)) - 127 + 127) = extract exponent byte
# UE8M0 encoding: uint8 = float32_exponent_bits >> 23
sf_u32 = sf_merged.view(torch.uint32)
sf_ue8m0 = (sf_u32 >> 23).to(torch.uint8)
return sf_ue8m0
l1_sf_32 = merge_block16_to_block32(l1_sf)
l2_sf_32 = merge_block16_to_block32(l2_sf)
num_experts = l1_weights[0].shape[0]
l1_n = l1_weights[0].shape[1]
l1_k = l1_weights[0].shape[2] * 2
l2_n = l2_weights[0].shape[1]
l2_k = l2_weights[0].shape[2] * 2
# Pack UE8M0 (uint8) block scales into int32 for DeepGEMM TMA consumption
# Same packing as MXFP4: 4 uint8 → 1 int32
def pack_uint8_to_int32(sf):
assert sf.dtype == torch.uint8
assert sf.shape[-1] % 4 == 0
packed = (sf[..., 0::4].to(torch.int32) |
(sf[..., 1::4].to(torch.int32) << 8) |
(sf[..., 2::4].to(torch.int32) << 16) |
(sf[..., 3::4].to(torch.int32) << 24))
return packed.contiguous()
l1_sf_packed = pack_uint8_to_int32(l1_sf_32)
l2_sf_packed = pack_uint8_to_int32(l2_sf_32)
print(f"[NVFP4-MoE] l1_sf_32: shape={l1_sf_32.shape}, l1_sf_packed: shape={l1_sf_packed.shape}")
print(f"[NVFP4-MoE] l2_sf_32: shape={l2_sf_32.shape}, l2_sf_packed: shape={l2_sf_packed.shape}")
print(f"[NVFP4-MoE] l1_n={l1_n} l1_k={l1_k} l2_n={l2_n} l2_k={l2_k}")
# Transpose to MN-major layout (stride(-2)=1) and make contiguous
# transform_sf_into_required_layout expects MN-major input for TMA stride checks
l1_sf_mn = l1_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1)
l2_sf_mn = l2_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1)
# Transform SF into TMA-aligned UTCCP layout using DeepGEMM's C++ function
# recipe (1, 32): gran_mn=1, gran_k=16
l1_sf_transformed = transform_sf_into_required_layout(
l1_sf_mn, l1_n, l1_k, (1, 32), num_experts)
l2_sf_transformed = transform_sf_into_required_layout(
l2_sf_mn, l2_n, l2_k, (1, 32), num_experts)
# L1: interleave gate/up
l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf_packed))
# DeepGEMM expects int8 (kPackedFP4 = torch.kInt8)
l1_out = (l1_interleaved[0].view(torch.int8), l1_sf_transformed)
l2_out = (l2_weights[0].view(torch.int8), l2_sf_transformed)
return l1_out, l2_out
def fp8_fp4_mega_moe(y: torch.Tensor,
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor],
sym_buffer: SymmBuffer,
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
recipe: Tuple[int, int, int] = (1, 1, 32),
activation: str = 'swiglu',
activation_clamp: Optional[float] = None,
fast_math: bool = True):
_C.fp8_fp4_mega_moe(
y,
l1_weights, l2_weights,
cumulative_local_expert_recv_stats,
sym_buffer.buffer,
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
sym_buffer.num_max_tokens_per_rank,
sym_buffer.num_experts, sym_buffer.num_topk,
recipe,
activation, activation_clamp,
fast_math
)
def get_symm_buffer_for_nvfp4_mega_moe(
group: "dist.ProcessGroup",
num_experts: int,
num_max_tokens_per_rank: int, num_topk: int,
hidden: int, intermediate_hidden: int,
use_fp8_dispatch: bool = True,
activation: str = 'swiglu') -> SymmBuffer:
"""Allocate a SymmBuffer sized for NVFP4 mega_moe (group_size=16)."""
from .. import _C
num_max_tokens_per_rank = align(num_max_tokens_per_rank,
_C.get_token_alignment_for_nvfp4_mega_moe())
buf = SymmBuffer.__new__(SymmBuffer)
buf.group = group
buf.num_experts = num_experts
buf.num_max_tokens_per_rank = num_max_tokens_per_rank
buf.num_topk = num_topk
buf.hidden = hidden
buf.intermediate_hidden = intermediate_hidden
# Use NVFP4-specific buffer size (2x SF due to group_size=16)
num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_nvfp4_mega_moe(
group.size(), num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
use_fp8_dispatch, activation)
import torch.distributed._symmetric_memory as symm_mem
import torch.distributed as dist
buf.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda')
buf.handle = symm_mem.rendezvous(buf.buffer, group=group)
buf.buffer.zero_()
buf.group.barrier()
torch.cuda.synchronize()
buf.x, buf.x_sf, buf.topk_idx, buf.topk_weights, \
buf.l1_acts, buf.l1_acts_sf, buf.l2_acts, buf.l2_acts_sf = \
slice_input_buffers(buf.buffer)
return buf
def fp8_nvfp4_mega_moe(y: torch.Tensor,
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor],
sym_buffer: SymmBuffer,
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
recipe: Tuple[int, int, int] = (1, 1, 32),
activation: str = 'swiglu',
activation_clamp: Optional[float] = None,
fast_math: bool = True):
"""NVFP4 mega MoE: uses kind::mxf4nvf4.block_scale.scale_vec::4X
with UE4M3 block scales (group_size=16).
Weight format: (uint8 E2M1 packed, int32 packed UTCCP UE4M3 scales)
Recipe: (1, 1, 16) — kGranK=16 for NVFP4 group_size=16.
"""
_C.fp8_nvfp4_mega_moe(
y,
l1_weights, l2_weights,
cumulative_local_expert_recv_stats,
sym_buffer.buffer,
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
sym_buffer.num_max_tokens_per_rank,
sym_buffer.num_experts, sym_buffer.num_topk,
recipe,
activation, activation_clamp,
fast_math
)