diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 579c9b7a5..34bcc241f 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -529,6 +529,12 @@ class InputBatch: def swap_states(self, i1: int, i2: int) -> None: old_id_i1 = self._req_ids[i1] old_id_i2 = self._req_ids[i2] + # Only swap the active token prefix for each request. Copying full + # max_model_len rows is expensive and unnecessary during reordering. + i1_active_token_count = self._get_active_token_count(i1) + i2_active_token_count = self._get_active_token_count(i2) + max_active_token_count = max(i1_active_token_count, i2_active_token_count) + self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1] # noqa self.req_output_token_ids[i1], self.req_output_token_ids[i2] = ( self.req_output_token_ids[i2], @@ -560,12 +566,15 @@ class InputBatch: # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] # instead, we need to temporarily copy the data for one of the indices - # TODO(lucas): optimize this by only copying valid indices - tmp = self.token_ids_cpu[i1, ...].copy() - self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] - self.token_ids_cpu[i2, ...] = tmp + tmp_token_ids = self.token_ids_cpu[i1, :max_active_token_count].copy() + self.token_ids_cpu[i1, :max_active_token_count] = self.token_ids_cpu[ + i2, :max_active_token_count + ] + self.token_ids_cpu[i2, :max_active_token_count] = tmp_token_ids - self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...] + self.is_token_ids[[i1, i2], :max_active_token_count] = self.is_token_ids[ + [i2, i1], :max_active_token_count + ] # Swap prompt embeddings if they exist embeds_i1 = self.req_prompt_embeds.get(i1) @@ -629,6 +638,11 @@ class InputBatch: self.allowed_token_ids_mask_cpu_tensor[i1], ) + def _get_active_token_count(self, req_index: int) -> int: + return int(self.num_tokens_no_spec[req_index]) + len( + self.spec_token_ids[req_index] + ) + def condense(self) -> None: """Slide non-empty requests down into lower, empty indices. @@ -678,9 +692,7 @@ class InputBatch: self.req_output_token_ids[last_req_index] = None self.req_id_to_index[req_id] = empty_index - num_tokens = self.num_tokens_no_spec[last_req_index] + len( - self.spec_token_ids[last_req_index] - ) + num_tokens = self._get_active_token_count(last_req_index) (self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = ( self.spec_token_ids[empty_index],