diff --git a/tests/test_pipeline_real_weights.py b/tests/test_pipeline_real_weights.py index 61de57a0..46bbe279 100644 --- a/tests/test_pipeline_real_weights.py +++ b/tests/test_pipeline_real_weights.py @@ -55,7 +55,7 @@ def load_expert_weights(layer_idx, num_experts): for shard_path in shards: with safe_open(shard_path, framework="pt", device="cpu") as f: for e in range(num_experts): - if len(experts) > e: + if e < len(experts): continue prefix = f"model.layers.{layer_idx}.mlp.experts.{e}" gate_w = f"{prefix}.gate_proj.weight"