fix(worker): optimize swap_states to copy only active token prefixes (#34733)

Signed-off-by: Philip Ottesen <phiott256@gmail.com>
This commit is contained in:
Philip Ottesen
2026-03-18 22:59:27 +01:00
committed by GitHub
parent 0d81a1fe61
commit 0091017188

View File

@@ -529,6 +529,12 @@ class InputBatch:
def swap_states(self, i1: int, i2: int) -> None:
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
# Only swap the active token prefix for each request. Copying full
# max_model_len rows is expensive and unnecessary during reordering.
i1_active_token_count = self._get_active_token_count(i1)
i2_active_token_count = self._get_active_token_count(i2)
max_active_token_count = max(i1_active_token_count, i2_active_token_count)
self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
self.req_output_token_ids[i2],
@@ -560,12 +566,15 @@ class InputBatch:
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
tmp_token_ids = self.token_ids_cpu[i1, :max_active_token_count].copy()
self.token_ids_cpu[i1, :max_active_token_count] = self.token_ids_cpu[
i2, :max_active_token_count
]
self.token_ids_cpu[i2, :max_active_token_count] = tmp_token_ids
self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
self.is_token_ids[[i1, i2], :max_active_token_count] = self.is_token_ids[
[i2, i1], :max_active_token_count
]
# Swap prompt embeddings if they exist
embeds_i1 = self.req_prompt_embeds.get(i1)
@@ -629,6 +638,11 @@ class InputBatch:
self.allowed_token_ids_mask_cpu_tensor[i1],
)
def _get_active_token_count(self, req_index: int) -> int:
return int(self.num_tokens_no_spec[req_index]) + len(
self.spec_token_ids[req_index]
)
def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices.
@@ -678,9 +692,7 @@ class InputBatch:
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
num_tokens = self.num_tokens_no_spec[last_req_index] + len(
self.spec_token_ids[last_req_index]
)
num_tokens = self._get_active_token_count(last_req_index)
(self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
self.spec_token_ids[empty_index],