diff --git a/single_shot_inference.py b/single_shot_inference.py index 1ad13283..c4aafb06 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -861,7 +861,7 @@ def main(): n_ih = cfg.get("index_n_heads", 64); ihd = cfg.get("index_head_dim", 128); itk = cfg.get("index_topk", 1024) for li in range(n_layers): dev = f"cuda:{li % NUM_GPUS}"; ratio = cr[li] if li < len(cr) else 128 - kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), dev) + kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), device=dev) if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev) if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)