Files
DeepGEMM/deep_gemm/mega/__init__.py
biondizzle b3d1aae038 feat: full FP4 activations for mxf4nvf4 - E2M1 packed A side + UE4M3 scales
mxf4nvf4 requires BOTH A and B to be FP4 (E2M1 packed).
Changes:
- a_dtype_t: float_e4m3_t → float_e2m1_unpacksmem_t
- UMMA_K: 32 → 64 (FP4 MMA atom)
- L1 epilogue: FP8 quant → E2M1 FP4 quantization with nearest-neighbor
- L1 output SMEM: packed E2M1 (2 per byte), TMA store uint8
- TMA descriptors: adjusted for FP4 packing (K/2 bytes per row)
- SymmBuffer: uint8 activations, shape (M, K//2)
- Staging kernel: BF16 → E2M1 packed + UE4M3 block16 scales
2026-05-11 20:29:08 +00:00

332 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
from typing import Tuple, Optional
from ..utils.math import align
# noinspection PyBroadException
try:
# noinspection PyProtectedMember
import torch.distributed._symmetric_memory as symm_mem
import torch.distributed as dist
except Exception as exception:
print(f'Failed to load mega kernels, please check your PyTorch version: {exception}')
from .. import _C
class SymmBuffer:
def __init__(self, group: dist.ProcessGroup,
# MoE arguments
num_experts: int,
num_max_tokens_per_rank: int, num_topk: int,
hidden: int, intermediate_hidden: int,
use_fp8_dispatch: bool = True,
activation: str = 'swiglu'):
self.group = group
self.num_experts = num_experts
self.num_max_tokens_per_rank = num_max_tokens_per_rank
self.num_topk = num_topk
self.hidden = hidden
self.intermediate_hidden = intermediate_hidden
# Allocate a symmetric buffer
num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe(
group.size(), num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
use_fp8_dispatch, activation
)
self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda')
self.handle = symm_mem.rendezvous(self.buffer, group=group)
self.buffer.zero_()
self.group.barrier()
torch.cuda.synchronize()
# Create input buffer views
(self.x, self.x_sf,
self.topk_idx, self.topk_weights,
self.l1_acts, self.l1_acts_sf,
self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer)
def destroy(self):
self.handle = None
self.buffer = None
self.group = None
self.x = None
self.x_sf = None
def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup,
num_experts: int,
num_max_tokens_per_rank: int, num_topk: int,
hidden: int, intermediate_hidden: int,
use_fp8_dispatch: bool = True,
activation: str = 'swiglu') -> SymmBuffer:
# Token count must be aligned to block sizes
num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe())
return SymmBuffer(
group, num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
use_fp8_dispatch, activation
)
class NVFP4SymmBuffer:
"""Symmetric buffer for NVFP4 mega_moe kernel.
Same structure as SymmBuffer but with NVFP4-specific SF layout:
group_size=16 means 2x the SF data compared to MXFP4 (group_size=32).
x_sf is packed int32 (4 UE4M3 per int32) with shape (M, K//64).
"""
def __init__(self, group: dist.ProcessGroup,
num_experts: int,
num_max_tokens_per_rank: int, num_topk: int,
hidden: int, intermediate_hidden: int,
use_fp8_dispatch: bool = True,
activation: str = 'swiglu'):
self.group = group
self.num_experts = num_experts
self.num_max_tokens_per_rank = num_max_tokens_per_rank
self.num_topk = num_topk
self.hidden = hidden
self.intermediate_hidden = intermediate_hidden
# Allocate a symmetric buffer (NVFP4 variant)
num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_nvfp4_mega_moe(
group.size(), num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
use_fp8_dispatch, activation
)
self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda')
self.handle = symm_mem.rendezvous(self.buffer, group=group)
self.buffer.zero_()
self.group.barrier()
torch.cuda.synchronize()
# Create input buffer views
(self.x, self.x_sf,
self.topk_idx, self.topk_weights,
self.l1_acts, self.l1_acts_sf,
self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer)
def destroy(self):
self.handle = None
self.buffer = None
self.group = None
self.x = None
self.x_sf = None
def get_symm_buffer_for_nvfp4_mega_moe(group: dist.ProcessGroup,
num_experts: int,
num_max_tokens_per_rank: int, num_topk: int,
hidden: int, intermediate_hidden: int,
use_fp8_dispatch: bool = True,
activation: str = 'swiglu') -> NVFP4SymmBuffer:
# Token count must be aligned to block sizes
num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_nvfp4_mega_moe())
return NVFP4SymmBuffer(
group, num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
use_fp8_dispatch, activation
)
def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
# [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up]
def interleave(t, gran: int = 8) -> torch.Tensor:
g, n, *rest = t.shape
half = n // 2
gate = t[:, :half].reshape(g, half // gran, gran, *rest)
up = t[:, half:].reshape(g, half // gran, gran, *rest)
return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest))
return interleave(l1_weights[0]), interleave(l1_weights[1])
def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
num_groups, mn, packed_sf_k = sf.shape
assert sf.dtype == torch.int and mn % 128 == 0
result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k)
.transpose(2, 3)
.reshape(num_groups, mn, packed_sf_k))
return torch.empty_like(sf).copy_(result)
def transform_weights_for_mega_moe(
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
# L1: interleave gate/up, then transpose SF for UTCCP
l1_interleaved = _interleave_l1_weights(l1_weights)
l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1]))
# L2: only transpose SF for UTCCP
l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1]))
return l1_weights, l2_weights
def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
"""Pack NVFP4 UE4M3 block scales (float8_e4m3fn) into int32 UTCCP layout.
NVFP4 uses UE4M3 scales with group_size=16 (scale_vec::4X).
The UTCCP layout packs 4 consecutive scale bytes into each int32,
then applies the 4x32 transpose for TMA consumption.
Input: (num_experts, mn, K//16) float8_e4m3fn scales
Output: (num_experts, mn, K//64) int32 packed UTCCP-transposed scales
"""
num_groups, mn, sf_k = sf.shape
assert sf_k % 4 == 0, f"NVFP4 SF K dim must be divisible by 4, got {sf_k}"
assert mn % 128 == 0, f"MN dim must be divisible by 128, got {mn}"
# View as uint8 and pack 4 consecutive bytes into int32
sf_uint8 = sf.view(torch.uint8) # (num_groups, mn, sf_k)
# Pack: every 4 uint8 → 1 int32
packed = (sf_uint8[..., 0::4].to(torch.int32) |
(sf_uint8[..., 1::4].to(torch.int32) << 8) |
(sf_uint8[..., 2::4].to(torch.int32) << 16) |
(sf_uint8[..., 3::4].to(torch.int32) << 24)) # (num_groups, mn, sf_k//4)
# Apply UTCCP 4x32 transpose (same as MXFP4 — the transpose is determined
# by the 128-element alignment, not the scale vector size)
packed_sf_k = sf_k // 4
result = (packed.reshape(num_groups, -1, 4, 32, packed_sf_k)
.transpose(2, 3)
.reshape(num_groups, mn, packed_sf_k))
return torch.empty_like(packed).copy_(result)
def transform_nvfp4_weights_for_mega_moe(
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor],
l1_weight_scale_2: Optional[torch.Tensor] = None,
l2_weight_scale_2: Optional[torch.Tensor] = None
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
"""Transform NVFP4 expert weights for the mega_moe kernel.
NVFP4 weights come as (weight, scale) where:
- weight: uint8 E2M1 packed, shape (num_experts, N, K//2)
- scale: float8_e4m3fn UE4M3 block scales, shape (num_experts, N, K//16)
The kernel expects (weight, packed_sf) where packed_sf is int32 TMA-aligned
UTCCP layout with gran_k=16.
If weight_scale_2 (float32 global scale) is provided, it is folded into the
block scales: effective_scale = block_scale * global_scale → re-quantized to UE4M3.
"""
from deep_gemm import transform_sf_into_required_layout
def fold_global_scale(sf: torch.Tensor, scale_2: Optional[torch.Tensor]) -> torch.Tensor:
"""Fold weight_scale_2 into block scales: UE4M3 * FP32 → UE4M3"""
if scale_2 is None:
return sf
# UE8M0 → float32: must reinterpret raw uint8 bits as IEEE 754 exponent,
# NOT cast float8_e4m3fn → float32 (Bug #7: E8M0 bytes misinterpreted as E4M3)
sf_f32 = (sf.view(torch.uint8).to(torch.int32) << 23).view(torch.float32)
if scale_2.dim() == 1:
scale_2 = scale_2.view(-1, 1, 1)
sf_f32 = sf_f32 * scale_2
sf_f32 = sf_f32.clamp(0.0, 448.0)
return sf_f32.to(torch.float8_e4m3fn)
# Fold global scales into block scales
l1_sf = fold_global_scale(l1_weights[1], l1_weight_scale_2)
l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2)
num_experts = l1_weights[0].shape[0]
l1_n = l1_weights[0].shape[1]
l1_k = l1_weights[0].shape[2] * 2 # K (weight is K//2 uint8)
l2_n = l2_weights[0].shape[1]
l2_k = l2_weights[0].shape[2] * 2
# Pack UE4M3 (float8_e4m3fn) into int32 for DeepGEMM TMA consumption
# 4 UE4M3 bytes → 1 int32
def pack_ue4m3_to_int32(sf):
sf_u8 = sf.view(torch.uint8)
assert sf_u8.shape[-1] % 4 == 0
packed = (sf_u8[..., 0::4].to(torch.int32) |
(sf_u8[..., 1::4].to(torch.int32) << 8) |
(sf_u8[..., 2::4].to(torch.int32) << 16) |
(sf_u8[..., 3::4].to(torch.int32) << 24))
return packed.contiguous()
l1_sf_packed = pack_ue4m3_to_int32(l1_sf)
l2_sf_packed = pack_ue4m3_to_int32(l2_sf)
# Transpose to MN-major layout (stride(-2)=1) and make contiguous
# transform_sf_into_required_layout expects MN-major input for TMA stride checks
l1_sf_mn = l1_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1)
l2_sf_mn = l2_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1)
# Transform SF into TMA-aligned UTCCP layout using DeepGEMM's C++ function
# recipe (1, 16): gran_mn=1, gran_k=16 (NVFP4 native block16)
l1_sf_transformed = transform_sf_into_required_layout(
l1_sf_mn, l1_n, l1_k, (1, 16), num_experts)
l2_sf_transformed = transform_sf_into_required_layout(
l2_sf_mn, l2_n, l2_k, (1, 16), num_experts)
# L1: interleave gate/up
l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf_transformed))
# DeepGEMM expects int8 (kPackedFP4 = torch.kInt8)
l1_out = (l1_interleaved[0].view(torch.int8), l1_interleaved[1])
l2_out = (l2_weights[0].view(torch.int8), l2_sf_transformed)
return l1_out, l2_out
def fp8_fp4_mega_moe(y: torch.Tensor,
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor],
sym_buffer: SymmBuffer,
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
recipe: Tuple[int, int, int] = (1, 1, 32),
activation: str = 'swiglu',
activation_clamp: Optional[float] = None,
fast_math: bool = True):
_C.fp8_fp4_mega_moe(
y,
l1_weights, l2_weights,
cumulative_local_expert_recv_stats,
sym_buffer.buffer,
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
sym_buffer.num_max_tokens_per_rank,
sym_buffer.num_experts, sym_buffer.num_topk,
recipe,
activation, activation_clamp,
fast_math
)
def fp8_nvfp4_mega_moe(y: torch.Tensor,
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor],
sym_buffer: SymmBuffer,
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
recipe: Tuple[int, int, int] = (1, 1, 16),
activation: str = 'swiglu',
activation_clamp: Optional[float] = None,
fast_math: bool = True):
"""NVFP4 mega MoE: uses kind::mxf4nvf4.block_scale.scale_vec::4X
with UE4M3 block scales (group_size=16).
Both activations AND weights are E2M1 packed (FP4×FP4).
Weight format: (uint8 E2M1 packed, int32 packed UTCCP UE4M3 scales)
Activation format: E2M1 packed uint8 + UE4M3 scales (computed by staging kernel)
Recipe: (1, 1, 16) — kGranK=16 for NVFP4 group_size=16.
"""
_C.fp8_nvfp4_mega_moe(
y,
l1_weights, l2_weights,
cumulative_local_expert_recv_stats,
sym_buffer.buffer,
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
sym_buffer.num_max_tokens_per_rank,
sym_buffer.num_experts, sym_buffer.num_topk,
recipe,
activation, activation_clamp,
fast_math
)