diff --git a/tests/test_pipeline_real_weights.py b/tests/test_pipeline_real_weights.py index e91b9503..71f27ede 100644 --- a/tests/test_pipeline_real_weights.py +++ b/tests/test_pipeline_real_weights.py @@ -1,26 +1,12 @@ -"""Pipeline Test: Step-by-step using CuTeDSL bridge + BF16 reference. - -Tests each stage of the NVFP4 MoE pipeline: - 1. Token sort + expert assignment - 2. L1 GEMM (gate+up) - 3. SwiGLU activation - 4. L2 GEMM (down_proj) - 5. Scatter-add - 6. Full runner end-to-end - -Incrementally enable stages with STAGE_START/STAGE_END. +"""Debug test: Replicate runner logic step by step in Python. +Compare against BF16 reference to isolate where tokens get dropped. """ import torch import torch.nn.functional as F -import sys -import os -import glob +import sys, os, glob sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) -# ============================================================ -# CONFIG -# ============================================================ MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" LAYER_IDX = 0 NUM_EXPERTS = 48 @@ -31,9 +17,6 @@ TOP_K = 6 SWIGLU_LIMIT = 10.0 DEVICE = "cuda" -STAGE_START = 1 -STAGE_END = 6 - def load_layer_tensors(model_dir, layer_idx): tensors = {} @@ -42,204 +25,205 @@ def load_layer_tensors(model_dir, layer_idx): data = load_file(sf) for k, v in data.items(): if f"layers.{layer_idx}." in k and "mlp.experts" in k: - norm_key = k.removeprefix("model.") - tensors[norm_key] = v + tensors[k.removeprefix("model.")] = v return tensors def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale): - """Dequantize NVFP4 to BF16. Input: (N, K_packed), scale: (N, K_sf).""" - lut = torch.tensor([ - 0., 0.5, 1., 1.5, 2., 3., 4., 6., - -0., -0.5, -1., -1.5, -2., -3., -4., -6. - ], dtype=torch.float32, device=packed_uint8.device) + lut = torch.tensor([0.,0.5,1.,1.5,2.,3.,4.,6.,-0.,-0.5,-1.,-1.5,-2.,-3.,-4.,-6.], + dtype=torch.float32, device=packed_uint8.device) lower = lut[(packed_uint8 & 0x0F).long()] upper = lut[((packed_uint8 >> 4) & 0x0F).long()] - N = packed_uint8.shape[0] - K = packed_uint8.shape[1] * 2 - bf16_vals = torch.stack([lower, upper], dim=-1).reshape(N, K) + N, K = packed_uint8.shape[0], packed_uint8.shape[1] * 2 + bf16 = torch.stack([lower, upper], dim=-1).reshape(N, K) K_sf = scale_e4m3.shape[1] scale_2d = scale_e4m3.float().repeat_interleave(K // K_sf, dim=1) - return (bf16_vals * scale_2d * global_scale).to(torch.bfloat16) - - -def bf16_moe_reference(hidden_states, nvfp4_tensors, layer_idx, expert_indices, topk_ids, topk_weights, swiglu_limit): - """Full BF16 reference MoE. Returns output + per-expert intermediates.""" - num_tokens = hidden_states.shape[0] - top_k = topk_ids.shape[1] - output = torch.zeros(num_tokens, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) - - # Per-expert intermediates (keyed by local expert index) - expert_data = {} - - for i, e in enumerate(expert_indices): - gate_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE) - up_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].to(DEVICE) - gate_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE) - up_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE) - gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() - up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item() - - gate_bf16 = dequantize_nvfp4_weight(gate_w, gate_sf, gate_gs) - up_bf16 = dequantize_nvfp4_weight(up_w, up_sf, up_gs) - - down_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" - if down_key in nvfp4_tensors: - down_w = nvfp4_tensors[down_key].to(DEVICE) - down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE) - down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item() - down_bf16 = dequantize_nvfp4_weight(down_w, down_sf, down_gs) - else: - down_bf16 = None - - # Collect tokens for this expert - mask = (topk_ids == i) # (num_tokens, top_k) - token_rows, k_rows = torch.where(mask) - if token_rows.numel() == 0: - continue - - x = hidden_states[token_rows] # (T, H) - gate = x @ gate_bf16.T - up = x @ up_bf16.T - l1_out = torch.cat([gate, up], dim=1) - - gate_silu = F.silu(gate) - if swiglu_limit is not None: - gate_silu = gate_silu.clamp(max=swiglu_limit) - up = up.clamp(min=-swiglu_limit, max=swiglu_limit) - activated = gate_silu * up - - if down_bf16 is not None: - l2_out = activated @ down_bf16.T - else: - l2_out = activated[:, :HIDDEN_SIZE] - - # Scatter - weights = topk_weights[token_rows, k_rows] - weighted = l2_out * weights.unsqueeze(1).to(l2_out.dtype) - output.scatter_add_(0, token_rows.unsqueeze(1).expand(-1, HIDDEN_SIZE), weighted) - - expert_data[i] = { - 'tokens': token_rows, - 'x': x, - 'gate': gate, 'up': up, 'l1_out': l1_out, - 'activated': activated, - 'l2_out': l2_out, - } - - return output, expert_data - - -def prepare_nvfp4_weights(nvfp4_tensors, layer_idx, expert_indices, intermediate_size): - l1_fp4, l1_sf, l1_gs = [], [], [] - l2_fp4, l2_sf, l2_gs = [], [], [] - for e in expert_indices: - gate_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE) - up_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].to(DEVICE) - gate_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE) - up_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE) - gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() - up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item() - fused_w = torch.cat([gate_w, up_w], dim=0).view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() - fused_sf = torch.cat([gate_sf, up_sf], dim=0).permute(1, 0).contiguous() - l1_max_gs = max(gate_gs, up_gs) - if gate_gs != up_gs: - sf32 = fused_sf.float() - sf32[:, :intermediate_size] *= (gate_gs / l1_max_gs) - sf32[:, intermediate_size:] *= (up_gs / l1_max_gs) - fused_sf = sf32.to(torch.float8_e4m3fn) - l1_fp4.append(fused_w); l1_sf.append(fused_sf); l1_gs.append(l1_max_gs) - - down_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" - if down_key in nvfp4_tensors: - dw = nvfp4_tensors[down_key].to(DEVICE) - dsf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE) - dgs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item() - l2_fp4.append(dw.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()) - l2_sf.append(dsf.permute(1, 0).contiguous()); l2_gs.append(dgs) - else: - l2_fp4.append(torch.zeros(intermediate_size // 2, HIDDEN_SIZE, dtype=torch.float4_e2m1fn_x2, device=DEVICE)) - l2_sf.append(torch.ones(intermediate_size // 16, HIDDEN_SIZE, dtype=torch.float8_e4m3fn, device=DEVICE)) - l2_gs.append(1.0) - return {'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs, - 'l2_fp4': l2_fp4, 'l2_sf': l2_sf, 'l2_gs': l2_gs} + return (bf16 * scale_2d * global_scale).to(torch.bfloat16) def main(): torch.cuda.set_device(0) torch.manual_seed(42) - - print(f"=== Pipeline Test (stages {STAGE_START}-{STAGE_END}) ===") - print(f" {NUM_EXPERTS} experts, H={HIDDEN_SIZE}, I={INTERMEDIATE_SIZE}, T={NUM_TOKENS}, top_k={TOP_K}") - - # Load weights + + print("=== Runner Logic Debug ===") nvfp4_tensors = load_layer_tensors(MODEL_PATH, LAYER_IDX) - print(f" {len(nvfp4_tensors)} tensors loaded") - - expert_indices = list(range(NUM_EXPERTS)) - weights = prepare_nvfp4_weights(nvfp4_tensors, LAYER_IDX, expert_indices, INTERMEDIATE_SIZE) - + hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) * 2.0 topk_ids = torch.zeros(NUM_TOKENS, TOP_K, dtype=torch.int64, device=DEVICE) for i in range(NUM_TOKENS): topk_ids[i] = torch.randperm(NUM_EXPERTS)[:TOP_K] topk_weights = torch.ones(NUM_TOKENS, TOP_K, dtype=torch.float32, device=DEVICE) / TOP_K - - # BF16 reference - print("\n--- BF16 Reference ---") - ref_out, ref_expert = bf16_moe_reference(hidden_states, nvfp4_tensors, LAYER_IDX, expert_indices, topk_ids, topk_weights, SWIGLU_LIMIT) - print(f" Output: amax={ref_out.amax().item():.4f} mean={ref_out.mean().item():.4f}") - for i in list(ref_expert.keys())[:3]: - d = ref_expert[i] - print(f" Expert {i}: {d['tokens'].numel()} tokens, l1_amax={d['l1_out'].amax().item():.4f} act_amax={d['activated'].amax().item():.4f} l2_amax={d['l2_out'].amax().item():.4f}") - - # Full CuTeDSL runner - if STAGE_END >= 6: - print("\n--- Stage 6: Full CuTeDSL Runner ---") - from vllm.nvfp4_cutedsl import CuTeDSLMoERunner - runner = CuTeDSLMoERunner( - num_experts=NUM_EXPERTS, hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS, - top_k=TOP_K, device=DEVICE, - ) - runner.l1_fp4 = weights['l1_fp4']; runner.l1_sf = weights['l1_sf']; runner.l1_gs = weights['l1_gs'] - runner.l2_fp4 = weights['l2_fp4']; runner.l2_sf = weights['l2_sf']; runner.l2_gs = weights['l2_gs'] - runner.set_swiglu_limit(SWIGLU_LIMIT) - with torch.no_grad(): - runner.compute_activation_global_scales(hidden_states, topk_weights, topk_ids) - print(f" Warmup gs: L1={runner._l1_activation_global_scale:.6f} L2={runner._l2_activation_global_scale:.6f}") - runner_out = runner.run(hidden_states, topk_weights, topk_ids) - print(f" Runner: amax={runner_out.amax().item():.4f} mean={runner_out.mean().item():.4f}") - print(f" NaN: {torch.isnan(runner_out).any().item()}") - - cos = F.cosine_similarity(ref_out.flatten().unsqueeze(0), runner_out.flatten().unsqueeze(0)).item() - print(f" vs BF16: cosine={cos:.6f}") - for t in range(NUM_TOKENS): - ct = F.cosine_similarity(ref_out[t].unsqueeze(0), runner_out[t].unsqueeze(0)).item() - if ct < 0.9: - print(f" Token {t}: cosine={ct:.4f} ref_max={ref_out[t].amax().item():.4f} run_max={runner_out[t].amax().item():.4f}") - - # layertest-style bridge reference (should match runner if runner is correct) - if STAGE_END >= 1 and STAGE_START <= 1: - print("\n--- Bridge Reference (run_nvfp4_moe) ---") - from cutedsl.bridge import run_nvfp4_moe - - # layertest uses 3 experts — let's use same subset for quick test - small_experts = list(range(min(3, NUM_EXPERTS))) - small_weights = prepare_nvfp4_weights(nvfp4_tensors, LAYER_IDX, small_experts, INTERMEDIATE_SIZE) - small_topk = torch.zeros(NUM_TOKENS, 2, dtype=torch.int32, device=DEVICE) - for i in range(NUM_TOKENS): - small_topk[i] = torch.tensor([0, 1], dtype=torch.int32) - small_tw = torch.tensor([[0.6, 0.4]] * NUM_TOKENS, dtype=torch.float32, device=DEVICE) - - bridge_out = run_nvfp4_moe(hidden_states, small_topk, small_tw, small_weights, small_experts) - - # BF16 ref for same subset - ref3, _ = bf16_moe_reference(hidden_states, nvfp4_tensors, LAYER_IDX, small_experts, small_topk, small_tw, SWIGLU_LIMIT) - cos3 = F.cosine_similarity(ref3.flatten().unsqueeze(0), bridge_out.flatten().unsqueeze(0)).item() - print(f" Bridge (3 experts) vs BF16: cosine={cos3:.6f}") - if cos3 >= 0.98: - print(" ✅ Bridge reference works correctly") + + # Step 1: Global→local remap (same as runner) + experts_start_idx = 0 + local_ids = topk_ids - experts_start_idx + local_mask = (local_ids >= 0) & (local_ids < NUM_EXPERTS) + safe_ids = local_ids.clamp(0, NUM_EXPERTS - 1) + safe_weights = topk_weights * local_mask.float() + + print(f"topk_ids:\n{topk_ids}") + print(f"local_ids:\n{local_ids}") + print(f"local_mask:\n{local_mask}") + print(f"safe_weights (should all be 0.1667):\n{safe_weights}") + + # Step 2: Sort by expert + flat_ids = safe_ids.reshape(-1) + flat_weights = safe_weights.reshape(-1) + num_slots = NUM_TOKENS * TOP_K + token_indices = torch.arange(num_slots, device=DEVICE) + + sort_idx = flat_ids.argsort(stable=True) + sorted_ids = flat_ids[sort_idx] + sorted_weights = flat_weights[sort_idx] + sorted_token_ids = token_indices[sort_idx] + + print(f"\nsorted_ids: {sorted_ids.tolist()}") + print(f"sorted_token_ids: {sorted_token_ids.tolist()}") + print(f"sorted_weights: {sorted_weights.tolist()}") + + # Step 3: Expert offsets + expert_id_range = torch.arange(NUM_EXPERTS, device=DEVICE) + tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int() + expert_offsets = torch.zeros(NUM_EXPERTS + 1, dtype=torch.int32, device=DEVICE) + expert_offsets[1:] = tokens_per_expert.cumsum(0) + + print(f"\ntokens_per_expert: {tokens_per_expert.tolist()}") + print(f"expert_offsets: {expert_offsets.tolist()}") + + # Step 4: Padded offsets + padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128 + padded_expert_offsets = torch.zeros(NUM_EXPERTS + 1, dtype=torch.int32, device=DEVICE) + padded_expert_offsets[1:] = padded_tokens_per_expert.cumsum(0) + total_padded = padded_expert_offsets[NUM_EXPERTS].item() + + print(f"padded_tokens_per_expert: {padded_tokens_per_expert.tolist()}") + print(f"padded_expert_offsets: {padded_expert_offsets.tolist()}") + print(f"total_padded: {total_padded}") + + # Step 5: Scatter into padded layout (runner's searchsorted approach) + row_indices = torch.arange(num_slots, device=DEVICE) + expert_assign = torch.searchsorted(expert_offsets[1:], row_indices, right=True).clamp(max=NUM_EXPERTS - 1) + local_row = row_indices - expert_offsets[expert_assign] + padded_dst = padded_expert_offsets[expert_assign] + local_row + + print(f"\nexpert_assign: {expert_assign.tolist()}") + print(f"local_row: {local_row.tolist()}") + print(f"padded_dst: {padded_dst.tolist()}") + + # Verify: expert_assign should match sorted_ids + match = (expert_assign == sorted_ids).all().item() + print(f"expert_assign == sorted_ids: {match}") + if not match: + mismatches = (expert_assign != sorted_ids).nonzero().squeeze() + print(f" Mismatch at rows: {mismatches.tolist()}") + print(f" expert_assign[mismatch]: {expert_assign[mismatches].tolist()}") + print(f" sorted_ids[mismatch]: {sorted_ids[mismatches].tolist()}") + + # Step 6: Scatter hidden states + slot_hidden = hidden_states[sorted_token_ids] + padded_hidden = torch.zeros(total_padded, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) + padded_hidden[padded_dst] = slot_hidden + + # Verify: padded_hidden[padded_dst] should match slot_hidden + verify = (padded_hidden[padded_dst] == slot_hidden).all().item() + print(f"\npadded_hidden scatter correct: {verify}") + + # Step 7: Now run L1 GEMM using bridge (direct call, not runner) + from cutedsl.bridge import ( + quantize_to_nvfp4, run_nvfp4_grouped_gemm, + assemble_scales_3d_side, make_b_k_major, + ) + + # Prepare weights (same as runner's _ensure_stacked) + expert_indices = list(range(NUM_EXPERTS)) + l1_fp4, l1_sf, l1_gs_list = [], [], [] + for e in expert_indices: + gw = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE) + uw = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight"].to(DEVICE) + gsf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE) + usf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE) + ggs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() + ugs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale_2"].item() + fw = torch.cat([gw, uw], dim=0).view(torch.float4_e2m1fn_x2).permute(1,0).contiguous() + fsf = torch.cat([gsf, usf], dim=0).permute(1,0).contiguous() + mgs = max(ggs, ugs) + if ggs != ugs: + fsf32 = fsf.float() + fsf32[:, :INTERMEDIATE_SIZE] *= (ggs / mgs) + fsf32[:, INTERMEDIATE_SIZE:] *= (ugs / mgs) + fsf = fsf32.to(torch.float8_e4m3fn) + l1_fp4.append(fw); l1_sf.append(fsf); l1_gs_list.append(mgs) + + l1_mat_b = torch.stack(l1_fp4) + l1_mat_b = make_b_k_major(l1_mat_b) + l1_scale_b = assemble_scales_3d_side(l1_sf) + l1_gsb = torch.tensor(l1_gs_list, dtype=torch.float32, device=DEVICE) + + # Quantize activation (dynamic gs, not warmup) + print("\n--- L1 GEMM (dynamic gs) ---") + x_fp4, x_sf, l1_gs = quantize_to_nvfp4(padded_hidden) + print(f" L1 gs (dynamic): {l1_gs:.6f}") + + # For scale_a, we need to use the runner's assembly approach. + # Use the same _assemble_scales_cudagraph_safe function + from vllm.nvfp4_cutedsl import CuTeDSLMoERunner + runner = CuTeDSLMoERunner( + num_experts=NUM_EXPERTS, hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS, + top_k=TOP_K, device=DEVICE, + ) + # Just use the runner's scale assembly + l1_gsa = torch.full((NUM_EXPERTS,), l1_gs, dtype=torch.float32, device=DEVICE) + l1_scale_a = runner._assemble_scales_cudagraph_safe( + x_sf[:num_slots], expert_offsets[:NUM_EXPERTS+1], + padded_expert_offsets, + runner._padded_x_sf_buf_l1, runner._per_expert_scale_bufs_l1 + ) + + l1_out = run_nvfp4_grouped_gemm( + mat_a=x_fp4, mat_b=l1_mat_b, + scale_a=l1_scale_a, scale_b=l1_scale_b, + expert_offsets=padded_expert_offsets[1:NUM_EXPERTS+1], + global_scale_a=l1_gsa, global_scale_b=l1_gsb, + ) + print(f" L1 out: shape={l1_out.shape} amax={l1_out.amax().item():.4f}") + print(f" L1 out NaN: {torch.isnan(l1_out).any().item()}") + + # Extract real tokens + l1_out_real = l1_out[padded_dst] + print(f" L1 real: amax={l1_out_real.amax().item():.4f}") + + # BF16 reference L1 + ref_l1 = torch.zeros(num_slots, 2*INTERMEDIATE_SIZE, dtype=torch.bfloat16, device=DEVICE) + for i, e in enumerate(expert_indices): + start = expert_offsets[i].item() + end = expert_offsets[i+1].item() + if start == end: + continue + x = slot_hidden[start:end] + gw = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE) + uw = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight"].to(DEVICE) + gsf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE) + usf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE) + ggs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() + ugs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale_2"].item() + gate = x @ dequantize_nvfp4_weight(gw, gsf, ggs).T + up = x @ dequantize_nvfp4_weight(uw, usf, ugs).T + ref_l1[start:end] = torch.cat([gate, up], dim=1) + + # Compare L1 + cos_l1 = F.cosine_similarity(ref_l1.flatten().unsqueeze(0), l1_out_real.flatten().unsqueeze(0)).item() + print(f"\n L1 cosine vs BF16: {cos_l1:.6f}") + + # Per-expert L1 comparison + for i in list(range(NUM_EXPERTS))[:5]: + start = expert_offsets[i].item() + end = expert_offsets[i+1].item() + if start == end: + continue + c = F.cosine_similarity(ref_l1[start:end].flatten().unsqueeze(0), + l1_out_real[start:end].flatten().unsqueeze(0)).item() + print(f" Expert {i} L1: cosine={c:.6f} ref_amax={ref_l1[start:end].amax().item():.4f} run_amax={l1_out_real[start:end].amax().item():.4f}") if __name__ == "__main__":