diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 94a00c825..29fe9ec83 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -727,8 +727,10 @@ class GPUModelRunner( self.draft_token_ids_copy_stream: torch.cuda.Stream | None = None self.valid_sampled_token_count_cpu: torch.Tensor | None = None self.draft_token_ids_cpu: torch.Tensor | None = None + self.num_accepted_tokens_event: torch.Event | None = None if self.num_spec_tokens: self.draft_token_ids_event = torch.Event() + self.num_accepted_tokens_event = torch.Event() self.draft_token_ids_copy_stream = torch.cuda.Stream() self.draft_token_ids_cpu = torch.empty( (self.max_num_reqs, self.num_spec_tokens), @@ -1229,6 +1231,8 @@ class GPUModelRunner( self.input_batch.num_accepted_tokens_cpu_tensor[:num_reqs].copy_( self.num_accepted_tokens.gpu[:num_reqs], non_blocking=True ) + assert self.num_accepted_tokens_event is not None + self.num_accepted_tokens_event.record() def _update_streaming_request( self, req_id: str, new_req_data: NewRequestData @@ -1773,6 +1777,8 @@ class GPUModelRunner( max_seq_len = self.seq_lens.np[:num_reqs].max().item() if use_spec_decode: + if self.num_accepted_tokens_event is not None: + self.num_accepted_tokens_event.synchronize() self.num_accepted_tokens.np[:num_reqs] = ( self.input_batch.num_accepted_tokens_cpu[:num_reqs] )