diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 0f5446b44..f46e8a8ed 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -121,6 +121,12 @@ class BlockTable: self.num_blocks_per_row[row_idx] = 0 self.append_row(block_ids, row_idx) + def clear_row(self, row_idx: int) -> None: + num_blocks = self.num_blocks_per_row[row_idx] + if num_blocks > 0: + self.block_table.np[row_idx, :num_blocks] = 0 + self.num_blocks_per_row[row_idx] = 0 + def move_row(self, src: int, tgt: int) -> None: num_blocks = self.num_blocks_per_row[src] block_table_np = self.block_table.np @@ -275,6 +281,10 @@ class MultiGroupBlockTable: for i, block_table in enumerate(self.block_tables): block_table.add_row(block_ids[i], row_idx) + def clear_row(self, row_idx: int) -> None: + for block_table in self.block_tables: + block_table.clear_row(row_idx) + def move_row(self, src: int, tgt: int) -> None: for block_table in self.block_tables: block_table.move_row(src, tgt) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index e20d268fe..11d57f1d7 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -496,6 +496,7 @@ class InputBatch: self._req_ids[req_index] = None self.req_output_token_ids[req_index] = None self.spec_token_ids[req_index].clear() + self.block_table.clear_row(req_index) # LoRA lora_id = self.request_lora_mapping[req_index] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6465ca654..be7734487 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5376,6 +5376,12 @@ class GPUModelRunner( self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() + # Sync block table CPU->GPU so cleared rows from + # remove_request() are visible to the attention metadata + # builder. Without this, stale block IDs from finished + # requests can corrupt Mamba state. + self.input_batch.block_table.commit_block_table(num_reqs_padded) + pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL attn_metadata, _ = self._build_attention_metadata( num_tokens=num_tokens_unpadded,