This commit is contained in:
@@ -121,6 +121,12 @@ class BlockTable:
|
|||||||
self.num_blocks_per_row[row_idx] = 0
|
self.num_blocks_per_row[row_idx] = 0
|
||||||
self.append_row(block_ids, row_idx)
|
self.append_row(block_ids, row_idx)
|
||||||
|
|
||||||
|
def clear_row(self, row_idx: int) -> None:
|
||||||
|
num_blocks = self.num_blocks_per_row[row_idx]
|
||||||
|
if num_blocks > 0:
|
||||||
|
self.block_table.np[row_idx, :num_blocks] = 0
|
||||||
|
self.num_blocks_per_row[row_idx] = 0
|
||||||
|
|
||||||
def move_row(self, src: int, tgt: int) -> None:
|
def move_row(self, src: int, tgt: int) -> None:
|
||||||
num_blocks = self.num_blocks_per_row[src]
|
num_blocks = self.num_blocks_per_row[src]
|
||||||
block_table_np = self.block_table.np
|
block_table_np = self.block_table.np
|
||||||
@@ -275,6 +281,10 @@ class MultiGroupBlockTable:
|
|||||||
for i, block_table in enumerate(self.block_tables):
|
for i, block_table in enumerate(self.block_tables):
|
||||||
block_table.add_row(block_ids[i], row_idx)
|
block_table.add_row(block_ids[i], row_idx)
|
||||||
|
|
||||||
|
def clear_row(self, row_idx: int) -> None:
|
||||||
|
for block_table in self.block_tables:
|
||||||
|
block_table.clear_row(row_idx)
|
||||||
|
|
||||||
def move_row(self, src: int, tgt: int) -> None:
|
def move_row(self, src: int, tgt: int) -> None:
|
||||||
for block_table in self.block_tables:
|
for block_table in self.block_tables:
|
||||||
block_table.move_row(src, tgt)
|
block_table.move_row(src, tgt)
|
||||||
|
|||||||
@@ -496,6 +496,7 @@ class InputBatch:
|
|||||||
self._req_ids[req_index] = None
|
self._req_ids[req_index] = None
|
||||||
self.req_output_token_ids[req_index] = None
|
self.req_output_token_ids[req_index] = None
|
||||||
self.spec_token_ids[req_index].clear()
|
self.spec_token_ids[req_index].clear()
|
||||||
|
self.block_table.clear_row(req_index)
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
lora_id = self.request_lora_mapping[req_index]
|
lora_id = self.request_lora_mapping[req_index]
|
||||||
|
|||||||
@@ -5376,6 +5376,12 @@ class GPUModelRunner(
|
|||||||
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
|
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
|
||||||
self.query_start_loc.copy_to_gpu()
|
self.query_start_loc.copy_to_gpu()
|
||||||
|
|
||||||
|
# Sync block table CPU->GPU so cleared rows from
|
||||||
|
# remove_request() are visible to the attention metadata
|
||||||
|
# builder. Without this, stale block IDs from finished
|
||||||
|
# requests can corrupt Mamba state.
|
||||||
|
self.input_batch.block_table.commit_block_table(num_reqs_padded)
|
||||||
|
|
||||||
pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
|
pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
|
||||||
attn_metadata, _ = self._build_attention_metadata(
|
attn_metadata, _ = self._build_attention_metadata(
|
||||||
num_tokens=num_tokens_unpadded,
|
num_tokens=num_tokens_unpadded,
|
||||||
|
|||||||
Reference in New Issue
Block a user