debug: more prints
This commit is contained in:
@@ -204,6 +204,8 @@ def run_nvfp4_moe(
|
||||
intermediate_size = l1_out.shape[1] // 2
|
||||
gate = l1_out[:, :intermediate_size]
|
||||
up = l1_out[:, intermediate_size:]
|
||||
print(f" gate: shape={gate.shape}, amax={gate.abs().amax().item():.4f}", flush=True)
|
||||
print(f" up: shape={up.shape}, amax={up.abs().amax().item():.4f}", flush=True)
|
||||
activated = torch.nn.functional.silu(gate) * up # (num_slots, intermediate) BF16
|
||||
print(f" After SiLU(gate)*up: shape={activated.shape}, amax={activated.abs().amax().item():.4f}", flush=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user