[Hardware][TPU][V1] Multi-LoRA Optimisations for the V1 TPU backend (#15655)

Signed-off-by: Akshat Tripathi <akshat@krai.ai>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: xihajun <junfan@krai.ai>
Signed-off-by: Jorge de Freitas <jorge.de-freitas22@imperial.ac.uk>
Signed-off-by: Jorge de Freitas <jorge@krai.ai>
Co-authored-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: xihajun <junfan@krai.ai>
Co-authored-by: Jorge de Freitas <jorge.de-freitas22@imperial.ac.uk>
Co-authored-by: Jorge de Freitas <jorge@krai.ai>
This commit is contained in:
Akshat Tripathi
2025-05-28 20:59:09 +01:00
committed by GitHub
parent a09c7ca9f2
commit 643622ba46
9 changed files with 325 additions and 334 deletions

View File

@@ -80,8 +80,38 @@ class LoRAModelRunnerMixin:
lora_requests)
@contextmanager
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
def maybe_setup_dummy_loras(self, lora_config):
if lora_config is None:
yield
else:
# __enter__ code
assert self.lora_manager is not None, "LoRA is not enabled"
num_loras = lora_config.max_loras
# Make dummy lora requests
lora_requests: set[LoRARequest] = {
LoRARequest(lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path")
for lora_id in range(1, num_loras + 1)
}
with self.lora_manager.dummy_lora_cache():
# Add the dummy LoRAs here so _set_active_loras doesn't try to
# load from disk.
for lr in lora_requests:
self.lora_manager.add_dummy_lora(
lr, rank=self.LORA_WARMUP_RANK)
yield
# __exit__ code
self.lora_manager.remove_all_adapters()
@contextmanager
def maybe_select_dummy_loras(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
if lora_config is None:
yield
else:
@@ -108,21 +138,18 @@ class LoRAModelRunnerMixin:
for lora_id in range(1, num_loras + 1)
}
with self.lora_manager.dummy_lora_cache():
# Add the dummy LoRAs here so _set_active_loras doesn't try to
# load from disk.
for lr in lora_requests:
self.lora_manager.add_dummy_lora(
lr, rank=self.LORA_WARMUP_RANK)
self._set_active_loras(tuple(prompt_lora_mapping),
tuple(token_lora_mapping), lora_requests)
self._set_active_loras(tuple(prompt_lora_mapping),
tuple(token_lora_mapping),
lora_requests)
yield
yield
# __exit__ code
self.lora_manager.remove_all_adapters()
@contextmanager
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
with self.maybe_setup_dummy_loras(
lora_config), self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens):
yield
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:

View File

