Files
nvfp4-megamoe-kernel/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py

383 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
NVFP4 Mega MoE Kernel — Full MoE with expert parallelism.
This is the main kernel that replaces fp8_nvfp4_mega_moe from DeepGEMM.
Architecture:
- L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with UE4M3 scales)
- SiLU+Mul activation
- L2 GEMM: down_proj (FP4 x FP4 → BF16 with UE4M3 scales)
- NVLink cross-rank sync handled by caller (not this kernel)
- Expert parallel: each rank handles NUM_EXPERTS/8 experts
The kernel uses native NVFP4 block-scaled MMA via tcgen05.mma
kind::mxf8f6f4.block_scale on Blackwell (SM100).
Native NVFP4 path:
E2M1 (int8, 2 vals/byte) × E2M1 + UE4M3 block-16 scales
→ native hardware block-scaled MMA in tensor cores
→ float32 accumulator
This replaces the dequantize-then-BF16-GEMM approach. The native path
performs the E2M1 × E2M1 with UE4M3 block scaling entirely in hardware,
avoiding the costly dequantization step.
"""
import os
import torch
def unpack_ue4m3_u32(x_u32):
"""Unpack uint32 packed UE4M3 scales to float8_e4m3fn.
Each uint32 contains 4 UE4M3 values packed in bits [0:8], [8:16], [16:24], [24:32].
Must use bit reinterpret (view), NOT value cast (to) — byte 0x3F is the float8
whose bits are 0x3F (~0.984), NOT the integer 63.
CUDA doesn't implement bitwise ops on uint32, so we cast to int32 first.
"""
# CUDA uint32 lacks bitwise ops — use int32
x_i32 = x_u32.to(torch.int32)
M, N = x_i32.shape
# Extract 4 bytes, cast to uint8, then bit-reinterpret to float8_e4m3fn
b0 = (x_i32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
b1 = ((x_i32 >> 8) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
b2 = ((x_i32 >> 16) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
b3 = ((x_i32 >> 24) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
# Interleave into (M, N*4)
out = torch.empty(M, N * 4, dtype=torch.float8_e4m3fn, device=x_u32.device)
out[:, 0::4] = b0
out[:, 1::4] = b1
out[:, 2::4] = b2
out[:, 3::4] = b3
return out
# CUTLASS native NVFP4 block-scaled GEMM (SM100 Blackwell)
# Primary path: uses CUTLASS MainloopSm100TmaUmmaWarpSpecializedBlockScaled
# which invokes mxf8f6f4.block_scale tensor core instructions directly.
MEGA_MOE_USE_CUTLASS = int(os.environ.get("MEGA_MOE_USE_CUTLASS", "1"))
try:
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import (
cutlass_nvfp4_blockscaled_gemm,
cutlass_grouped_nvfp4_gemm,
)
_CUTLASS_AVAILABLE = True
except ImportError:
_CUTLASS_AVAILABLE = False
# DeepSeek-V4-Pro dimensions
HIDDEN = 7168
INTERMEDIATE = 3072
NUM_EXPERTS = 256
NUM_RANKS = 8
NUM_TOPK = 6
# NVFP4 scale parameters
SF_GRANULARITY_K = 16 # UE4M3 group_size
SF_PACK_FACTOR = 4 # 4 UE4M3 values per uint32
# Runtime flags
MEGA_MOE_STATIC = int(os.environ.get("MEGA_MOE_STATIC", "0"))
MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
# ---------------------------------------------------------------------------
# Main kernel entry points
# ---------------------------------------------------------------------------
def nvfp4_mega_moe_l1(
x_fp4, # (num_tokens, K//2) int8 packed E2M1
x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3
l1_weights, # (E_per_rank, K//2, 2*INTER) int8, column-major for CUTLASS
l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major
topk_ids, # (num_tokens, NUM_TOPK) int32
topk_weights, # (num_tokens, NUM_TOPK) float32
num_experts_per_rank,
alpha=1.0, # fp32 scalar from stage_activation global scale
):
"""L1 GEMM: gate_up_proj — Native NVFP4 block-scaled MMA.
Uses tcgen05.mma.kind::mxf8f6f4.block_scale for native E2M1×E2M1
with UE4M3 block-16 scaling in tensor cores.
Falls back to dequantize+BF16 if native path unavailable.
"""
num_tokens = x_fp4.shape[0]
K_half = x_fp4.shape[1]
K = K_half * 2 # HIDDEN = 7168
N = l1_weights.shape[2] # 2 * INTERMEDIATE = 6144 (column-major: shape is E, K_half, N)
if MEGA_MOE_DEBUG:
print(f"[nvfp4_moe_l1] tokens={num_tokens} K={K} N={N} "
f"experts={num_experts_per_rank} native=1")
# Unpack uint32 packed UE4M3 scales to float8_e4m3fn
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales
output = cutlass_grouped_nvfp4_gemm(
x_fp4, x_sf_fp8,
l1_weights, w_sf_fp8,
topk_ids, topk_weights,
alpha=alpha,
)
return output # (num_tokens, 6144) bfloat16
def nvfp4_mega_moe_l2(
x_fp4, # (num_tokens, INTER//2) int8 packed E2M1
x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3
l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS
l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major
topk_ids, # (num_tokens, NUM_TOPK) int32
topk_weights, # (num_tokens, NUM_TOPK) float32
num_experts_per_rank,
alpha=1.0, # fp32 scalar from stage_activation global scale
):
"""L2 GEMM: down_proj — Native NVFP4 block-scaled MMA.
Same pipeline as L1 using native mxf8f6f4.block_scale MMA.
"""
num_tokens = x_fp4.shape[0]
K_half = x_fp4.shape[1]
K = K_half * 2 # INTERMEDIATE = 3072
N = l2_weights.shape[2] # HIDDEN = 7168 (column-major: shape is E, K_half, N)
if MEGA_MOE_DEBUG:
print(f"[nvfp4_moe_l2] tokens={num_tokens} K={K} N={N} "
f"experts={num_experts_per_rank} native=1")
# Unpack uint32 packed UE4M3 scales to float8_e4m3fn
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales
output = cutlass_grouped_nvfp4_gemm(
x_fp4, x_sf_fp8,
l2_weights, w_sf_fp8,
topk_ids, topk_weights,
alpha=alpha,
)
return output # (num_tokens, 7168) bfloat16
# E2M1 (FP4) representable magnitudes: {0, 0.5, 1, 1.5, 2, 3, 4, 6}
# Bit patterns (3-bit, no sign): 000=0, 001=0.5, 010=1, 011=1.5, 100=2, 101=3, 110=4, 111=6
# Full 4-bit nibble: bit 3 = sign, bits 2:0 = magnitude index
_E2M1_MAGNITUDES = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.float32)
def _quantize_to_e2m1(x_f32):
"""Quantize float32 values to E2M1 (FP4) nibble indices.
Maps each value to the nearest E2M1 representable magnitude,
then packs as 4-bit sign-magnitude nibbles.
Returns (nibbles, scales) where:
nibbles: (..., N) uint8 with 4-bit sign-magnitude per value
scales: (..., N//16) float8_e4m3fn block scales
"""
*batch, N = x_f32.shape
assert N % 16 == 0, f"Last dim {N} not divisible by 16 (block size)"
# Reshape into blocks of 16 for block-wise scaling
x_blocks = x_f32.reshape(*batch, N // 16, 16)
# Per-block absmax determines the scale
block_max = x_blocks.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8, max=448.0)
# Scale so that the max maps to 6.0 (largest E2M1 magnitude)
# Dequant: x_reconstructed = x_e2m1 * scale, where scale = block_max / 6.0
scale_f32 = block_max / 6.0
x_scaled = x_blocks / scale_f32.clamp(min=1e-8)
# Find nearest E2M1 magnitude for each value
signs = torch.sign(x_scaled) # +1, -1, or 0
abs_scaled = x_scaled.abs() # 0..6 range
# Nearest E2M1 magnitude: find closest in {0, 0.5, 1, 1.5, 2, 3, 4, 6}
mags = _E2M1_MAGNITUDES.to(device=abs_scaled.device)
# Distance from each value to each magnitude
dists = (abs_scaled.unsqueeze(-1) - mags).abs() # (..., 16, 8)
idx = dists.argmin(dim=-1) # (..., 16) — index into E2M1 magnitudes
# Clamp to valid range (safety)
idx = idx.clamp(0, 7).to(torch.uint8)
# Build 4-bit sign-magnitude nibble: bit3=sign, bits2:0=magnitude index
sign_bit = (signs < 0).to(torch.uint8) # 1 if negative
nibbles = (sign_bit << 3) | idx # (..., 16) uint8, values 0..15
# Pack 2 nibbles per byte: low nibble = even index, high nibble = odd index
nibbles = nibbles.reshape(*batch, N // 2, 2)
packed = (nibbles[..., 1] << 4) | nibbles[..., 0] # (..., N//2) uint8
# Scale factors: what the GEMM needs to reconstruct the original values
# dequant = e2m1_magnitude * scale, so scale = block_max / 6.0
sf = scale_f32.squeeze(-1).to(torch.float8_e4m3fn) # (..., N//16)
return packed.to(torch.int8), sf
def stage_activation(x_bf16):
"""Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales.
Two-level quantization matching the NVFP4 weight format:
1. Per-tensor global scale: amax / (6.0 * 448.0)
Normalizes the activation so that block scales fit in UE4M3 range.
2. Per-block (16 values) absmax scaling on the normalized values
Snap to nearest E2M1 representable value: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}
Pack as 4-bit sign-magnitude nibbles (bit3=sign, bits2:0=mag index)
Block scale = block_max / 6.0 stored as UE4M3 (float8_e4m3fn)
Returns (x_fp4, x_sf, input_global_scale) where:
x_fp4: packed E2M1 nibbles
x_sf: UE4M3 block scales (NOT folded with global scale)
input_global_scale: fp32 per-tensor scale, applied as GEMM alpha
The GEMM applies global scale via alpha: D = alpha * (A_sf * A_fp4) @ (B_sf * B_fp4)
This avoids fp32→UE4M3 round-trip from folding, preserving precision.
"""
x_f32 = x_bf16.float()
# Per-tensor global scale (same role as weight_scale_2)
# NVFP4 spec: global_scale = amax / (6.0 * 448.0)
# This ensures the largest block scale after normalization is ~448.0,
# which fits exactly in UE4M3 max (448.0 for E4M3).
x_amax = x_f32.abs().amax().to(torch.float32).clamp(min=1e-8)
input_global_scale = x_amax / (6.0 * 448.0)
# Normalize by global scale before block quantization.
# After this, values are in a range where block_max / 6.0 ≤ 448.0,
# so block scales fit in UE4M3 without saturation.
x_normalized = x_f32 / input_global_scale
x_fp4, x_sf = _quantize_to_e2m1(x_normalized)
return x_fp4, x_sf, input_global_scale
def nvfp4_mega_moe_full(
y, # output tensor (num_tokens, HIDDEN) bfloat16
transformed_l1_weights, # (l1_w, l1_sf) tuple from finalize_weights
transformed_l2_weights, # (l2_w, l2_sf) tuple from finalize_weights
symm_buffer, # SymmBuffer from get_symm_buffer
activation_clamp=None, # optional clamp value (unused in NVFP4)
fast_math=False, # fast math flag (unused in NVFP4)
):
"""Full mega_moe forward pass — replaces deep_gemm.mega.fp8_nvfp4_mega_moe.
API matches the DeepGEMM fp8_nvfp4_mega_moe call signature used in
the vLLM deepseek_v4.py patch:
fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer,
activation_clamp=..., fast_math=...)
Pipeline:
1. Read staged activation from symm_buffer (already quantized by staging kernel)
2. L1 GEMM: gate_up_proj (native NVFP4 block-scaled MMA)
3. SiLU + Mul (activation)
4. Quantize L1 output → FP4 + UE4M3 scales
5. L2 GEMM: down_proj (native NVFP4 block-scaled MMA)
6. Write to y (caller handles cross-rank all-reduce)
Uses tcgen05.mma.kind::mxf8f6f4.block_scale for native E2M1×E2M1
with UE4M3 block-16 scaling in Blackwell tensor cores.
"""
num_tokens = y.shape[0]
device = y.device
dtype = y.dtype
if MEGA_MOE_STATIC:
if MEGA_MOE_DEBUG:
print(f"[MEGA_MOE_STATIC] Skipping nvfp4_mega_moe, returning zeros "
f"shape=({num_tokens}, {y.shape[1]})")
y.zero_()
return
# Unpack transformed weights
l1_w, l1_sf = transformed_l1_weights
l2_w, l2_sf = transformed_l2_weights
# Step 1: Read staged activation from symm_buffer
x_fp4 = symm_buffer.x[:num_tokens]
x_sf = symm_buffer.x_sf[:num_tokens]
l1_global_scale = symm_buffer.input_global_scale # fp32, from stage_activation
topk_ids = symm_buffer.topk_idx[:num_tokens]
topk_weights = symm_buffer.topk_weights[:num_tokens]
# ALWAYS-ON debug: alpha and scale ranges
_x_sf_f32 = x_sf.to(torch.float32)
_igs = l1_global_scale if isinstance(l1_global_scale, float) else l1_global_scale.item() if hasattr(l1_global_scale, 'item') else float(l1_global_scale)
print(f"[ALPHA L1] alpha={_igs:.4e} x_sf range [{_x_sf_f32.min().item():.4e}, {_x_sf_f32.max().item():.4e}] x_fp4_absmax={x_fp4.view(torch.int8).abs().max().item()}")
# Convert global expert IDs to local expert IDs.
# vLLM's symm_buffer stores global IDs (0..383) but our weight tensors
# are indexed by local ID (0..47). Each rank handles a contiguous chunk:
# rank r gets experts [r*E_per_rank, (r+1)*E_per_rank).
num_experts_per_rank = l1_w.shape[0]
experts_start_idx = symm_buffer.experts_start_idx
topk_ids_local = topk_ids - experts_start_idx
if MEGA_MOE_DEBUG:
print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} "
f"topk_ids={topk_ids.shape} topk_ids range: {topk_ids.min().item()}-{topk_ids.max().item()} "
f"local: {topk_ids_local.min().item()}-{topk_ids_local.max().item()} "
f"l1_w={l1_w.shape} l2_w={l2_w.shape}")
# NaN-trace: check activation scales at L1 input
if MEGA_MOE_DEBUG:
x_sf_f32 = x_sf.to(torch.float32)
print(f"[L1-in] x_sf nan={torch.isnan(x_sf_f32).any().item()} "
f"inf={torch.isinf(x_sf_f32).any().item()} "
f"min={x_sf_f32.min().item():.4e} max={x_sf_f32.max().item():.4e}")
# Step 2: L1 GEMM (native NVFP4 block-scaled MMA)
l1_output = nvfp4_mega_moe_l1(
x_fp4, x_sf, l1_w, l1_sf,
topk_ids_local, topk_weights, num_experts_per_rank,
alpha=l1_global_scale,
)
# NaN-trace: check L1 output
if MEGA_MOE_DEBUG:
print(f"[L1-out] nan={torch.isnan(l1_output).any().item()} "
f"inf={torch.isinf(l1_output).any().item()} "
f"abs_max={l1_output.abs().max().item():.4e}")
# Step 3: SiLU + Mul
gate, up = l1_output.chunk(2, dim=-1)
activated = torch.nn.functional.silu(gate) * up
if activation_clamp is not None:
activated = activated.clamp(max=activation_clamp)
# NaN-trace: check SiLU output
if MEGA_MOE_DEBUG:
print(f"[silu] nan={torch.isnan(activated).any().item()} "
f"abs_max={activated.abs().max().item():.4e}")
# Step 4: Quantize L1 output → FP4
l1_fp4, l1_sf_out, l2_global_scale = stage_activation(activated)
# ALWAYS-ON debug: L2 alpha and scale ranges
_l1sf_f32 = l1_sf_out.to(torch.float32)
_l2gs = l2_global_scale if isinstance(l2_global_scale, float) else l2_global_scale.item() if hasattr(l2_global_scale, 'item') else float(l2_global_scale)
print(f"[ALPHA L2] alpha={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}] activated amax={activated.abs().max().item():.4e}")
# Step 5: L2 GEMM (native NVFP4 block-scaled MMA)
l2_output = nvfp4_mega_moe_l2(
l1_fp4, l1_sf_out, l2_w, l2_sf,
topk_ids_local, topk_weights, num_experts_per_rank,
alpha=l2_global_scale,
)
# NaN-trace: check L2 output
if MEGA_MOE_DEBUG:
print(f"[L2-out] nan={torch.isnan(l2_output).any().item()} "
f"abs_max={l2_output.abs().max().item():.4e}")
# Step 6: Write to output (caller handles cross-rank all-reduce)
y.copy_(l2_output)