From abff942edd6ea601a3cb958c5bb31cf217f6c340 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 16:04:53 +0000 Subject: [PATCH] Fix N for C128A (need 128 tokens) --- tests/test_csa_sparse_attn_b200.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_csa_sparse_attn_b200.py b/tests/test_csa_sparse_attn_b200.py index 5f6566c4..5c75e229 100644 --- a/tests/test_csa_sparse_attn_b200.py +++ b/tests/test_csa_sparse_attn_b200.py @@ -235,7 +235,7 @@ def test_csa_layer(layer_id, compress_ratio): swa_cache = torch.zeros(num_blocks, block_size, HD, dtype=torch.uint8, device=DEV) swa_inv_scale = torch.zeros(max_tokens, 1, dtype=torch.bfloat16, device=DEV) - N = 16 # Prefill tokens (use a multiple of compress_ratio) + N = 128 if cr >= 128 else 16 # Prefill tokens (use a multiple of compress_ratio) assert N % cr == 0, f"N={N} must be multiple of compress_ratio={cr}" token_ids = torch.arange(1, N + 1, dtype=torch.long, device=DEV)