diff --git a/tests/layertest.py b/tests/layertest.py index 6513e49c..fc0f178b 100644 --- a/tests/layertest.py +++ b/tests/layertest.py @@ -396,7 +396,7 @@ def main(): print(" STEP 1: Loading original MXFP4 checkpoint") print("="*70) - orig_tensors = load_layer_tensors(ORIG_MODEL_DIR, LAYER_IDX, prefix_filter="experts") + orig_tensors = load_layer_tensors(ORIG_MODEL_DIR, LAYER_IDX) print_layer_keys(orig_tensors, "Original checkpoint (MXFP4)") # Dequantize to BF16 @@ -430,7 +430,7 @@ def main(): print(" STEP 3: Loading NVFP4 checkpoint") print("="*70) - nvfp4_tensors = load_layer_tensors(NVFP4_MODEL_DIR, LAYER_IDX, prefix_filter="experts") + nvfp4_tensors = load_layer_tensors(NVFP4_MODEL_DIR, LAYER_IDX) print_layer_keys(nvfp4_tensors, "NVFP4 checkpoint") # Verify dtype of weight_scale (should be float8_e4m3fn, NOT float8_e8m0fnu)