diff --git a/tests/unit/test_part_a_decode_diagnostics.py b/tests/unit/test_part_a_decode_diagnostics.py index 1664879d..f2fd4fe3 100644 --- a/tests/unit/test_part_a_decode_diagnostics.py +++ b/tests/unit/test_part_a_decode_diagnostics.py @@ -23,6 +23,9 @@ CHECKPOINT_DIR = os.environ.get( NUM_GPUS = int(os.environ.get("NUM_GPUS", "8")) DEVICE = "cuda:0" TEST_LAYERS = int(os.environ.get("TEST_LAYERS", "5")) +# First layer index to test. L0-1 are hash routing, L2+ are dense/CSA/HCA. +# Set to 0 to include hash layers. +FIRST_LAYER = int(os.environ.get("FIRST_LAYER", "2")) def cosine(a, b): @@ -280,6 +283,9 @@ def main(): gpu = li % NUM_GPUS if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}") torch.cuda.set_device(gpu) + if pi == 0: + r = routers.get(li) + print(f" L{li} router: mode={r.mode if r else 'None'} has_gate_lin={r._gate_lin is not None if r and hasattr(r, '_gate_lin') else 'N/A'}", flush=True) X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu], attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li), kv_caches[li], pos, tid32, compressors.get(li), indexers.get(li),