[core] add bucket padding to tpu_model_runner (#14995)

Signed-off-by: Chenyaaang <llccyy1212@gmail.com>
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
Co-authored-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
Chenyaaang
2025-03-25 14:27:22 -07:00
committed by GitHub
parent 082ab86f5f
commit ac3cd6e83c
3 changed files with 63 additions and 19 deletions

View File

@@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import bisect
import time
from typing import TYPE_CHECKING, Optional, cast
from unittest.mock import patch
@@ -170,6 +171,10 @@ class TPUModelRunner:
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
self.num_tokens_paddings = _get_paddings(
min_token_size=16,
max_token_size=self.max_num_tokens,
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
"""Update the cached states and the persistent batch with the scheduler
@@ -422,7 +427,7 @@ class TPUModelRunner:
# Do the padding and copy the tensors to the TPU.
padded_total_num_scheduled_tokens = _get_padded_token_len(
total_num_scheduled_tokens)
self.num_tokens_paddings, total_num_scheduled_tokens)
# Zero out to avoid spurious values from prev iteration (last cp chunk)
self.input_ids_cpu[
total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0
@@ -573,7 +578,6 @@ class TPUModelRunner:
# Prepare inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
if self.is_multimodal_model:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
@@ -764,26 +768,21 @@ class TPUModelRunner:
logger.info("Compiling the model with different input shapes.")
start = time.perf_counter()
num_tokens = 16
while True:
for num_tokens in self.num_tokens_paddings:
logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(self.kv_caches, num_tokens)
xm.mark_step()
if num_tokens >= self.max_num_tokens:
break
num_tokens *= 2
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
logger.info("Compiling sampling with different input shapes.")
start = time.perf_counter()
num_tokens = 16
hsize = self.model_config.get_hidden_size()
device = self.device
# Compile sampling step for different model+sampler outputs in bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
while True:
for num_tokens in self.num_tokens_paddings:
num_reqs_to_sample = MIN_NUM_SEQS
dummy_hidden = torch.randn((num_tokens, hsize),
device=device,
@@ -805,9 +804,6 @@ class TPUModelRunner:
if num_reqs_to_sample >= self.max_num_reqs:
break
num_reqs_to_sample *= 2
if num_tokens >= self.max_num_tokens:
break
num_tokens *= 2
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
@@ -939,12 +935,33 @@ def _get_padded_number(n: int, multiple: int) -> int:
return ((n + multiple - 1) // multiple) * multiple
def _get_padded_token_len(x: int) -> int:
if x <= 16:
return 16
return 1 << (x - 1).bit_length()
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
return min(res, upper_limit)
def _get_paddings(min_token_size: int, max_token_size: int,
padding_gap: int) -> list[int]:
"""Generate a list of padding size, starting from min_token_size,
ending with a number that can cover max_token_size
first increase the size to twice,
then increase the padding size by padding_gap.
"""
paddings = []
num = min_token_size
while num <= padding_gap:
paddings.append(num)
num *= 2
num //= 2
while num < max_token_size:
num += padding_gap
paddings.append(num)
return paddings
def _get_padded_token_len(paddings: list[int], x: int) -> int:
"""Return the first element in paddings list greater or equal to x.
"""
index = bisect.bisect_left(paddings, x)
assert index < len(paddings)
return paddings[index]