Fix N for C128A (need 128 tokens)
This commit is contained in:
@@ -235,7 +235,7 @@ def test_csa_layer(layer_id, compress_ratio):
|
||||
swa_cache = torch.zeros(num_blocks, block_size, HD, dtype=torch.uint8, device=DEV)
|
||||
swa_inv_scale = torch.zeros(max_tokens, 1, dtype=torch.bfloat16, device=DEV)
|
||||
|
||||
N = 16 # Prefill tokens (use a multiple of compress_ratio)
|
||||
N = 128 if cr >= 128 else 16 # Prefill tokens (use a multiple of compress_ratio)
|
||||
assert N % cr == 0, f"N={N} must be multiple of compress_ratio={cr}"
|
||||
token_ids = torch.arange(1, N + 1, dtype=torch.long, device=DEV)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user