@@ -20,6 +20,7 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
@@ -152,6 +153,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.hidden_size = model_config.get_hidden_size()
self.vocab_size = model_config.get_vocab_size()
if self.lora_config is not None:
self.vocab_size += self.lora_config.lora_extra_vocab_size
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
@@ -591,6 +595,17 @@ class TPUModelRunner(LoRAModelRunnerMixin):
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
logits_indices = logits_indices.to(self.device)
if self.lora_config is not None:
# We need to respect padding when activating LoRA adapters
padded_num_scheduled_tokens_per_req = np.copy(
num_scheduled_tokens_per_req
) # Copying to avoid accidental state corruption bugs
padded_num_scheduled_tokens_per_req[-1] += \
padded_total_num_scheduled_tokens - total_num_scheduled_tokens
self.set_active_loras(self.input_batch,
padded_num_scheduled_tokens_per_req)
layer_names = get_layers_from_vllm_config(self.vllm_config,
Attention).keys()
per_layer_attn_metadata = {
@@ -916,6 +931,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
model = self.load_lora_model(model, self.model_config,
self.scheduler_config,
self.lora_config, self.device)
replace_set_lora(model)
# Sync all pending XLA execution during model initialization and weight
# loading.
@@ -980,7 +996,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
for layer_name in layer_names
}
with self.maybe_dummy_run_with_lora(
with self.maybe_select_dummy_loras(
self.lora_config,
np.array([num_tokens], dtype=np.int32)), set_forward_context(
per_layer_attn_metadata, self.vllm_config, 0):
@@ -989,6 +1005,13 @@ class TPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds=inputs_embeds)
self._hidden_states_dtype = out.dtype
def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping,
lora_requests) -> None:
xm.mark_step() # Captures input updates
super()._set_active_loras(prompt_lora_mapping, token_lora_mapping,
lora_requests)
xm.mark_step() # Captures metadata updates
def _precompile_mm_encoder(self) -> None:
# Pre-compile MM encoder for all supported data modalities.
hf_config = self.vllm_config.model_config.hf_config
@@ -1151,7 +1174,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
generate_params_if_all_greedy,
))
sampling_metadata.all_greedy = all_greedy
self.sample_from_logits(dummy_logits, sampling_metadata)
with self.maybe_select_dummy_loras(
self.lora_config, np.array([num_reqs],
dtype=np.int32)):
self.sample_from_logits(dummy_logits, sampling_metadata)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
@@ -1167,7 +1193,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
dtype=self._hidden_states_dtype)
dummy_tokens = torch.zeros((num_reqs, 1),
dtype=torch.int64).to(self.device)
self.gather_logprobs(dummy_logits, dummy_tokens)
with self.maybe_select_dummy_loras(
self.lora_config, np.array([num_reqs], dtype=np.int32)):
self.gather_logprobs(dummy_logits, dummy_tokens)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
@@ -1178,13 +1206,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"""
Precompile all the subgraphs with possible input shapes.
"""
self._precompile_mm_encoder()
self._precompile_backbone()
self._precompile_select_hidden_states()
self._precompile_compute_logits()
self._precompile_structured_decoding()
self._precompile_sample_from_logits()
self._precompile_gather_logprobs()
with self.maybe_setup_dummy_loras(self.lora_config):
self._precompile_mm_encoder()
self._precompile_backbone()
self._precompile_select_hidden_states()
self._precompile_compute_logits()
self._precompile_structured_decoding()
self._precompile_sample_from_logits()
self._precompile_gather_logprobs()
def profile_run(
self,
@@ -1467,11 +1496,11 @@ def _get_token_paddings(min_token_size: int, max_token_size: int,
padding_gap: int) -> list[int]:
"""Generate a list of padding size, starting from min_token_size,
ending with a number that can cover max_token_size
If padding_gap == 0 then:
increase 2X each time (exponential)
else:
first increase the size to twice,
first increase the size to twice,
then increase the padding size by padding_gap.
"""
# assert min_token_size is power of 2
@@ -1508,3 +1537,32 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
index = bisect.bisect_left(paddings, x)
assert index < len(paddings)
return paddings[index]
def replace_set_lora(model):
def _tpu_set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
# TODO: The integer index leads to a recompilation, but converting it
# to a tensor doesn't seem to work anymore. This might be fixed with a
# later release of torch_xla.
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias)
xm.mark_step()
def _tpu_reset_lora(self, index: int):
self._original_reset_lora(index)
xm.mark_step()
for _, module in model.named_modules():
if isinstance(module, BaseLayerWithLoRA):
module._original_set_lora = module.set_lora
module._original_reset_lora = module.reset_lora
module.set_lora = _tpu_set_lora.__get__(module, module.__class__)
module.reset_lora = _tpu_reset_lora.__get__(
module, module.__class__)

View File

@@ -83,10 +83,6 @@ class TPUWorker:
if self.model_config.seed is None:
self.model_config.seed = 0
if vllm_config.lora_config is not None:
raise NotImplementedError(
"The V1 TPU backend doesn't support LoRA serving")
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
@@ -166,7 +162,8 @@ class TPUWorker:
runner_kv_caches)
# `max_num_tokens >= max_num_batched_tokens` due to padding.
self.model_runner.profile_run(self.model_runner.max_num_tokens)
with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
self.model_runner.profile_run(self.model_runner.max_num_tokens)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()