[Bugfix] Fix decode tokens w. CUDA graph (#6757)

This commit is contained in:
Cody Yu
2024-07-24 22:33:56 -07:00
committed by GitHub
parent 9e169a4c61
commit 309aaef825
4 changed files with 31 additions and 4 deletions

View File

@@ -193,6 +193,7 @@ def test_prepare_decode_cuda_graph(batch_size):
for _ in range(expected_bs - len(seq_lens)):
seq_lens.append(1)
assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.num_decode_tokens == len(seq_lens)
start_idx = 0
start_loc = [start_idx]
for _ in context_lens: