[Perf] [Hybrid] Copy num_accepted_tokens in non-blocking way when not using prefix caching (#35442)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -1191,13 +1191,14 @@ class GPUModelRunner(
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Find the number of accepted tokens for each sequence.
|
# Find the number of accepted tokens for each sequence.
|
||||||
num_accepted_tokens = (
|
num_reqs = output_token_ids.size(0)
|
||||||
|
self.num_accepted_tokens.gpu[:num_reqs] = (
|
||||||
(
|
(
|
||||||
torch.cat(
|
torch.cat(
|
||||||
[
|
[
|
||||||
output_token_ids,
|
output_token_ids,
|
||||||
torch.full(
|
torch.full(
|
||||||
(output_token_ids.size(0), 1),
|
(num_reqs, 1),
|
||||||
-1,
|
-1,
|
||||||
device=output_token_ids.device,
|
device=output_token_ids.device,
|
||||||
),
|
),
|
||||||
@@ -1208,12 +1209,13 @@ class GPUModelRunner(
|
|||||||
)
|
)
|
||||||
.int()
|
.int()
|
||||||
.argmax(-1)
|
.argmax(-1)
|
||||||
.cpu()
|
|
||||||
.numpy()
|
|
||||||
)
|
)
|
||||||
for i, num_tokens in enumerate(num_accepted_tokens):
|
|
||||||
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
|
|
||||||
if self.cache_config.mamba_cache_mode == "align":
|
if self.cache_config.mamba_cache_mode == "align":
|
||||||
|
for i, num_tokens in enumerate(
|
||||||
|
self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy()
|
||||||
|
):
|
||||||
|
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
|
||||||
|
|
||||||
mamba_utils.postprocess_mamba(
|
mamba_utils.postprocess_mamba(
|
||||||
scheduler_output,
|
scheduler_output,
|
||||||
self.kv_cache_config,
|
self.kv_cache_config,
|
||||||
@@ -1224,6 +1226,10 @@ class GPUModelRunner(
|
|||||||
self.model.get_mamba_state_copy_func(),
|
self.model.get_mamba_state_copy_func(),
|
||||||
self._get_mamba_copy_bufs(),
|
self._get_mamba_copy_bufs(),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.input_batch.num_accepted_tokens_cpu_tensor[:num_reqs].copy_(
|
||||||
|
self.num_accepted_tokens.gpu[:num_reqs], non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
def _update_streaming_request(
|
def _update_streaming_request(
|
||||||
self, req_id: str, new_req_data: NewRequestData
|
self, req_id: str, new_req_data: NewRequestData
|
||||||
|
|||||||
Reference in New Issue
Block a user