diff --git a/single_shot_inference.py b/single_shot_inference.py index 406512dc..e3c91ce6 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -398,9 +398,11 @@ class Indexer: wp_out = wp_w.shape[0] wp_in = wp_w.shape[1] * 2 self.wp_lin = make_nvfp4_linear(wp_in, wp_out, dev, w, pfx, 'weights_proj') - if f"{pfx}.compressor.kv_proj.weight" in w: + # Indexer compressor weights are directly under the indexer prefix + # (e.g. *.indexer.kv_proj.weight), NOT nested under *.indexer.compressor. + if f"{pfx}.kv_proj.weight" in w: self.compressor = Compressor(4, self.ihd, 7168, dev) - self.compressor.load(w, f"{pfx}.compressor", dev) + self.compressor.load(w, pfx, dev) def forward(self, q_lora, hidden_states, comp_indexer_kv, positions, layer_idx=None): if self.q_b_lin is None or comp_indexer_kv is None or comp_indexer_kv.shape[0] == 0: