[BugFix] Fix Llama4 - Index Error When Single Request Near Max Context (#16209)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@@ -264,7 +264,7 @@ def make_local_attention_virtual_batches(
|
|||||||
np.arange(pages_per_local_batch, dtype=np.int32),
|
np.arange(pages_per_local_batch, dtype=np.int32),
|
||||||
(virtual_batches, pages_per_local_batch)) \
|
(virtual_batches, pages_per_local_batch)) \
|
||||||
+ np.expand_dims(block_starts, axis=1)
|
+ np.expand_dims(block_starts, axis=1)
|
||||||
block_indices = block_indices.flatten()
|
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
||||||
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
|
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
|
||||||
local_blocks * pages_per_local_batch)
|
local_blocks * pages_per_local_batch)
|
||||||
block_table_local = block_table[batch_indices, block_indices]\
|
block_table_local = block_table[batch_indices, block_indices]\
|
||||||
|
|||||||
Reference in New Issue
Block a user