diff --git a/tests/unit/test_decode_fmha_layer.py b/tests/unit/test_decode_fmha_layer.py index 7ba9abc0..96e6dae7 100644 --- a/tests/unit/test_decode_fmha_layer.py +++ b/tests/unit/test_decode_fmha_layer.py @@ -387,7 +387,7 @@ def main(): # 4. Indexer top-k (CSA layers) topk_idx = None if indexer is not None and ratio == 4: - topk_idx = indexer.forward(q_a, x_normed, kc, dec_pos.to(dev), layer_idx=li + topk_idx = indexer.forward(q_a, x_normed, kc, dec_pos.to(dev), layer_idx=li) if topk_idx is not None: print(f" L{li} CSA: indexer topk shape={tuple(topk_idx.shape)} " f"range=[{topk_idx.min().item()}, {topk_idx.max().item()}] "