342 lines
14 KiB
Python
342 lines
14 KiB
Python
import torch
|
||
from typing import Tuple, Optional
|
||
from ..utils.math import align
|
||
|
||
# noinspection PyBroadException
|
||
try:
|
||
# noinspection PyProtectedMember
|
||
import torch.distributed._symmetric_memory as symm_mem
|
||
import torch.distributed as dist
|
||
except Exception as exception:
|
||
print(f'Failed to load mega kernels, please check your PyTorch version: {exception}')
|
||
|
||
from .. import _C
|
||
|
||
|
||
class SymmBuffer:
|
||
def __init__(self, group: dist.ProcessGroup,
|
||
# MoE arguments
|
||
num_experts: int,
|
||
num_max_tokens_per_rank: int, num_topk: int,
|
||
hidden: int, intermediate_hidden: int,
|
||
use_fp8_dispatch: bool = True,
|
||
activation: str = 'swiglu'):
|
||
self.group = group
|
||
self.num_experts = num_experts
|
||
self.num_max_tokens_per_rank = num_max_tokens_per_rank
|
||
self.num_topk = num_topk
|
||
self.hidden = hidden
|
||
self.intermediate_hidden = intermediate_hidden
|
||
|
||
# Allocate a symmetric buffer
|
||
num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe(
|
||
group.size(), num_experts,
|
||
num_max_tokens_per_rank, num_topk,
|
||
hidden, intermediate_hidden,
|
||
use_fp8_dispatch, activation
|
||
)
|
||
self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda')
|
||
self.handle = symm_mem.rendezvous(self.buffer, group=group)
|
||
self.buffer.zero_()
|
||
self.group.barrier()
|
||
torch.cuda.synchronize()
|
||
|
||
# Create input buffer views
|
||
(self.x, self.x_sf,
|
||
self.topk_idx, self.topk_weights,
|
||
self.l1_acts, self.l1_acts_sf,
|
||
self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer)
|
||
|
||
def destroy(self):
|
||
self.handle = None
|
||
self.buffer = None
|
||
self.group = None
|
||
self.x = None
|
||
self.x_sf = None
|
||
|
||
|
||
def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup,
|
||
num_experts: int,
|
||
num_max_tokens_per_rank: int, num_topk: int,
|
||
hidden: int, intermediate_hidden: int,
|
||
use_fp8_dispatch: bool = True,
|
||
activation: str = 'swiglu') -> SymmBuffer:
|
||
# Token count must be aligned to block sizes
|
||
num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe())
|
||
|
||
return SymmBuffer(
|
||
group, num_experts,
|
||
num_max_tokens_per_rank, num_topk,
|
||
hidden, intermediate_hidden,
|
||
use_fp8_dispatch, activation
|
||
)
|
||
|
||
|
||
class NVFP4SymmBuffer:
|
||
"""Symmetric buffer for NVFP4 mega_moe kernel.
|
||
|
||
Same structure as SymmBuffer but with NVFP4-specific SF layout:
|
||
group_size=16 means 2x the SF data compared to MXFP4 (group_size=32).
|
||
x_sf is packed int32 (4 UE4M3 per int32) with shape (M, K//64).
|
||
"""
|
||
def __init__(self, group: dist.ProcessGroup,
|
||
num_experts: int,
|
||
num_max_tokens_per_rank: int, num_topk: int,
|
||
hidden: int, intermediate_hidden: int,
|
||
use_fp8_dispatch: bool = True,
|
||
activation: str = 'swiglu'):
|
||
self.group = group
|
||
self.num_experts = num_experts
|
||
self.num_max_tokens_per_rank = num_max_tokens_per_rank
|
||
self.num_topk = num_topk
|
||
self.hidden = hidden
|
||
self.intermediate_hidden = intermediate_hidden
|
||
|
||
# Allocate a symmetric buffer (NVFP4 variant)
|
||
num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_nvfp4_mega_moe(
|
||
group.size(), num_experts,
|
||
num_max_tokens_per_rank, num_topk,
|
||
hidden, intermediate_hidden,
|
||
use_fp8_dispatch, activation
|
||
)
|
||
self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda')
|
||
self.handle = symm_mem.rendezvous(self.buffer, group=group)
|
||
self.buffer.zero_()
|
||
self.group.barrier()
|
||
torch.cuda.synchronize()
|
||
|
||
# Create input buffer views
|
||
(self.x, self.x_sf,
|
||
self.topk_idx, self.topk_weights,
|
||
self.l1_acts, self.l1_acts_sf,
|
||
self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer)
|
||
|
||
def destroy(self):
|
||
self.handle = None
|
||
self.buffer = None
|
||
self.group = None
|
||
self.x = None
|
||
self.x_sf = None
|
||
|
||
|
||
def get_symm_buffer_for_nvfp4_mega_moe(group: dist.ProcessGroup,
|
||
num_experts: int,
|
||
num_max_tokens_per_rank: int, num_topk: int,
|
||
hidden: int, intermediate_hidden: int,
|
||
use_fp8_dispatch: bool = True,
|
||
activation: str = 'swiglu') -> NVFP4SymmBuffer:
|
||
# Token count must be aligned to block sizes
|
||
num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_nvfp4_mega_moe())
|
||
|
||
return NVFP4SymmBuffer(
|
||
group, num_experts,
|
||
num_max_tokens_per_rank, num_topk,
|
||
hidden, intermediate_hidden,
|
||
use_fp8_dispatch, activation
|
||
)
|
||
|
||
|
||
def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
# [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up]
|
||
def interleave(t, gran: int = 8) -> torch.Tensor:
|
||
g, n, *rest = t.shape
|
||
half = n // 2
|
||
gate = t[:, :half].reshape(g, half // gran, gran, *rest)
|
||
up = t[:, half:].reshape(g, half // gran, gran, *rest)
|
||
return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest))
|
||
|
||
return interleave(l1_weights[0]), interleave(l1_weights[1])
|
||
|
||
|
||
def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
|
||
num_groups, mn, packed_sf_k = sf.shape
|
||
assert sf.dtype == torch.int and mn % 128 == 0
|
||
result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k)
|
||
.transpose(2, 3)
|
||
.reshape(num_groups, mn, packed_sf_k))
|
||
return torch.empty_like(sf).copy_(result)
|
||
|
||
|
||
def transform_weights_for_mega_moe(
|
||
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
||
l2_weights: Tuple[torch.Tensor, torch.Tensor]
|
||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||
# L1: interleave gate/up, then transpose SF for UTCCP
|
||
l1_interleaved = _interleave_l1_weights(l1_weights)
|
||
l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1]))
|
||
# L2: only transpose SF for UTCCP
|
||
l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1]))
|
||
return l1_weights, l2_weights
|
||
|
||
|
||
def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
|
||
"""Pack NVFP4 UE4M3 block scales (float8_e4m3fn) into int32 UTCCP layout.
|
||
|
||
NVFP4 uses UE4M3 scales with group_size=16 (scale_vec::4X).
|
||
The UTCCP layout packs 4 consecutive scale bytes into each int32,
|
||
then applies the 4x32 transpose for TMA consumption.
|
||
|
||
Input: (num_experts, mn, K//16) float8_e4m3fn scales
|
||
Output: (num_experts, mn, K//64) int32 packed UTCCP-transposed scales
|
||
"""
|
||
num_groups, mn, sf_k = sf.shape
|
||
assert sf_k % 4 == 0, f"NVFP4 SF K dim must be divisible by 4, got {sf_k}"
|
||
assert mn % 128 == 0, f"MN dim must be divisible by 128, got {mn}"
|
||
|
||
# View as uint8 and pack 4 consecutive bytes into int32
|
||
sf_uint8 = sf.view(torch.uint8) # (num_groups, mn, sf_k)
|
||
# Pack: every 4 uint8 → 1 int32
|
||
packed = (sf_uint8[..., 0::4].to(torch.int32) |
|
||
(sf_uint8[..., 1::4].to(torch.int32) << 8) |
|
||
(sf_uint8[..., 2::4].to(torch.int32) << 16) |
|
||
(sf_uint8[..., 3::4].to(torch.int32) << 24)) # (num_groups, mn, sf_k//4)
|
||
|
||
# Apply UTCCP 4x32 transpose (same as MXFP4 — the transpose is determined
|
||
# by the 128-element alignment, not the scale vector size)
|
||
packed_sf_k = sf_k // 4
|
||
result = (packed.reshape(num_groups, -1, 4, 32, packed_sf_k)
|
||
.transpose(2, 3)
|
||
.reshape(num_groups, mn, packed_sf_k))
|
||
return torch.empty_like(packed).copy_(result)
|
||
|
||
|
||
def transform_nvfp4_weights_for_mega_moe(
|
||
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
||
l2_weights: Tuple[torch.Tensor, torch.Tensor],
|
||
l1_weight_scale_2: Optional[torch.Tensor] = None,
|
||
l2_weight_scale_2: Optional[torch.Tensor] = None
|
||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||
"""Transform NVFP4 expert weights for the mega_moe kernel.
|
||
|
||
NVFP4 weights come as (weight, scale) where:
|
||
- weight: uint8 E2M1 packed, shape (num_experts, N, K//2)
|
||
- scale: float8_e4m3fn UE4M3 block scales, shape (num_experts, N, K//16)
|
||
|
||
The kernel expects (weight, packed_sf) where packed_sf is int32 TMA-aligned
|
||
UTCCP layout with gran_k=16.
|
||
|
||
If weight_scale_2 (float32 global scale) is provided, it is folded into the
|
||
block scales: effective_scale = block_scale * global_scale → re-quantized to UE4M3.
|
||
"""
|
||
from deep_gemm import transform_sf_into_required_layout
|
||
|
||
def fold_global_scale(sf: torch.Tensor, scale_2: Optional[torch.Tensor]) -> torch.Tensor:
|
||
"""Fold weight_scale_2 into block scales: UE4M3 * FP32 → UE4M3"""
|
||
if scale_2 is None:
|
||
return sf
|
||
# UE4M3 → float32: checkpoint stores float8_e4m3fn (standard NVFP4 spec)
|
||
# NOT UE8M0 — shift-by-23 was wrong (Bug #7 fix: data IS E4M3, not E8M0)
|
||
sf_f32 = sf.to(torch.float32)
|
||
if scale_2.dim() == 1:
|
||
scale_2 = scale_2.view(-1, 1, 1)
|
||
sf_f32 = sf_f32 * scale_2
|
||
sf_f32 = sf_f32.clamp(0.0, 448.0)
|
||
return sf_f32.to(torch.float8_e4m3fn)
|
||
|
||
# Fold global scales into block scales
|
||
l1_sf = fold_global_scale(l1_weights[1], l1_weight_scale_2)
|
||
l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2)
|
||
|
||
num_experts = l1_weights[0].shape[0]
|
||
l1_n = l1_weights[0].shape[1]
|
||
l1_k = l1_weights[0].shape[2] * 2 # K (weight is K//2 uint8)
|
||
l2_n = l2_weights[0].shape[1]
|
||
l2_k = l2_weights[0].shape[2] * 2
|
||
|
||
# Pack UE4M3 (float8_e4m3fn) into int32 for DeepGEMM TMA consumption
|
||
# 4 UE4M3 bytes → 1 int32
|
||
def pack_ue4m3_to_int32(sf):
|
||
sf_u8 = sf.view(torch.uint8)
|
||
assert sf_u8.shape[-1] % 4 == 0
|
||
packed = (sf_u8[..., 0::4].to(torch.int32) |
|
||
(sf_u8[..., 1::4].to(torch.int32) << 8) |
|
||
(sf_u8[..., 2::4].to(torch.int32) << 16) |
|
||
(sf_u8[..., 3::4].to(torch.int32) << 24))
|
||
return packed.contiguous()
|
||
|
||
l1_sf_packed = pack_ue4m3_to_int32(l1_sf)
|
||
l2_sf_packed = pack_ue4m3_to_int32(l2_sf)
|
||
|
||
# Transpose to MN-major layout (stride(-2)=1) and make contiguous
|
||
# transform_sf_into_required_layout expects MN-major input for TMA stride checks
|
||
l1_sf_mn = l1_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1)
|
||
l2_sf_mn = l2_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1)
|
||
|
||
# Transform SF into TMA-aligned UTCCP layout using DeepGEMM's C++ function
|
||
# recipe (1, 16): gran_mn=1, gran_k=16 (NVFP4 native block16)
|
||
l1_sf_transformed = transform_sf_into_required_layout(
|
||
l1_sf_mn, l1_n, l1_k, (1, 16), num_experts)
|
||
l2_sf_transformed = transform_sf_into_required_layout(
|
||
l2_sf_mn, l2_n, l2_k, (1, 16), num_experts)
|
||
|
||
# L1: interleave gate/up
|
||
l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf_transformed))
|
||
# DeepGEMM expects int8 (kPackedFP4 = torch.kInt8)
|
||
l1_out = (l1_interleaved[0].view(torch.int8), l1_interleaved[1])
|
||
l2_out = (l2_weights[0].view(torch.int8), l2_sf_transformed)
|
||
return l1_out, l2_out
|
||
|
||
|
||
def fp8_fp4_mega_moe(y: torch.Tensor,
|
||
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
||
l2_weights: Tuple[torch.Tensor, torch.Tensor],
|
||
sym_buffer: SymmBuffer,
|
||
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
|
||
recipe: Tuple[int, int, int] = (1, 1, 32),
|
||
activation: str = 'swiglu',
|
||
activation_clamp: Optional[float] = None,
|
||
fast_math: bool = True):
|
||
_C.fp8_fp4_mega_moe(
|
||
y,
|
||
l1_weights, l2_weights,
|
||
cumulative_local_expert_recv_stats,
|
||
sym_buffer.buffer,
|
||
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
|
||
sym_buffer.num_max_tokens_per_rank,
|
||
sym_buffer.num_experts, sym_buffer.num_topk,
|
||
recipe,
|
||
activation, activation_clamp,
|
||
fast_math
|
||
)
|
||
|
||
|
||
def fp8_nvfp4_mega_moe(y: torch.Tensor,
|
||
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
||
l2_weights: Tuple[torch.Tensor, torch.Tensor],
|
||
sym_buffer: SymmBuffer,
|
||
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
|
||
recipe: Tuple[int, int, int] = (1, 1, 16),
|
||
activation: str = 'swiglu',
|
||
activation_clamp: Optional[float] = None,
|
||
fast_math: bool = True):
|
||
"""NVFP4 mega MoE: uses kind::mxf4nvf4.block_scale.scale_vec::4X
|
||
with UE4M3 block scales (group_size=16).
|
||
|
||
Both activations AND weights are E2M1 packed (FP4×FP4).
|
||
Weight format: (uint8 E2M1 packed, int32 packed UTCCP UE4M3 scales)
|
||
Activation format: E2M1 packed uint8 + UE4M3 scales (computed by staging kernel)
|
||
Recipe: (1, 1, 16) — kGranK=16 for NVFP4 group_size=16.
|
||
"""
|
||
for name, t in [("l1_w", l1_weights), ("l1_w_sf", l1_weights_sf),
|
||
("l2_w", l2_weights), ("l2_w_sf", l2_weights_sf)]:
|
||
print(f"[debug] {name}: dtype={t.dtype} shape={tuple(t.shape)} contig={t.is_contiguous()}", flush=True)
|
||
|
||
# Also check symm buffer views
|
||
for name, t in [("sym_x", sym_buffer.x), ("sym_x_sf", sym_buffer.x_sf),
|
||
("sym_l1_acts", sym_buffer.l1_acts), ("sym_l1_acts_sf", sym_buffer.l1_acts_sf),
|
||
("sym_l2_acts", sym_buffer.l2_acts), ("sym_l2_acts_sf", sym_buffer.l2_acts_sf)]:
|
||
print(f"[debug] {name}: dtype={t.dtype} shape={tuple(t.shape)} contig={t.is_contiguous()}", flush=True)
|
||
|
||
_C.fp8_nvfp4_mega_moe(
|
||
y,
|
||
l1_weights, l2_weights,
|
||
cumulative_local_expert_recv_stats,
|
||
sym_buffer.buffer,
|
||
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
|
||
sym_buffer.num_max_tokens_per_rank,
|
||
sym_buffer.num_experts, sym_buffer.num_topk,
|
||
recipe,
|
||
activation, activation_clamp,
|
||
fast_math
|
||
)
|