fix(worker): optimize swap_states to copy only active token prefixes (#34733)
Signed-off-by: Philip Ottesen <phiott256@gmail.com>
This commit is contained in:
@@ -529,6 +529,12 @@ class InputBatch:
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
old_id_i1 = self._req_ids[i1]
|
||||
old_id_i2 = self._req_ids[i2]
|
||||
# Only swap the active token prefix for each request. Copying full
|
||||
# max_model_len rows is expensive and unnecessary during reordering.
|
||||
i1_active_token_count = self._get_active_token_count(i1)
|
||||
i2_active_token_count = self._get_active_token_count(i2)
|
||||
max_active_token_count = max(i1_active_token_count, i2_active_token_count)
|
||||
|
||||
self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1] # noqa
|
||||
self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
|
||||
self.req_output_token_ids[i2],
|
||||
@@ -560,12 +566,15 @@ class InputBatch:
|
||||
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
||||
# instead, we need to temporarily copy the data for one of the indices
|
||||
# TODO(lucas): optimize this by only copying valid indices
|
||||
tmp = self.token_ids_cpu[i1, ...].copy()
|
||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||
self.token_ids_cpu[i2, ...] = tmp
|
||||
tmp_token_ids = self.token_ids_cpu[i1, :max_active_token_count].copy()
|
||||
self.token_ids_cpu[i1, :max_active_token_count] = self.token_ids_cpu[
|
||||
i2, :max_active_token_count
|
||||
]
|
||||
self.token_ids_cpu[i2, :max_active_token_count] = tmp_token_ids
|
||||
|
||||
self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
|
||||
self.is_token_ids[[i1, i2], :max_active_token_count] = self.is_token_ids[
|
||||
[i2, i1], :max_active_token_count
|
||||
]
|
||||
|
||||
# Swap prompt embeddings if they exist
|
||||
embeds_i1 = self.req_prompt_embeds.get(i1)
|
||||
@@ -629,6 +638,11 @@ class InputBatch:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1],
|
||||
)
|
||||
|
||||
def _get_active_token_count(self, req_index: int) -> int:
|
||||
return int(self.num_tokens_no_spec[req_index]) + len(
|
||||
self.spec_token_ids[req_index]
|
||||
)
|
||||
|
||||
def condense(self) -> None:
|
||||
"""Slide non-empty requests down into lower, empty indices.
|
||||
|
||||
@@ -678,9 +692,7 @@ class InputBatch:
|
||||
self.req_output_token_ids[last_req_index] = None
|
||||
self.req_id_to_index[req_id] = empty_index
|
||||
|
||||
num_tokens = self.num_tokens_no_spec[last_req_index] + len(
|
||||
self.spec_token_ids[last_req_index]
|
||||
)
|
||||
num_tokens = self._get_active_token_count(last_req_index)
|
||||
|
||||
(self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
|
||||
self.spec_token_ids[empty_index],
|
||||
|
||||
Reference in New Issue
Block a user