This commit is contained in:
@@ -121,6 +121,12 @@ class BlockTable:
|
||||
self.num_blocks_per_row[row_idx] = 0
|
||||
self.append_row(block_ids, row_idx)
|
||||
|
||||
def clear_row(self, row_idx: int) -> None:
|
||||
num_blocks = self.num_blocks_per_row[row_idx]
|
||||
if num_blocks > 0:
|
||||
self.block_table.np[row_idx, :num_blocks] = 0
|
||||
self.num_blocks_per_row[row_idx] = 0
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks = self.num_blocks_per_row[src]
|
||||
block_table_np = self.block_table.np
|
||||
@@ -275,6 +281,10 @@ class MultiGroupBlockTable:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.add_row(block_ids[i], row_idx)
|
||||
|
||||
def clear_row(self, row_idx: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.clear_row(row_idx)
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.move_row(src, tgt)
|
||||
|
||||
@@ -496,6 +496,7 @@ class InputBatch:
|
||||
self._req_ids[req_index] = None
|
||||
self.req_output_token_ids[req_index] = None
|
||||
self.spec_token_ids[req_index].clear()
|
||||
self.block_table.clear_row(req_index)
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
|
||||
@@ -5376,6 +5376,12 @@ class GPUModelRunner(
|
||||
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
|
||||
self.query_start_loc.copy_to_gpu()
|
||||
|
||||
# Sync block table CPU->GPU so cleared rows from
|
||||
# remove_request() are visible to the attention metadata
|
||||
# builder. Without this, stale block IDs from finished
|
||||
# requests can corrupt Mamba state.
|
||||
self.input_batch.block_table.commit_block_table(num_reqs_padded)
|
||||
|
||||
pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
|
||||
attn_metadata, _ = self._build_attention_metadata(
|
||||
num_tokens=num_tokens_unpadded,
|
||||
|
||||
Reference in New Issue
Block a user