diff --git a/tests/test_csa_sparse_attn_b200.py b/tests/test_csa_sparse_attn_b200.py index 5f6566c4..5c75e229 100644 --- a/tests/test_csa_sparse_attn_b200.py +++ b/tests/test_csa_sparse_attn_b200.py @@ -235,7 +235,7 @@ def test_csa_layer(layer_id, compress_ratio): swa_cache = torch.zeros(num_blocks, block_size, HD, dtype=torch.uint8, device=DEV) swa_inv_scale = torch.zeros(max_tokens, 1, dtype=torch.bfloat16, device=DEV) - N = 16 # Prefill tokens (use a multiple of compress_ratio) + N = 128 if cr >= 128 else 16 # Prefill tokens (use a multiple of compress_ratio) assert N % cr == 0, f"N={N} must be multiple of compress_ratio={cr}" token_ids = torch.arange(1, N + 1, dtype=torch.long, device=DEV)