Fix Mamba state corruption from referencing stale block table entries (#37728) (#37728) (#37728)

This commit is contained in:
Ming Yang
2026-03-24 10:29:59 -07:00
committed by GitHub
parent 4df5fa7439
commit c07e2ca6e0
3 changed files with 17 additions and 0 deletions

View File

@@ -121,6 +121,12 @@ class BlockTable:
self.num_blocks_per_row[row_idx] = 0 self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx) self.append_row(block_ids, row_idx)
def clear_row(self, row_idx: int) -> None:
num_blocks = self.num_blocks_per_row[row_idx]
if num_blocks > 0:
self.block_table.np[row_idx, :num_blocks] = 0
self.num_blocks_per_row[row_idx] = 0
def move_row(self, src: int, tgt: int) -> None: def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src] num_blocks = self.num_blocks_per_row[src]
block_table_np = self.block_table.np block_table_np = self.block_table.np
@@ -275,6 +281,10 @@ class MultiGroupBlockTable:
for i, block_table in enumerate(self.block_tables): for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx) block_table.add_row(block_ids[i], row_idx)
def clear_row(self, row_idx: int) -> None:
for block_table in self.block_tables:
block_table.clear_row(row_idx)
def move_row(self, src: int, tgt: int) -> None: def move_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables: for block_table in self.block_tables:
block_table.move_row(src, tgt) block_table.move_row(src, tgt)

View File

@@ -496,6 +496,7 @@ class InputBatch:
self._req_ids[req_index] = None self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None self.req_output_token_ids[req_index] = None
self.spec_token_ids[req_index].clear() self.spec_token_ids[req_index].clear()
self.block_table.clear_row(req_index)
# LoRA # LoRA
lora_id = self.request_lora_mapping[req_index] lora_id = self.request_lora_mapping[req_index]

View File

@@ -5376,6 +5376,12 @@ class GPUModelRunner(
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
self.query_start_loc.copy_to_gpu() self.query_start_loc.copy_to_gpu()
# Sync block table CPU->GPU so cleared rows from
# remove_request() are visible to the attention metadata
# builder. Without this, stale block IDs from finished
# requests can corrupt Mamba state.
self.input_batch.block_table.commit_block_table(num_reqs_padded)
pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
attn_metadata, _ = self._build_attention_metadata( attn_metadata, _ = self._build_attention_metadata(
num_tokens=num_tokens_unpadded, num_tokens=num_tokens_unpadded,