[Bugfix] Fix decode tokens w. CUDA graph (#6757)
This commit is contained in:
@@ -193,6 +193,7 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
for _ in range(expected_bs - len(seq_lens)):
|
||||
seq_lens.append(1)
|
||||
assert attn_metadata.seq_lens == seq_lens
|
||||
assert attn_metadata.num_decode_tokens == len(seq_lens)
|
||||
start_idx = 0
|
||||
start_loc = [start_idx]
|
||||
for _ in context_lens:
|
||||
|
||||
Reference in New Issue
Block a user