[core] add bucket padding to tpu_model_runner (#14995)
Signed-off-by: Chenyaaang <llccyy1212@gmail.com> Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Co-authored-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import bisect
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
from unittest.mock import patch
|
||||
@@ -170,6 +171,10 @@ class TPUModelRunner:
|
||||
# Range tensor with values [0 .. self.max_num_tokens - 1].
|
||||
# Used to initialize positions / context_lens / seq_lens
|
||||
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
|
||||
self.num_tokens_paddings = _get_paddings(
|
||||
min_token_size=16,
|
||||
max_token_size=self.max_num_tokens,
|
||||
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
@@ -422,7 +427,7 @@ class TPUModelRunner:
|
||||
|
||||
# Do the padding and copy the tensors to the TPU.
|
||||
padded_total_num_scheduled_tokens = _get_padded_token_len(
|
||||
total_num_scheduled_tokens)
|
||||
self.num_tokens_paddings, total_num_scheduled_tokens)
|
||||
# Zero out to avoid spurious values from prev iteration (last cp chunk)
|
||||
self.input_ids_cpu[
|
||||
total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0
|
||||
@@ -573,7 +578,6 @@ class TPUModelRunner:
|
||||
|
||||
# Prepare inputs
|
||||
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
@@ -764,26 +768,21 @@ class TPUModelRunner:
|
||||
logger.info("Compiling the model with different input shapes.")
|
||||
|
||||
start = time.perf_counter()
|
||||
num_tokens = 16
|
||||
while True:
|
||||
for num_tokens in self.num_tokens_paddings:
|
||||
logger.info(" -- num_tokens: %d", num_tokens)
|
||||
self._dummy_run(self.kv_caches, num_tokens)
|
||||
xm.mark_step()
|
||||
if num_tokens >= self.max_num_tokens:
|
||||
break
|
||||
num_tokens *= 2
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
|
||||
logger.info("Compiling sampling with different input shapes.")
|
||||
start = time.perf_counter()
|
||||
num_tokens = 16
|
||||
hsize = self.model_config.get_hidden_size()
|
||||
device = self.device
|
||||
# Compile sampling step for different model+sampler outputs in bucketed
|
||||
# n_tokens x max_num_reqs. Graph is really small so this is fine.
|
||||
while True:
|
||||
for num_tokens in self.num_tokens_paddings:
|
||||
num_reqs_to_sample = MIN_NUM_SEQS
|
||||
dummy_hidden = torch.randn((num_tokens, hsize),
|
||||
device=device,
|
||||
@@ -805,9 +804,6 @@ class TPUModelRunner:
|
||||
if num_reqs_to_sample >= self.max_num_reqs:
|
||||
break
|
||||
num_reqs_to_sample *= 2
|
||||
if num_tokens >= self.max_num_tokens:
|
||||
break
|
||||
num_tokens *= 2
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
@@ -939,12 +935,33 @@ def _get_padded_number(n: int, multiple: int) -> int:
|
||||
return ((n + multiple - 1) // multiple) * multiple
|
||||
|
||||
|
||||
def _get_padded_token_len(x: int) -> int:
|
||||
if x <= 16:
|
||||
return 16
|
||||
return 1 << (x - 1).bit_length()
|
||||
|
||||
|
||||
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
|
||||
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
|
||||
return min(res, upper_limit)
|
||||
|
||||
|
||||
def _get_paddings(min_token_size: int, max_token_size: int,
|
||||
padding_gap: int) -> list[int]:
|
||||
"""Generate a list of padding size, starting from min_token_size,
|
||||
ending with a number that can cover max_token_size
|
||||
first increase the size to twice,
|
||||
then increase the padding size by padding_gap.
|
||||
"""
|
||||
paddings = []
|
||||
num = min_token_size
|
||||
while num <= padding_gap:
|
||||
paddings.append(num)
|
||||
num *= 2
|
||||
num //= 2
|
||||
while num < max_token_size:
|
||||
num += padding_gap
|
||||
paddings.append(num)
|
||||
return paddings
|
||||
|
||||
|
||||
def _get_padded_token_len(paddings: list[int], x: int) -> int:
|
||||
"""Return the first element in paddings list greater or equal to x.
|
||||
"""
|
||||
index = bisect.bisect_left(paddings, x)
|
||||
assert index < len(paddings)
|
||||
return paddings[index]
|
||||
|
||||
Reference in New Issue
Block a user