debug: add L1 GEMM and SiLU output debug prints
This commit is contained in:
@@ -194,7 +194,8 @@ def run_nvfp4_moe(
|
||||
scale_a=l1_scale_a, scale_b=l1_scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=l1_global_scale_a, global_scale_b=l1_global_scale_b,
|
||||
) # (num_slots, intermediate) BF16
|
||||
) # (num_slots, 2*intermediate) BF16
|
||||
print(f" L1 GEMM output: shape={l1_out.shape}, amax={l1_out.abs().amax().item():.4f}", flush=True)
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# SiLU(gate) * up (BF16 — nonlinear requires BF16)
|
||||
@@ -204,6 +205,7 @@ def run_nvfp4_moe(
|
||||
gate = l1_out[:, :intermediate_size]
|
||||
up = l1_out[:, intermediate_size:]
|
||||
activated = torch.nn.functional.silu(gate) * up # (num_slots, intermediate) BF16
|
||||
print(f" After SiLU(gate)*up: shape={activated.shape}, amax={activated.abs().amax().item():.4f}", flush=True)
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# L2: down projection (NVFP4 × NVFP4 → BF16)
|
||||
|
||||
Reference in New Issue
Block a user