383 lines
15 KiB
Python
383 lines
15 KiB
Python
"""
|
||
NVFP4 Mega MoE Kernel — Full MoE with expert parallelism.
|
||
|
||
This is the main kernel that replaces fp8_nvfp4_mega_moe from DeepGEMM.
|
||
|
||
Architecture:
|
||
- L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with UE4M3 scales)
|
||
- SiLU+Mul activation
|
||
- L2 GEMM: down_proj (FP4 x FP4 → BF16 with UE4M3 scales)
|
||
- NVLink cross-rank sync handled by caller (not this kernel)
|
||
- Expert parallel: each rank handles NUM_EXPERTS/8 experts
|
||
|
||
The kernel uses native NVFP4 block-scaled MMA via tcgen05.mma
|
||
kind::mxf8f6f4.block_scale on Blackwell (SM100).
|
||
|
||
Native NVFP4 path:
|
||
E2M1 (int8, 2 vals/byte) × E2M1 + UE4M3 block-16 scales
|
||
→ native hardware block-scaled MMA in tensor cores
|
||
→ float32 accumulator
|
||
|
||
This replaces the dequantize-then-BF16-GEMM approach. The native path
|
||
performs the E2M1 × E2M1 with UE4M3 block scaling entirely in hardware,
|
||
avoiding the costly dequantization step.
|
||
"""
|
||
|
||
import os
|
||
import torch
|
||
|
||
def unpack_ue4m3_u32(x_u32):
|
||
"""Unpack uint32 packed UE4M3 scales to float8_e4m3fn.
|
||
|
||
Each uint32 contains 4 UE4M3 values packed in bits [0:8], [8:16], [16:24], [24:32].
|
||
Must use bit reinterpret (view), NOT value cast (to) — byte 0x3F is the float8
|
||
whose bits are 0x3F (~0.984), NOT the integer 63.
|
||
|
||
CUDA doesn't implement bitwise ops on uint32, so we cast to int32 first.
|
||
"""
|
||
# CUDA uint32 lacks bitwise ops — use int32
|
||
x_i32 = x_u32.to(torch.int32)
|
||
M, N = x_i32.shape
|
||
|
||
# Extract 4 bytes, cast to uint8, then bit-reinterpret to float8_e4m3fn
|
||
b0 = (x_i32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
|
||
b1 = ((x_i32 >> 8) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
|
||
b2 = ((x_i32 >> 16) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
|
||
b3 = ((x_i32 >> 24) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
|
||
|
||
# Interleave into (M, N*4)
|
||
out = torch.empty(M, N * 4, dtype=torch.float8_e4m3fn, device=x_u32.device)
|
||
out[:, 0::4] = b0
|
||
out[:, 1::4] = b1
|
||
out[:, 2::4] = b2
|
||
out[:, 3::4] = b3
|
||
return out
|
||
|
||
# CUTLASS native NVFP4 block-scaled GEMM (SM100 Blackwell)
|
||
# Primary path: uses CUTLASS MainloopSm100TmaUmmaWarpSpecializedBlockScaled
|
||
# which invokes mxf8f6f4.block_scale tensor core instructions directly.
|
||
MEGA_MOE_USE_CUTLASS = int(os.environ.get("MEGA_MOE_USE_CUTLASS", "1"))
|
||
|
||
try:
|
||
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import (
|
||
cutlass_nvfp4_blockscaled_gemm,
|
||
cutlass_grouped_nvfp4_gemm,
|
||
)
|
||
_CUTLASS_AVAILABLE = True
|
||
except ImportError:
|
||
_CUTLASS_AVAILABLE = False
|
||
|
||
# DeepSeek-V4-Pro dimensions
|
||
HIDDEN = 7168
|
||
INTERMEDIATE = 3072
|
||
NUM_EXPERTS = 256
|
||
NUM_RANKS = 8
|
||
NUM_TOPK = 6
|
||
|
||
# NVFP4 scale parameters
|
||
SF_GRANULARITY_K = 16 # UE4M3 group_size
|
||
SF_PACK_FACTOR = 4 # 4 UE4M3 values per uint32
|
||
|
||
# Runtime flags
|
||
MEGA_MOE_STATIC = int(os.environ.get("MEGA_MOE_STATIC", "0"))
|
||
MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
|
||
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Main kernel entry points
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def nvfp4_mega_moe_l1(
|
||
x_fp4, # (num_tokens, K//2) int8 packed E2M1
|
||
x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3
|
||
l1_weights, # (E_per_rank, K//2, 2*INTER) int8, column-major for CUTLASS
|
||
l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major
|
||
topk_ids, # (num_tokens, NUM_TOPK) int32
|
||
topk_weights, # (num_tokens, NUM_TOPK) float32
|
||
num_experts_per_rank,
|
||
alpha=1.0, # fp32 scalar from stage_activation global scale
|
||
):
|
||
"""L1 GEMM: gate_up_proj — Native NVFP4 block-scaled MMA.
|
||
|
||
Uses tcgen05.mma.kind::mxf8f6f4.block_scale for native E2M1×E2M1
|
||
with UE4M3 block-16 scaling in tensor cores.
|
||
|
||
Falls back to dequantize+BF16 if native path unavailable.
|
||
"""
|
||
num_tokens = x_fp4.shape[0]
|
||
K_half = x_fp4.shape[1]
|
||
K = K_half * 2 # HIDDEN = 7168
|
||
N = l1_weights.shape[2] # 2 * INTERMEDIATE = 6144 (column-major: shape is E, K_half, N)
|
||
|
||
if MEGA_MOE_DEBUG:
|
||
print(f"[nvfp4_moe_l1] tokens={num_tokens} K={K} N={N} "
|
||
f"experts={num_experts_per_rank} native=1")
|
||
|
||
# Unpack uint32 packed UE4M3 scales to float8_e4m3fn
|
||
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
|
||
w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales
|
||
|
||
output = cutlass_grouped_nvfp4_gemm(
|
||
x_fp4, x_sf_fp8,
|
||
l1_weights, w_sf_fp8,
|
||
topk_ids, topk_weights,
|
||
alpha=alpha,
|
||
)
|
||
return output # (num_tokens, 6144) bfloat16
|
||
|
||
|
||
def nvfp4_mega_moe_l2(
|
||
x_fp4, # (num_tokens, INTER//2) int8 packed E2M1
|
||
x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3
|
||
l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS
|
||
l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major
|
||
topk_ids, # (num_tokens, NUM_TOPK) int32
|
||
topk_weights, # (num_tokens, NUM_TOPK) float32
|
||
num_experts_per_rank,
|
||
alpha=1.0, # fp32 scalar from stage_activation global scale
|
||
):
|
||
"""L2 GEMM: down_proj — Native NVFP4 block-scaled MMA.
|
||
|
||
Same pipeline as L1 using native mxf8f6f4.block_scale MMA.
|
||
"""
|
||
num_tokens = x_fp4.shape[0]
|
||
K_half = x_fp4.shape[1]
|
||
K = K_half * 2 # INTERMEDIATE = 3072
|
||
N = l2_weights.shape[2] # HIDDEN = 7168 (column-major: shape is E, K_half, N)
|
||
|
||
if MEGA_MOE_DEBUG:
|
||
print(f"[nvfp4_moe_l2] tokens={num_tokens} K={K} N={N} "
|
||
f"experts={num_experts_per_rank} native=1")
|
||
|
||
# Unpack uint32 packed UE4M3 scales to float8_e4m3fn
|
||
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
|
||
w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales
|
||
|
||
output = cutlass_grouped_nvfp4_gemm(
|
||
x_fp4, x_sf_fp8,
|
||
l2_weights, w_sf_fp8,
|
||
topk_ids, topk_weights,
|
||
alpha=alpha,
|
||
)
|
||
return output # (num_tokens, 7168) bfloat16
|
||
|
||
|
||
# E2M1 (FP4) representable magnitudes: {0, 0.5, 1, 1.5, 2, 3, 4, 6}
|
||
# Bit patterns (3-bit, no sign): 000=0, 001=0.5, 010=1, 011=1.5, 100=2, 101=3, 110=4, 111=6
|
||
# Full 4-bit nibble: bit 3 = sign, bits 2:0 = magnitude index
|
||
_E2M1_MAGNITUDES = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.float32)
|
||
|
||
|
||
def _quantize_to_e2m1(x_f32):
|
||
"""Quantize float32 values to E2M1 (FP4) nibble indices.
|
||
|
||
Maps each value to the nearest E2M1 representable magnitude,
|
||
then packs as 4-bit sign-magnitude nibbles.
|
||
|
||
Returns (nibbles, scales) where:
|
||
nibbles: (..., N) uint8 with 4-bit sign-magnitude per value
|
||
scales: (..., N//16) float8_e4m3fn block scales
|
||
"""
|
||
*batch, N = x_f32.shape
|
||
assert N % 16 == 0, f"Last dim {N} not divisible by 16 (block size)"
|
||
|
||
# Reshape into blocks of 16 for block-wise scaling
|
||
x_blocks = x_f32.reshape(*batch, N // 16, 16)
|
||
|
||
# Per-block absmax determines the scale
|
||
block_max = x_blocks.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8, max=448.0)
|
||
|
||
# Scale so that the max maps to 6.0 (largest E2M1 magnitude)
|
||
# Dequant: x_reconstructed = x_e2m1 * scale, where scale = block_max / 6.0
|
||
scale_f32 = block_max / 6.0
|
||
x_scaled = x_blocks / scale_f32.clamp(min=1e-8)
|
||
|
||
# Find nearest E2M1 magnitude for each value
|
||
signs = torch.sign(x_scaled) # +1, -1, or 0
|
||
abs_scaled = x_scaled.abs() # 0..6 range
|
||
|
||
# Nearest E2M1 magnitude: find closest in {0, 0.5, 1, 1.5, 2, 3, 4, 6}
|
||
mags = _E2M1_MAGNITUDES.to(device=abs_scaled.device)
|
||
# Distance from each value to each magnitude
|
||
dists = (abs_scaled.unsqueeze(-1) - mags).abs() # (..., 16, 8)
|
||
idx = dists.argmin(dim=-1) # (..., 16) — index into E2M1 magnitudes
|
||
|
||
# Clamp to valid range (safety)
|
||
idx = idx.clamp(0, 7).to(torch.uint8)
|
||
|
||
# Build 4-bit sign-magnitude nibble: bit3=sign, bits2:0=magnitude index
|
||
sign_bit = (signs < 0).to(torch.uint8) # 1 if negative
|
||
nibbles = (sign_bit << 3) | idx # (..., 16) uint8, values 0..15
|
||
|
||
# Pack 2 nibbles per byte: low nibble = even index, high nibble = odd index
|
||
nibbles = nibbles.reshape(*batch, N // 2, 2)
|
||
packed = (nibbles[..., 1] << 4) | nibbles[..., 0] # (..., N//2) uint8
|
||
|
||
# Scale factors: what the GEMM needs to reconstruct the original values
|
||
# dequant = e2m1_magnitude * scale, so scale = block_max / 6.0
|
||
sf = scale_f32.squeeze(-1).to(torch.float8_e4m3fn) # (..., N//16)
|
||
|
||
return packed.to(torch.int8), sf
|
||
|
||
|
||
def stage_activation(x_bf16):
|
||
"""Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales.
|
||
|
||
Two-level quantization matching the NVFP4 weight format:
|
||
1. Per-tensor global scale: amax / (6.0 * 448.0)
|
||
Normalizes the activation so that block scales fit in UE4M3 range.
|
||
2. Per-block (16 values) absmax scaling on the normalized values
|
||
Snap to nearest E2M1 representable value: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}
|
||
Pack as 4-bit sign-magnitude nibbles (bit3=sign, bits2:0=mag index)
|
||
Block scale = block_max / 6.0 stored as UE4M3 (float8_e4m3fn)
|
||
|
||
Returns (x_fp4, x_sf, input_global_scale) where:
|
||
x_fp4: packed E2M1 nibbles
|
||
x_sf: UE4M3 block scales (NOT folded with global scale)
|
||
input_global_scale: fp32 per-tensor scale, applied as GEMM alpha
|
||
|
||
The GEMM applies global scale via alpha: D = alpha * (A_sf * A_fp4) @ (B_sf * B_fp4)
|
||
This avoids fp32→UE4M3 round-trip from folding, preserving precision.
|
||
"""
|
||
x_f32 = x_bf16.float()
|
||
|
||
# Per-tensor global scale (same role as weight_scale_2)
|
||
# NVFP4 spec: global_scale = amax / (6.0 * 448.0)
|
||
# This ensures the largest block scale after normalization is ~448.0,
|
||
# which fits exactly in UE4M3 max (448.0 for E4M3).
|
||
x_amax = x_f32.abs().amax().to(torch.float32).clamp(min=1e-8)
|
||
input_global_scale = x_amax / (6.0 * 448.0)
|
||
|
||
# Normalize by global scale before block quantization.
|
||
# After this, values are in a range where block_max / 6.0 ≤ 448.0,
|
||
# so block scales fit in UE4M3 without saturation.
|
||
x_normalized = x_f32 / input_global_scale
|
||
|
||
x_fp4, x_sf = _quantize_to_e2m1(x_normalized)
|
||
|
||
return x_fp4, x_sf, input_global_scale
|
||
|
||
|
||
def nvfp4_mega_moe_full(
|
||
y, # output tensor (num_tokens, HIDDEN) bfloat16
|
||
transformed_l1_weights, # (l1_w, l1_sf) tuple from finalize_weights
|
||
transformed_l2_weights, # (l2_w, l2_sf) tuple from finalize_weights
|
||
symm_buffer, # SymmBuffer from get_symm_buffer
|
||
activation_clamp=None, # optional clamp value (unused in NVFP4)
|
||
fast_math=False, # fast math flag (unused in NVFP4)
|
||
):
|
||
"""Full mega_moe forward pass — replaces deep_gemm.mega.fp8_nvfp4_mega_moe.
|
||
|
||
API matches the DeepGEMM fp8_nvfp4_mega_moe call signature used in
|
||
the vLLM deepseek_v4.py patch:
|
||
|
||
fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer,
|
||
activation_clamp=..., fast_math=...)
|
||
|
||
Pipeline:
|
||
1. Read staged activation from symm_buffer (already quantized by staging kernel)
|
||
2. L1 GEMM: gate_up_proj (native NVFP4 block-scaled MMA)
|
||
3. SiLU + Mul (activation)
|
||
4. Quantize L1 output → FP4 + UE4M3 scales
|
||
5. L2 GEMM: down_proj (native NVFP4 block-scaled MMA)
|
||
6. Write to y (caller handles cross-rank all-reduce)
|
||
|
||
Uses tcgen05.mma.kind::mxf8f6f4.block_scale for native E2M1×E2M1
|
||
with UE4M3 block-16 scaling in Blackwell tensor cores.
|
||
"""
|
||
num_tokens = y.shape[0]
|
||
device = y.device
|
||
dtype = y.dtype
|
||
|
||
if MEGA_MOE_STATIC:
|
||
if MEGA_MOE_DEBUG:
|
||
print(f"[MEGA_MOE_STATIC] Skipping nvfp4_mega_moe, returning zeros "
|
||
f"shape=({num_tokens}, {y.shape[1]})")
|
||
y.zero_()
|
||
return
|
||
|
||
# Unpack transformed weights
|
||
l1_w, l1_sf = transformed_l1_weights
|
||
l2_w, l2_sf = transformed_l2_weights
|
||
|
||
# Step 1: Read staged activation from symm_buffer
|
||
x_fp4 = symm_buffer.x[:num_tokens]
|
||
x_sf = symm_buffer.x_sf[:num_tokens]
|
||
l1_global_scale = symm_buffer.input_global_scale # fp32, from stage_activation
|
||
topk_ids = symm_buffer.topk_idx[:num_tokens]
|
||
topk_weights = symm_buffer.topk_weights[:num_tokens]
|
||
|
||
# ALWAYS-ON debug: alpha and scale ranges
|
||
_x_sf_f32 = x_sf.to(torch.float32)
|
||
_igs = l1_global_scale if isinstance(l1_global_scale, float) else l1_global_scale.item() if hasattr(l1_global_scale, 'item') else float(l1_global_scale)
|
||
print(f"[ALPHA L1] alpha={_igs:.4e} x_sf range [{_x_sf_f32.min().item():.4e}, {_x_sf_f32.max().item():.4e}] x_fp4_absmax={x_fp4.view(torch.int8).abs().max().item()}")
|
||
|
||
# Convert global expert IDs to local expert IDs.
|
||
# vLLM's symm_buffer stores global IDs (0..383) but our weight tensors
|
||
# are indexed by local ID (0..47). Each rank handles a contiguous chunk:
|
||
# rank r gets experts [r*E_per_rank, (r+1)*E_per_rank).
|
||
num_experts_per_rank = l1_w.shape[0]
|
||
experts_start_idx = symm_buffer.experts_start_idx
|
||
topk_ids_local = topk_ids - experts_start_idx
|
||
|
||
if MEGA_MOE_DEBUG:
|
||
print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} "
|
||
f"topk_ids={topk_ids.shape} topk_ids range: {topk_ids.min().item()}-{topk_ids.max().item()} "
|
||
f"local: {topk_ids_local.min().item()}-{topk_ids_local.max().item()} "
|
||
f"l1_w={l1_w.shape} l2_w={l2_w.shape}")
|
||
|
||
# NaN-trace: check activation scales at L1 input
|
||
if MEGA_MOE_DEBUG:
|
||
x_sf_f32 = x_sf.to(torch.float32)
|
||
print(f"[L1-in] x_sf nan={torch.isnan(x_sf_f32).any().item()} "
|
||
f"inf={torch.isinf(x_sf_f32).any().item()} "
|
||
f"min={x_sf_f32.min().item():.4e} max={x_sf_f32.max().item():.4e}")
|
||
|
||
# Step 2: L1 GEMM (native NVFP4 block-scaled MMA)
|
||
l1_output = nvfp4_mega_moe_l1(
|
||
x_fp4, x_sf, l1_w, l1_sf,
|
||
topk_ids_local, topk_weights, num_experts_per_rank,
|
||
alpha=l1_global_scale,
|
||
)
|
||
|
||
# NaN-trace: check L1 output
|
||
if MEGA_MOE_DEBUG:
|
||
print(f"[L1-out] nan={torch.isnan(l1_output).any().item()} "
|
||
f"inf={torch.isinf(l1_output).any().item()} "
|
||
f"abs_max={l1_output.abs().max().item():.4e}")
|
||
|
||
# Step 3: SiLU + Mul
|
||
gate, up = l1_output.chunk(2, dim=-1)
|
||
activated = torch.nn.functional.silu(gate) * up
|
||
if activation_clamp is not None:
|
||
activated = activated.clamp(max=activation_clamp)
|
||
|
||
# NaN-trace: check SiLU output
|
||
if MEGA_MOE_DEBUG:
|
||
print(f"[silu] nan={torch.isnan(activated).any().item()} "
|
||
f"abs_max={activated.abs().max().item():.4e}")
|
||
|
||
# Step 4: Quantize L1 output → FP4
|
||
l1_fp4, l1_sf_out, l2_global_scale = stage_activation(activated)
|
||
|
||
# ALWAYS-ON debug: L2 alpha and scale ranges
|
||
_l1sf_f32 = l1_sf_out.to(torch.float32)
|
||
_l2gs = l2_global_scale if isinstance(l2_global_scale, float) else l2_global_scale.item() if hasattr(l2_global_scale, 'item') else float(l2_global_scale)
|
||
print(f"[ALPHA L2] alpha={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}] activated amax={activated.abs().max().item():.4e}")
|
||
|
||
# Step 5: L2 GEMM (native NVFP4 block-scaled MMA)
|
||
l2_output = nvfp4_mega_moe_l2(
|
||
l1_fp4, l1_sf_out, l2_w, l2_sf,
|
||
topk_ids_local, topk_weights, num_experts_per_rank,
|
||
alpha=l2_global_scale,
|
||
)
|
||
|
||
# NaN-trace: check L2 output
|
||
if MEGA_MOE_DEBUG:
|
||
print(f"[L2-out] nan={torch.isnan(l2_output).any().item()} "
|
||
f"abs_max={l2_output.abs().max().item():.4e}")
|
||
|
||
# Step 6: Write to output (caller handles cross-rank all-reduce)
|
||
y.copy_(l2_output)
|