The C++ transform function expects int32 (for kInt type) with 4 UE4M3 bytes packed per int32. We pack first, then transform for TMA alignment and UTCCP transpose with recipe (1, 16).
254 lines
10 KiB
Python
254 lines
10 KiB
Python
import torch
|
|
from typing import Tuple, Optional
|
|
from ..utils.math import align
|
|
|
|
# noinspection PyBroadException
|
|
try:
|
|
# noinspection PyProtectedMember
|
|
import torch.distributed._symmetric_memory as symm_mem
|
|
import torch.distributed as dist
|
|
except Exception as exception:
|
|
print(f'Failed to load mega kernels, please check your PyTorch version: {exception}')
|
|
|
|
from .. import _C
|
|
|
|
|
|
class SymmBuffer:
|
|
def __init__(self, group: dist.ProcessGroup,
|
|
# MoE arguments
|
|
num_experts: int,
|
|
num_max_tokens_per_rank: int, num_topk: int,
|
|
hidden: int, intermediate_hidden: int,
|
|
use_fp8_dispatch: bool = True,
|
|
activation: str = 'swiglu'):
|
|
self.group = group
|
|
self.num_experts = num_experts
|
|
self.num_max_tokens_per_rank = num_max_tokens_per_rank
|
|
self.num_topk = num_topk
|
|
self.hidden = hidden
|
|
self.intermediate_hidden = intermediate_hidden
|
|
|
|
# Allocate a symmetric buffer
|
|
num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe(
|
|
group.size(), num_experts,
|
|
num_max_tokens_per_rank, num_topk,
|
|
hidden, intermediate_hidden,
|
|
use_fp8_dispatch, activation
|
|
)
|
|
self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda')
|
|
self.handle = symm_mem.rendezvous(self.buffer, group=group)
|
|
self.buffer.zero_()
|
|
self.group.barrier()
|
|
torch.cuda.synchronize()
|
|
|
|
# Create input buffer views
|
|
(self.x, self.x_sf,
|
|
self.topk_idx, self.topk_weights,
|
|
self.l1_acts, self.l1_acts_sf,
|
|
self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer)
|
|
|
|
def destroy(self):
|
|
self.handle = None
|
|
self.buffer = None
|
|
self.group = None
|
|
self.x = None
|
|
self.x_sf = None
|
|
|
|
|
|
def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup,
|
|
num_experts: int,
|
|
num_max_tokens_per_rank: int, num_topk: int,
|
|
hidden: int, intermediate_hidden: int,
|
|
use_fp8_dispatch: bool = True,
|
|
activation: str = 'swiglu') -> SymmBuffer:
|
|
# Token count must be aligned to block sizes
|
|
num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe())
|
|
|
|
return SymmBuffer(
|
|
group, num_experts,
|
|
num_max_tokens_per_rank, num_topk,
|
|
hidden, intermediate_hidden,
|
|
use_fp8_dispatch, activation
|
|
)
|
|
|
|
|
|
def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up]
|
|
def interleave(t, gran: int = 8) -> torch.Tensor:
|
|
g, n, *rest = t.shape
|
|
half = n // 2
|
|
gate = t[:, :half].reshape(g, half // gran, gran, *rest)
|
|
up = t[:, half:].reshape(g, half // gran, gran, *rest)
|
|
return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest))
|
|
|
|
return interleave(l1_weights[0]), interleave(l1_weights[1])
|
|
|
|
|
|
def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
|
|
num_groups, mn, packed_sf_k = sf.shape
|
|
assert sf.dtype == torch.int and mn % 128 == 0
|
|
result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k)
|
|
.transpose(2, 3)
|
|
.reshape(num_groups, mn, packed_sf_k))
|
|
return torch.empty_like(sf).copy_(result)
|
|
|
|
|
|
def transform_weights_for_mega_moe(
|
|
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
|
l2_weights: Tuple[torch.Tensor, torch.Tensor]
|
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
|
# L1: interleave gate/up, then transpose SF for UTCCP
|
|
l1_interleaved = _interleave_l1_weights(l1_weights)
|
|
l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1]))
|
|
# L2: only transpose SF for UTCCP
|
|
l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1]))
|
|
return l1_weights, l2_weights
|
|
|
|
|
|
def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
|
|
"""Pack NVFP4 UE4M3 block scales (float8_e4m3fn) into int32 UTCCP layout.
|
|
|
|
NVFP4 uses UE4M3 scales with group_size=16 (scale_vec::4X).
|
|
The UTCCP layout packs 4 consecutive scale bytes into each int32,
|
|
then applies the 4x32 transpose for TMA consumption.
|
|
|
|
Input: (num_experts, mn, K//16) float8_e4m3fn scales
|
|
Output: (num_experts, mn, K//64) int32 packed UTCCP-transposed scales
|
|
"""
|
|
num_groups, mn, sf_k = sf.shape
|
|
assert sf_k % 4 == 0, f"NVFP4 SF K dim must be divisible by 4, got {sf_k}"
|
|
assert mn % 128 == 0, f"MN dim must be divisible by 128, got {mn}"
|
|
|
|
# View as uint8 and pack 4 consecutive bytes into int32
|
|
sf_uint8 = sf.view(torch.uint8) # (num_groups, mn, sf_k)
|
|
# Pack: every 4 uint8 → 1 int32
|
|
packed = (sf_uint8[..., 0::4].to(torch.int32) |
|
|
(sf_uint8[..., 1::4].to(torch.int32) << 8) |
|
|
(sf_uint8[..., 2::4].to(torch.int32) << 16) |
|
|
(sf_uint8[..., 3::4].to(torch.int32) << 24)) # (num_groups, mn, sf_k//4)
|
|
|
|
# Apply UTCCP 4x32 transpose (same as MXFP4 — the transpose is determined
|
|
# by the 128-element alignment, not the scale vector size)
|
|
packed_sf_k = sf_k // 4
|
|
result = (packed.reshape(num_groups, -1, 4, 32, packed_sf_k)
|
|
.transpose(2, 3)
|
|
.reshape(num_groups, mn, packed_sf_k))
|
|
return torch.empty_like(packed).copy_(result)
|
|
|
|
|
|
def transform_nvfp4_weights_for_mega_moe(
|
|
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
|
l2_weights: Tuple[torch.Tensor, torch.Tensor],
|
|
l1_weight_scale_2: Optional[torch.Tensor] = None,
|
|
l2_weight_scale_2: Optional[torch.Tensor] = None
|
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
|
"""Transform NVFP4 expert weights for the mega_moe kernel.
|
|
|
|
Uses deep_gemm.transform_sf_into_required_layout for proper TMA-aligned
|
|
UTCCP layout with recipe (1, 1, 16) for NVFP4 group_size=16.
|
|
"""
|
|
from deep_gemm import transform_sf_into_required_layout
|
|
|
|
def fold_global_scale(sf: torch.Tensor, scale_2: Optional[torch.Tensor]) -> torch.Tensor:
|
|
if scale_2 is None:
|
|
return sf
|
|
sf_f32 = sf.to(torch.float32)
|
|
if scale_2.dim() == 1:
|
|
scale_2 = scale_2.view(-1, 1, 1)
|
|
sf_f32 = sf_f32 * scale_2
|
|
sf_f32 = sf_f32.clamp(0.0, 448.0)
|
|
return sf_f32.to(torch.float8_e4m3fn)
|
|
|
|
l1_sf = fold_global_scale(l1_weights[1], l1_weight_scale_2)
|
|
l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2)
|
|
|
|
num_experts = l1_weights[0].shape[0]
|
|
l1_n = l1_weights[0].shape[1]
|
|
l1_k = l1_weights[0].shape[2] * 2
|
|
l2_n = l2_weights[0].shape[1]
|
|
l2_k = l2_weights[0].shape[2] * 2
|
|
|
|
# Pack UE4M3 (float8_e4m3fn) into int32 for DeepGEMM TMA consumption
|
|
# 4 UE4M3 bytes → 1 int32, matching the hardware's 4X scale vector
|
|
def pack_ue4m3_to_int32(sf):
|
|
sf_u8 = sf.view(torch.uint8)
|
|
# Pack 4 consecutive uint8 bytes into int32
|
|
assert sf_u8.shape[-1] % 4 == 0
|
|
packed = (sf_u8[..., 0::4].to(torch.int32) |
|
|
(sf_u8[..., 1::4].to(torch.int32) << 8) |
|
|
(sf_u8[..., 2::4].to(torch.int32) << 16) |
|
|
(sf_u8[..., 3::4].to(torch.int32) << 24))
|
|
return packed.contiguous()
|
|
|
|
l1_sf_packed = pack_ue4m3_to_int32(l1_sf)
|
|
l2_sf_packed = pack_ue4m3_to_int32(l2_sf)
|
|
|
|
# Transform SF into TMA-aligned UTCCP layout using DeepGEMM's C++ function
|
|
# Pass as kInt with recipe (1, 16): gran_mn=1, gran_k=16
|
|
# After packing, effective K for SF is k/4 (4 UE4M3 per int32)
|
|
# check_sf_layout expects: sf.size(-1) = ceil_div(k, gran_k * 4) = ceil_div(k, 64)
|
|
# Our packed shape is (experts, mn, K/64) — matches!
|
|
l1_sf_transformed = transform_sf_into_required_layout(
|
|
l1_sf_packed, l1_n, l1_k, (1, 16), num_experts)
|
|
l2_sf_transformed = transform_sf_into_required_layout(
|
|
l2_sf_packed, l2_n, l2_k, (1, 16), num_experts)
|
|
|
|
# L1: interleave gate/up
|
|
l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf_packed))
|
|
# DeepGEMM expects int8 (kPackedFP4 = torch.kInt8)
|
|
l1_out = (l1_interleaved[0].view(torch.int8), l1_sf_transformed)
|
|
l2_out = (l2_weights[0].view(torch.int8), l2_sf_transformed)
|
|
return l1_out, l2_out
|
|
|
|
|
|
def fp8_fp4_mega_moe(y: torch.Tensor,
|
|
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
|
l2_weights: Tuple[torch.Tensor, torch.Tensor],
|
|
sym_buffer: SymmBuffer,
|
|
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
|
|
recipe: Tuple[int, int, int] = (1, 1, 32),
|
|
activation: str = 'swiglu',
|
|
activation_clamp: Optional[float] = None,
|
|
fast_math: bool = True):
|
|
_C.fp8_fp4_mega_moe(
|
|
y,
|
|
l1_weights, l2_weights,
|
|
cumulative_local_expert_recv_stats,
|
|
sym_buffer.buffer,
|
|
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
|
|
sym_buffer.num_max_tokens_per_rank,
|
|
sym_buffer.num_experts, sym_buffer.num_topk,
|
|
recipe,
|
|
activation, activation_clamp,
|
|
fast_math
|
|
)
|
|
|
|
|
|
def fp8_nvfp4_mega_moe(y: torch.Tensor,
|
|
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
|
l2_weights: Tuple[torch.Tensor, torch.Tensor],
|
|
sym_buffer: SymmBuffer,
|
|
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
|
|
recipe: Tuple[int, int, int] = (1, 1, 16),
|
|
activation: str = 'swiglu',
|
|
activation_clamp: Optional[float] = None,
|
|
fast_math: bool = True):
|
|
"""NVFP4 mega MoE: uses kind::mxf4nvf4.block_scale.scale_vec::4X
|
|
with UE4M3 block scales (group_size=16).
|
|
|
|
Weight format: (uint8 E2M1 packed, int32 packed UTCCP UE4M3 scales)
|
|
Recipe: (1, 1, 16) — kGranK=16 for NVFP4 group_size=16.
|
|
"""
|
|
_C.fp8_nvfp4_mega_moe(
|
|
y,
|
|
l1_weights, l2_weights,
|
|
cumulative_local_expert_recv_stats,
|
|
sym_buffer.buffer,
|
|
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
|
|
sym_buffer.num_max_tokens_per_rank,
|
|
sym_buffer.num_experts, sym_buffer.num_topk,
|
|
recipe,
|
|
activation, activation_clamp,
|
|
fast_math
|
|
)
|