diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8b818f67c..c9d9ecf4a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1191,13 +1191,14 @@ class GPUModelRunner( return # Find the number of accepted tokens for each sequence. - num_accepted_tokens = ( + num_reqs = output_token_ids.size(0) + self.num_accepted_tokens.gpu[:num_reqs] = ( ( torch.cat( [ output_token_ids, torch.full( - (output_token_ids.size(0), 1), + (num_reqs, 1), -1, device=output_token_ids.device, ), @@ -1208,12 +1209,13 @@ class GPUModelRunner( ) .int() .argmax(-1) - .cpu() - .numpy() ) - for i, num_tokens in enumerate(num_accepted_tokens): - self.input_batch.num_accepted_tokens_cpu[i] = num_tokens if self.cache_config.mamba_cache_mode == "align": + for i, num_tokens in enumerate( + self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy() + ): + self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + mamba_utils.postprocess_mamba( scheduler_output, self.kv_cache_config, @@ -1224,6 +1226,10 @@ class GPUModelRunner( self.model.get_mamba_state_copy_func(), self._get_mamba_copy_bufs(), ) + else: + self.input_batch.num_accepted_tokens_cpu_tensor[:num_reqs].copy_( + self.num_accepted_tokens.gpu[:num_reqs], non_blocking=True + ) def _update_streaming_request( self, req_id: str, new_req_data: NewRequestData