fix the L2 path and the clamping math
This commit is contained in:
@@ -36,46 +36,58 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
x_sf, # (num_tokens, sf_k) float8_e4m3fn block scales
|
||||
weights, # (E_per_rank, K_half, N) int8 packed E2M1, column-major for CUTLASS
|
||||
weight_sf, # (E_per_rank, sf_k, N) float8_e4m3fn, column-major for CUTLASS
|
||||
topk_ids, # (num_tokens, NUM_TOPK) int32
|
||||
topk_weights, # (num_tokens, NUM_TOPK) float32
|
||||
topk_ids, # (num_tokens, NUM_TOPK) int32 — local expert IDs
|
||||
alpha=1.0, # fp32 scalar: D = alpha * A @ B (from stage_activation global scale)
|
||||
):
|
||||
"""Per-expert grouped GEMM for MoE dispatch using CUTLASS NVFP4.
|
||||
|
||||
For each expert, gather the tokens routed to it, run the block-scaled GEMM,
|
||||
then scatter results back with routing weights.
|
||||
|
||||
Returns slot-based output: one row per (token, topk) slot routed to a local
|
||||
expert. No routing weights applied — caller handles that at the final scatter.
|
||||
|
||||
Returns:
|
||||
slot_out: (num_slots, N) bfloat16 — per-slot GEMM results
|
||||
slot_token: (num_slots,) int64 — token index for each slot
|
||||
"""
|
||||
num_tokens = x_fp4.shape[0]
|
||||
K_half = x_fp4.shape[1]
|
||||
K = K_half * 2 # Actual K dimension (2 FP4 per byte)
|
||||
# Weights are (E, K_half, N) column-major (transposed at load time for CUTLASS ColumnMajor B)
|
||||
N = weights.shape[2] # Output dimension
|
||||
K = K_half * 2
|
||||
N = weights.shape[2]
|
||||
num_experts = weights.shape[0]
|
||||
num_topk = topk_ids.shape[1]
|
||||
|
||||
|
||||
# Build slot mapping: which (token, topk) pairs land on local experts?
|
||||
local_mask = (topk_ids >= 0) & (topk_ids < num_experts) # (num_tokens, num_topk)
|
||||
slot_token, slot_k = local_mask.nonzero(as_tuple=True) # (num_slots,)
|
||||
slot_expert = topk_ids[slot_token, slot_k] # (num_slots,) local expert id
|
||||
|
||||
num_slots = slot_token.shape[0]
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[cutlass_grouped_gemm] tokens={num_tokens} K={K} N={N} "
|
||||
f"experts={num_experts} topk={num_topk}")
|
||||
|
||||
output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=x_fp4.device)
|
||||
|
||||
f"experts={num_experts} topk={num_topk} slots={num_slots}")
|
||||
|
||||
if num_slots == 0:
|
||||
slot_out = torch.empty(0, N, dtype=torch.bfloat16, device=x_fp4.device)
|
||||
return slot_out, slot_token
|
||||
|
||||
# Gather activations for all slots
|
||||
slot_x = x_fp4[slot_token] # (num_slots, K_half)
|
||||
slot_x_sf = x_sf[slot_token] # (num_slots, sf_k)
|
||||
|
||||
slot_out = torch.empty(num_slots, N, dtype=torch.bfloat16, device=x_fp4.device)
|
||||
|
||||
for e in range(num_experts):
|
||||
# Find tokens routed to this expert
|
||||
expert_mask = (topk_ids == e) # (num_tokens, num_topk)
|
||||
token_indices = expert_mask.any(dim=1).nonzero(as_tuple=True)[0]
|
||||
|
||||
if token_indices.numel() == 0:
|
||||
expert_slots = (slot_expert == e)
|
||||
if not expert_slots.any():
|
||||
continue
|
||||
|
||||
# Gather tokens for this expert
|
||||
expert_x = x_fp4[token_indices] # (num_expert_tokens, K_half)
|
||||
expert_x_sf = x_sf[token_indices] # (num_expert_tokens, sf_k)
|
||||
expert_w = weights[e] # (K_half, N) column-major for CUTLASS
|
||||
expert_w_sf = weight_sf[e] # (sf_k, N) column-major for CUTLASS
|
||||
|
||||
M_expert = token_indices.shape[0]
|
||||
|
||||
# DEBUG: verify data going into GEMM
|
||||
|
||||
e_idx = expert_slots.nonzero(as_tuple=True)[0]
|
||||
expert_x = slot_x[e_idx]
|
||||
expert_x_sf = slot_x_sf[e_idx]
|
||||
expert_w = weights[e]
|
||||
expert_w_sf = weight_sf[e]
|
||||
M_expert = e_idx.shape[0]
|
||||
|
||||
if e < 3 and M_expert > 0:
|
||||
print(f"[GEMM-IN] expert={e} M={M_expert} N={N} K={K} "
|
||||
f"w shape={expert_w.shape} w_sf shape={expert_w_sf.shape} "
|
||||
@@ -83,19 +95,16 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
f"w_sf range=[{expert_w_sf.to(torch.float32).min().item():.4e}, "
|
||||
f"{expert_w_sf.to(torch.float32).max().item():.4e}] "
|
||||
f"w_sf nonzero_frac={(expert_w_sf.view(torch.uint8) != 0).float().mean().item():.4f}")
|
||||
|
||||
# Run CUTLASS NVFP4 block-scaled GEMM
|
||||
|
||||
expert_out = cutlass_nvfp4_blockscaled_gemm(
|
||||
expert_x, expert_x_sf,
|
||||
expert_w, expert_w_sf, # Pass directly — already (N, K_half) and (N, sf_k)
|
||||
expert_w, expert_w_sf,
|
||||
M_expert, N, K,
|
||||
alpha=alpha,
|
||||
) # (M_expert, N) bfloat16
|
||||
|
||||
# Check for CUDA errors after each expert GEMM
|
||||
)
|
||||
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
# Hard-fail on NaN/Inf — silent skip was hiding bugs
|
||||
|
||||
if torch.isnan(expert_out).any() or torch.isinf(expert_out).any():
|
||||
raise RuntimeError(
|
||||
f"expert {e} of {num_experts}: GEMM emitted NaN/Inf. "
|
||||
@@ -108,11 +117,7 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
f"x_sf nan_frac={torch.isnan(expert_x_sf.to(torch.float32)).float().mean().item():.4f}, "
|
||||
f"w_sf nan_frac={torch.isnan(expert_w_sf.to(torch.float32)).float().mean().item():.4f}"
|
||||
)
|
||||
|
||||
# Scatter back with routing weights
|
||||
for t_idx, token_idx in enumerate(token_indices):
|
||||
for k_idx in range(num_topk):
|
||||
if topk_ids[token_idx, k_idx] == e:
|
||||
output[token_idx] += topk_weights[token_idx, k_idx] * expert_out[t_idx]
|
||||
|
||||
return output
|
||||
|
||||
slot_out[e_idx] = expert_out
|
||||
|
||||
return slot_out, slot_token
|
||||
|
||||
@@ -5,8 +5,9 @@ This is the main kernel that replaces fp8_nvfp4_mega_moe from DeepGEMM.
|
||||
|
||||
Architecture:
|
||||
- L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with UE4M3 scales)
|
||||
- SiLU+Mul activation
|
||||
- SiLU+Mul activation (per-slot, BEFORE combining expert paths)
|
||||
- L2 GEMM: down_proj (FP4 x FP4 → BF16 with UE4M3 scales)
|
||||
- Routing weights applied ONCE at final scatter
|
||||
- NVLink cross-rank sync handled by caller (not this kernel)
|
||||
- Expert parallel: each rank handles NUM_EXPERTS/8 experts
|
||||
|
||||
@@ -90,82 +91,75 @@ MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
|
||||
|
||||
def nvfp4_mega_moe_l1(
|
||||
x_fp4, # (num_tokens, K//2) int8 packed E2M1
|
||||
x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3
|
||||
x_sf, # (num_tokens, sf_k_groups) float8_e4m3fn
|
||||
l1_weights, # (E_per_rank, K//2, 2*INTER) int8, column-major for CUTLASS
|
||||
l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major
|
||||
topk_ids, # (num_tokens, NUM_TOPK) int32
|
||||
topk_weights, # (num_tokens, NUM_TOPK) float32
|
||||
num_experts_per_rank,
|
||||
topk_ids, # (num_tokens, NUM_TOPK) int32 — local expert IDs
|
||||
alpha=1.0, # fp32 scalar from stage_activation global scale
|
||||
):
|
||||
"""L1 GEMM: gate_up_proj — Native NVFP4 block-scaled MMA.
|
||||
"""L1 GEMM: gate_up_proj — slot-based, no routing weights.
|
||||
|
||||
Uses tcgen05.mma.kind::mxf8f6f4.block_scale for native E2M1×E2M1
|
||||
with UE4M3 block-16 scaling in tensor cores.
|
||||
|
||||
Falls back to dequantize+BF16 if native path unavailable.
|
||||
Returns (slot_out, slot_token) where each slot is one (token, topk) pair.
|
||||
Caller applies SiLU+Mul per-slot, then L2, then final scatter with weights.
|
||||
"""
|
||||
num_tokens = x_fp4.shape[0]
|
||||
K_half = x_fp4.shape[1]
|
||||
K = K_half * 2 # HIDDEN = 7168
|
||||
N = l1_weights.shape[2] # 2 * INTERMEDIATE = 6144 (column-major: shape is E, K_half, N)
|
||||
K = K_half * 2
|
||||
N = l1_weights.shape[2] # 2 * INTERMEDIATE = 6144
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[nvfp4_moe_l1] tokens={num_tokens} K={K} N={N} "
|
||||
f"experts={num_experts_per_rank} native=1")
|
||||
print(f"[nvfp4_moe_l1] tokens={x_fp4.shape[0]} K={K} N={N} native=1")
|
||||
|
||||
# DEBUG: verify weight shapes after transpose
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[L1-WT] l1_w shape={l1_weights.shape} l1_sf shape={l1_scales.shape} w_sf dtype={l1_scales.dtype}")
|
||||
|
||||
# Unpack uint32 packed UE4M3 scales to float8_e4m3fn
|
||||
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
|
||||
w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales
|
||||
|
||||
output = cutlass_grouped_nvfp4_gemm(
|
||||
slot_out, slot_token = cutlass_grouped_nvfp4_gemm(
|
||||
x_fp4, x_sf_fp8,
|
||||
l1_weights, w_sf_fp8,
|
||||
topk_ids, topk_weights,
|
||||
topk_ids,
|
||||
alpha=alpha,
|
||||
)
|
||||
print(f"[L1-GEMM-OUT] amax={output.abs().max().item():.4e} mean={output.float().mean().item():.4e} nonzero_frac={(output != 0).float().mean().item():.4f}")
|
||||
return output # (num_tokens, 6144) bfloat16
|
||||
print(f"[L1-GEMM-OUT] slots={slot_out.shape[0]} N={N} amax={slot_out.abs().max().item():.4e} mean={slot_out.float().mean().item():.4e}")
|
||||
return slot_out, slot_token
|
||||
|
||||
|
||||
def nvfp4_mega_moe_l2(
|
||||
x_fp4, # (num_tokens, INTER//2) int8 packed E2M1
|
||||
x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3
|
||||
x_fp4, # (num_slots, INTER//2) int8 packed E2M1
|
||||
x_sf, # (num_slots, sf_k_groups) float8_e4m3fn
|
||||
l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS
|
||||
l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major
|
||||
topk_ids, # (num_tokens, NUM_TOPK) int32
|
||||
topk_weights, # (num_tokens, NUM_TOPK) float32
|
||||
num_experts_per_rank,
|
||||
topk_ids, # (num_tokens, NUM_TOPK) int32 — local expert IDs (for slot mapping)
|
||||
slot_token, # (num_slots,) int64 — token index per slot (from L1)
|
||||
alpha=1.0, # fp32 scalar from stage_activation global scale
|
||||
):
|
||||
"""L2 GEMM: down_proj — Native NVFP4 block-scaled MMA.
|
||||
"""L2 GEMM: down_proj — slot-based, no routing weights.
|
||||
|
||||
Same pipeline as L1 using native mxf8f6f4.block_scale MMA.
|
||||
Reuses the same slot mapping from L1 (same slot_token indices).
|
||||
topk_ids is passed to rebuild the slot→expert mapping.
|
||||
"""
|
||||
num_tokens = x_fp4.shape[0]
|
||||
K_half = x_fp4.shape[1]
|
||||
K = K_half * 2 # INTERMEDIATE = 3072
|
||||
N = l2_weights.shape[2] # HIDDEN = 7168 (column-major: shape is E, K_half, N)
|
||||
K = K_half * 2
|
||||
N = l2_weights.shape[2]
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[nvfp4_moe_l2] tokens={num_tokens} K={K} N={N} "
|
||||
f"experts={num_experts_per_rank} native=1")
|
||||
print(f"[nvfp4_moe_l2] slots={x_fp4.shape[0]} K={K} N={N} native=1")
|
||||
|
||||
# Unpack uint32 packed UE4M3 scales to float8_e4m3fn
|
||||
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
|
||||
w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales
|
||||
|
||||
output = cutlass_grouped_nvfp4_gemm(
|
||||
# Build local expert IDs per slot (same mapping as L1)
|
||||
num_topk = topk_ids.shape[1]
|
||||
num_experts = l2_weights.shape[0]
|
||||
local_mask = (topk_ids >= 0) & (topk_ids < num_experts)
|
||||
_, slot_k = local_mask.nonzero(as_tuple=True)
|
||||
slot_expert_ids = topk_ids[slot_token, slot_k] # (num_slots,)
|
||||
|
||||
slot_out, _ = cutlass_grouped_nvfp4_gemm(
|
||||
x_fp4, x_sf_fp8,
|
||||
l2_weights, w_sf_fp8,
|
||||
topk_ids, topk_weights,
|
||||
slot_expert_ids, # per-slot expert IDs
|
||||
alpha=alpha,
|
||||
)
|
||||
return output # (num_tokens, 7168) bfloat16
|
||||
return slot_out # (num_slots, HIDDEN) bfloat16
|
||||
|
||||
|
||||
# E2M1 (FP4) representable magnitudes: {0, 0.5, 1, 1.5, 2, 3, 4, 6}
|
||||
@@ -191,37 +185,29 @@ def _quantize_to_e2m1(x_f32):
|
||||
x_blocks = x_f32.reshape(*batch, N // 16, 16)
|
||||
|
||||
# Per-block absmax determines the scale
|
||||
block_max = x_blocks.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8, max=448.0)
|
||||
block_max = x_blocks.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||
|
||||
# Scale so that the max maps to 6.0 (largest E2M1 magnitude)
|
||||
# Dequant: x_reconstructed = x_e2m1 * scale, where scale = block_max / 6.0
|
||||
scale_f32 = block_max / 6.0
|
||||
scale_f32 = (block_max / 6.0).clamp(min=1e-8, max=448.0)
|
||||
x_scaled = x_blocks / scale_f32.clamp(min=1e-8)
|
||||
|
||||
# Find nearest E2M1 magnitude for each value
|
||||
signs = torch.sign(x_scaled) # +1, -1, or 0
|
||||
abs_scaled = x_scaled.abs() # 0..6 range
|
||||
signs = torch.sign(x_scaled)
|
||||
abs_scaled = x_scaled.abs()
|
||||
|
||||
# Nearest E2M1 magnitude: find closest in {0, 0.5, 1, 1.5, 2, 3, 4, 6}
|
||||
mags = _E2M1_MAGNITUDES.to(device=abs_scaled.device)
|
||||
# Distance from each value to each magnitude
|
||||
dists = (abs_scaled.unsqueeze(-1) - mags).abs() # (..., 16, 8)
|
||||
idx = dists.argmin(dim=-1) # (..., 16) — index into E2M1 magnitudes
|
||||
dists = (abs_scaled.unsqueeze(-1) - mags).abs()
|
||||
idx = dists.argmin(dim=-1)
|
||||
|
||||
# Clamp to valid range (safety)
|
||||
idx = idx.clamp(0, 7).to(torch.uint8)
|
||||
|
||||
# Build 4-bit sign-magnitude nibble: bit3=sign, bits2:0=magnitude index
|
||||
sign_bit = (signs < 0).to(torch.uint8) # 1 if negative
|
||||
nibbles = (sign_bit << 3) | idx # (..., 16) uint8, values 0..15
|
||||
sign_bit = (signs < 0).to(torch.uint8)
|
||||
nibbles = (sign_bit << 3) | idx
|
||||
|
||||
# Pack 2 nibbles per byte: low nibble = even index, high nibble = odd index
|
||||
nibbles = nibbles.reshape(*batch, N // 2, 2)
|
||||
packed = (nibbles[..., 1] << 4) | nibbles[..., 0] # (..., N//2) uint8
|
||||
packed = (nibbles[..., 1] << 4) | nibbles[..., 0]
|
||||
|
||||
# Scale factors: what the GEMM needs to reconstruct the original values
|
||||
# dequant = e2m1_magnitude * scale, so scale = block_max / 6.0
|
||||
sf = scale_f32.squeeze(-1).to(torch.float8_e4m3fn) # (..., N//16)
|
||||
sf = scale_f32.squeeze(-1).to(torch.float8_e4m3fn)
|
||||
|
||||
return packed.to(torch.int8), sf
|
||||
|
||||
@@ -231,32 +217,18 @@ def stage_activation(x_bf16):
|
||||
|
||||
Two-level quantization matching the NVFP4 weight format:
|
||||
1. Per-tensor global scale: amax / (6.0 * 448.0)
|
||||
Normalizes the activation so that block scales fit in UE4M3 range.
|
||||
2. Per-block (16 values) absmax scaling on the normalized values
|
||||
Snap to nearest E2M1 representable value: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}
|
||||
Pack as 4-bit sign-magnitude nibbles (bit3=sign, bits2:0=mag index)
|
||||
Block scale = block_max / 6.0 stored as UE4M3 (float8_e4m3fn)
|
||||
|
||||
Returns (x_fp4, x_sf, input_global_scale) where:
|
||||
x_fp4: packed E2M1 nibbles
|
||||
x_sf: UE4M3 block scales (NOT folded with global scale)
|
||||
input_global_scale: fp32 per-tensor scale, applied as GEMM alpha
|
||||
|
||||
The GEMM applies global scale via alpha: D = alpha * (A_sf * A_fp4) @ (B_sf * B_fp4)
|
||||
This avoids fp32→UE4M3 round-trip from folding, preserving precision.
|
||||
"""
|
||||
x_f32 = x_bf16.float()
|
||||
|
||||
# Per-tensor global scale (same role as weight_scale_2)
|
||||
# NVFP4 spec: global_scale = amax / (6.0 * 448.0)
|
||||
# This ensures the largest block scale after normalization is ~448.0,
|
||||
# which fits exactly in UE4M3 max (448.0 for E4M3).
|
||||
x_amax = x_f32.abs().amax().to(torch.float32).clamp(min=1e-8)
|
||||
input_global_scale = x_amax / (6.0 * 448.0)
|
||||
|
||||
# Normalize by global scale before block quantization.
|
||||
# After this, values are in a range where block_max / 6.0 ≤ 448.0,
|
||||
# so block scales fit in UE4M3 without saturation.
|
||||
x_normalized = x_f32 / input_global_scale
|
||||
|
||||
x_fp4, x_sf = _quantize_to_e2m1(x_normalized)
|
||||
@@ -274,22 +246,14 @@ def nvfp4_mega_moe_full(
|
||||
):
|
||||
"""Full mega_moe forward pass — replaces deep_gemm.mega.fp8_nvfp4_mega_moe.
|
||||
|
||||
API matches the DeepGEMM fp8_nvfp4_mega_moe call signature used in
|
||||
the vLLM deepseek_v4.py patch:
|
||||
|
||||
fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer,
|
||||
activation_clamp=..., fast_math=...)
|
||||
|
||||
Pipeline:
|
||||
1. Read staged activation from symm_buffer (already quantized by staging kernel)
|
||||
2. L1 GEMM: gate_up_proj (native NVFP4 block-scaled MMA)
|
||||
3. SiLU + Mul (activation)
|
||||
4. Quantize L1 output → FP4 + UE4M3 scales
|
||||
5. L2 GEMM: down_proj (native NVFP4 block-scaled MMA)
|
||||
6. Write to y (caller handles cross-rank all-reduce)
|
||||
|
||||
Uses tcgen05.mma.kind::mxf8f6f4.block_scale for native E2M1×E2M1
|
||||
with UE4M3 block-16 scaling in Blackwell tensor cores.
|
||||
Slot-based pipeline (routing weights applied ONCE at final scatter):
|
||||
1. Read staged activation from symm_buffer
|
||||
2. L1 GEMM → slot output (num_slots, 2*INTER) — NO routing weights
|
||||
3. SiLU + Mul PER SLOT (nonlinearity before combining expert paths)
|
||||
4. Quantize activated slots → FP4
|
||||
5. L2 GEMM → slot output (num_slots, HIDDEN) — NO routing weights
|
||||
6. Final scatter: y.index_add_(0, slot_token, slot_weight * l2_slots)
|
||||
Single routing weight application.
|
||||
"""
|
||||
num_tokens = y.shape[0]
|
||||
device = y.device
|
||||
@@ -318,87 +282,82 @@ def nvfp4_mega_moe_full(
|
||||
# Step 1: Read staged activation from symm_buffer
|
||||
x_fp4 = symm_buffer.x[:num_tokens]
|
||||
x_sf = symm_buffer.x_sf[:num_tokens]
|
||||
l1_global_scale = symm_buffer.input_global_scale # fp32, from stage_activation
|
||||
l1_global_scale = symm_buffer.input_global_scale
|
||||
topk_ids = symm_buffer.topk_idx[:num_tokens]
|
||||
topk_weights = symm_buffer.topk_weights[:num_tokens]
|
||||
|
||||
# ALWAYS-ON debug: alpha and scale ranges
|
||||
_x_sf_f32 = x_sf.to(torch.float32)
|
||||
_igs = l1_global_scale if isinstance(l1_global_scale, float) else l1_global_scale.item() if hasattr(l1_global_scale, 'item') else float(l1_global_scale)
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[ALPHA L1] alpha={_igs:.4e} x_sf range [{_x_sf_f32.min().item():.4e}, {_x_sf_f32.max().item():.4e}] x_fp4_absmax={x_fp4.view(torch.int8).abs().max().item()}")
|
||||
print(f"[ALPHA L1] alpha={_igs:.4e} x_sf range [{_x_sf_f32.min().item():.4e}, {_x_sf_f32.max().item():.4e}]")
|
||||
|
||||
# Convert global expert IDs to local expert IDs.
|
||||
# vLLM's symm_buffer stores global IDs (0..383) but our weight tensors
|
||||
# are indexed by local ID (0..47). Each rank handles a contiguous chunk:
|
||||
# rank r gets experts [r*E_per_rank, (r+1)*E_per_rank).
|
||||
# Convert global expert IDs to local expert IDs
|
||||
num_experts_per_rank = l1_w.shape[0]
|
||||
experts_start_idx = symm_buffer.experts_start_idx
|
||||
topk_ids_local = topk_ids - experts_start_idx
|
||||
|
||||
# Routing diagnostic (ungated — needed to diagnose zero-GEMM on specific ranks)
|
||||
# Build slot mapping for this rank
|
||||
local_topk = (topk_ids >= experts_start_idx) & (topk_ids < experts_start_idx + num_experts_per_rank)
|
||||
slot_token, slot_k = local_topk.nonzero(as_tuple=True)
|
||||
slot_expert_local = topk_ids_local[slot_token, slot_k]
|
||||
slot_weight = topk_weights[slot_token, slot_k]
|
||||
num_slots = slot_token.shape[0]
|
||||
|
||||
tokens_routed_locally = local_topk.any(dim=-1).sum().item()
|
||||
print(f"[ROUTING] tokens_routed_local={tokens_routed_locally}/{topk_ids.shape[0]} "
|
||||
f"unique_local_experts={local_topk.long().sum().item()}")
|
||||
print(f"[ROUTING] tokens_routed_local={tokens_routed_locally}/{num_tokens} "
|
||||
f"num_slots={num_slots}")
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} "
|
||||
f"topk_ids={topk_ids.shape} topk_ids range: {topk_ids.min().item()}-{topk_ids.max().item()} "
|
||||
f"topk_ids range: {topk_ids.min().item()}-{topk_ids.max().item()} "
|
||||
f"local: {topk_ids_local.min().item()}-{topk_ids_local.max().item()} "
|
||||
f"l1_w={l1_w.shape} l2_w={l2_w.shape}")
|
||||
f"slots={num_slots}")
|
||||
|
||||
# NaN-trace: check activation scales at L1 input
|
||||
if MEGA_MOE_DEBUG:
|
||||
x_sf_f32 = x_sf.to(torch.float32)
|
||||
print(f"[L1-in] x_sf nan={torch.isnan(x_sf_f32).any().item()} "
|
||||
f"inf={torch.isinf(x_sf_f32).any().item()} "
|
||||
f"min={x_sf_f32.min().item():.4e} max={x_sf_f32.max().item():.4e}")
|
||||
# Handle no local slots
|
||||
if num_slots == 0:
|
||||
y.zero_()
|
||||
return
|
||||
|
||||
# Step 2: L1 GEMM (native NVFP4 block-scaled MMA)
|
||||
l1_output = nvfp4_mega_moe_l1(
|
||||
# Step 2: L1 GEMM — slot-based, no routing weights
|
||||
l1_slots, _ = nvfp4_mega_moe_l1(
|
||||
x_fp4, x_sf, l1_w, l1_sf,
|
||||
topk_ids_local, topk_weights, num_experts_per_rank,
|
||||
topk_ids_local,
|
||||
alpha=l1_global_scale,
|
||||
)
|
||||
) # (num_slots, 2*INTER) bfloat16
|
||||
|
||||
# NaN-trace: check L1 output
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[L1-out] nan={torch.isnan(l1_output).any().item()} "
|
||||
f"inf={torch.isinf(l1_output).any().item()} "
|
||||
f"abs_max={l1_output.abs().max().item():.4e}")
|
||||
print(f"[L1-out] nan={torch.isnan(l1_slots).any().item()} "
|
||||
f"abs_max={l1_slots.abs().max().item():.4e}")
|
||||
|
||||
# Step 3: SiLU + Mul
|
||||
gate, up = l1_output.chunk(2, dim=-1)
|
||||
# Step 3: SiLU + Mul PER SLOT — nonlinearity before combining paths
|
||||
gate, up = l1_slots.chunk(2, dim=-1)
|
||||
activated = torch.nn.functional.silu(gate) * up
|
||||
if activation_clamp is not None:
|
||||
activated = activated.clamp(max=activation_clamp)
|
||||
|
||||
# NaN-trace: check SiLU output
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[silu] nan={torch.isnan(activated).any().item()} "
|
||||
f"abs_max={activated.abs().max().item():.4e}")
|
||||
|
||||
# Step 4: Quantize L1 output → FP4
|
||||
# Step 4: Quantize activated slots → FP4
|
||||
l1_fp4, l1_sf_out, l2_global_scale = stage_activation(activated)
|
||||
|
||||
# ALWAYS-ON debug: L2 alpha and scale ranges
|
||||
_l1sf_f32 = l1_sf_out.to(torch.float32)
|
||||
_l2gs = l2_global_scale if isinstance(l2_global_scale, float) else l2_global_scale.item() if hasattr(l2_global_scale, 'item') else float(l2_global_scale)
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[ALPHA L2] alpha={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}] activated amax={activated.abs().max().item():.4e}")
|
||||
_l1sf_f32 = l1_sf_out.to(torch.float32)
|
||||
_l2gs = l2_global_scale if isinstance(l2_global_scale, float) else l2_global_scale.item()
|
||||
print(f"[ALPHA L2] alpha={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}]")
|
||||
|
||||
# Step 5: L2 GEMM (native NVFP4 block-scaled MMA)
|
||||
l2_output = nvfp4_mega_moe_l2(
|
||||
# Step 5: L2 GEMM — slot-based, no routing weights
|
||||
l2_slots = nvfp4_mega_moe_l2(
|
||||
l1_fp4, l1_sf_out, l2_w, l2_sf,
|
||||
topk_ids_local, topk_weights, num_experts_per_rank,
|
||||
topk_ids_local, slot_token,
|
||||
alpha=l2_global_scale,
|
||||
)
|
||||
) # (num_slots, HIDDEN) bfloat16
|
||||
|
||||
# NaN-trace: check L2 output
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[L2-out] nan={torch.isnan(l2_output).any().item()} "
|
||||
f"abs_max={l2_output.abs().max().item():.4e}")
|
||||
print(f"[L2-out] nan={torch.isnan(l2_slots).any().item()} "
|
||||
f"abs_max={l2_slots.abs().max().item():.4e}")
|
||||
|
||||
# Step 6: Write to output (caller handles cross-rank all-reduce)
|
||||
y.copy_(l2_output)
|
||||
# Step 6: Final scatter — routing weights applied ONCE
|
||||
y.zero_()
|
||||
y.index_add_(0, slot_token, slot_weight[:, None] * l2_slots)
|
||||
|
||||
Reference in New Issue
Block a user