Fix N for C128A (need 128 tokens)

This commit is contained in:
2026-05-19 16:04:53 +00:00
parent 49c2e088d4
commit abff942edd

View File

@@ -235,7 +235,7 @@ def test_csa_layer(layer_id, compress_ratio):
swa_cache = torch.zeros(num_blocks, block_size, HD, dtype=torch.uint8, device=DEV)
swa_inv_scale = torch.zeros(max_tokens, 1, dtype=torch.bfloat16, device=DEV)
N = 16 # Prefill tokens (use a multiple of compress_ratio)
N = 128 if cr >= 128 else 16 # Prefill tokens (use a multiple of compress_ratio)
assert N % cr == 0, f"N={N} must be multiple of compress_ratio={cr}"
token_ids = torch.arange(1, N + 1, dtype=torch.long, device=DEV)