[Hardware][TPU][V1] Multi-LoRA Optimisations for the V1 TPU backend (#15655)
Signed-off-by: Akshat Tripathi <akshat@krai.ai> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: xihajun <junfan@krai.ai> Signed-off-by: Jorge de Freitas <jorge.de-freitas22@imperial.ac.uk> Signed-off-by: Jorge de Freitas <jorge@krai.ai> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: xihajun <junfan@krai.ai> Co-authored-by: Jorge de Freitas <jorge.de-freitas22@imperial.ac.uk> Co-authored-by: Jorge de Freitas <jorge@krai.ai>
This commit is contained in:
@@ -80,8 +80,38 @@ class LoRAModelRunnerMixin:
|
||||
lora_requests)
|
||||
|
||||
@contextmanager
|
||||
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
|
||||
num_scheduled_tokens: np.ndarray):
|
||||
def maybe_setup_dummy_loras(self, lora_config):
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
# __enter__ code
|
||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
|
||||
num_loras = lora_config.max_loras
|
||||
|
||||
# Make dummy lora requests
|
||||
lora_requests: set[LoRARequest] = {
|
||||
LoRARequest(lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path")
|
||||
for lora_id in range(1, num_loras + 1)
|
||||
}
|
||||
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
# Add the dummy LoRAs here so _set_active_loras doesn't try to
|
||||
# load from disk.
|
||||
for lr in lora_requests:
|
||||
self.lora_manager.add_dummy_lora(
|
||||
lr, rank=self.LORA_WARMUP_RANK)
|
||||
|
||||
yield
|
||||
|
||||
# __exit__ code
|
||||
self.lora_manager.remove_all_adapters()
|
||||
|
||||
@contextmanager
|
||||
def maybe_select_dummy_loras(self, lora_config: LoRAConfig,
|
||||
num_scheduled_tokens: np.ndarray):
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
@@ -108,21 +138,18 @@ class LoRAModelRunnerMixin:
|
||||
for lora_id in range(1, num_loras + 1)
|
||||
}
|
||||
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
# Add the dummy LoRAs here so _set_active_loras doesn't try to
|
||||
# load from disk.
|
||||
for lr in lora_requests:
|
||||
self.lora_manager.add_dummy_lora(
|
||||
lr, rank=self.LORA_WARMUP_RANK)
|
||||
self._set_active_loras(tuple(prompt_lora_mapping),
|
||||
tuple(token_lora_mapping), lora_requests)
|
||||
|
||||
self._set_active_loras(tuple(prompt_lora_mapping),
|
||||
tuple(token_lora_mapping),
|
||||
lora_requests)
|
||||
yield
|
||||
|
||||
yield
|
||||
|
||||
# __exit__ code
|
||||
self.lora_manager.remove_all_adapters()
|
||||
@contextmanager
|
||||
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
|
||||
num_scheduled_tokens: np.ndarray):
|
||||
with self.maybe_setup_dummy_loras(
|
||||
lora_config), self.maybe_select_dummy_loras(
|
||||
lora_config, num_scheduled_tokens):
|
||||
yield
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if not self.lora_manager:
|
||||
|
||||
@@ -20,6 +20,7 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
||||
@@ -152,6 +153,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
self.vocab_size = model_config.get_vocab_size()
|
||||
|
||||
if self.lora_config is not None:
|
||||
self.vocab_size += self.lora_config.lora_extra_vocab_size
|
||||
|
||||
# Multi-modal data support
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.uses_mrope = model_config.uses_mrope
|
||||
@@ -591,6 +595,17 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
|
||||
logits_indices = logits_indices.to(self.device)
|
||||
|
||||
if self.lora_config is not None:
|
||||
# We need to respect padding when activating LoRA adapters
|
||||
padded_num_scheduled_tokens_per_req = np.copy(
|
||||
num_scheduled_tokens_per_req
|
||||
) # Copying to avoid accidental state corruption bugs
|
||||
padded_num_scheduled_tokens_per_req[-1] += \
|
||||
padded_total_num_scheduled_tokens - total_num_scheduled_tokens
|
||||
|
||||
self.set_active_loras(self.input_batch,
|
||||
padded_num_scheduled_tokens_per_req)
|
||||
|
||||
layer_names = get_layers_from_vllm_config(self.vllm_config,
|
||||
Attention).keys()
|
||||
per_layer_attn_metadata = {
|
||||
@@ -916,6 +931,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
model = self.load_lora_model(model, self.model_config,
|
||||
self.scheduler_config,
|
||||
self.lora_config, self.device)
|
||||
replace_set_lora(model)
|
||||
|
||||
# Sync all pending XLA execution during model initialization and weight
|
||||
# loading.
|
||||
@@ -980,7 +996,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
for layer_name in layer_names
|
||||
}
|
||||
|
||||
with self.maybe_dummy_run_with_lora(
|
||||
with self.maybe_select_dummy_loras(
|
||||
self.lora_config,
|
||||
np.array([num_tokens], dtype=np.int32)), set_forward_context(
|
||||
per_layer_attn_metadata, self.vllm_config, 0):
|
||||
@@ -989,6 +1005,13 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
inputs_embeds=inputs_embeds)
|
||||
self._hidden_states_dtype = out.dtype
|
||||
|
||||
def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping,
|
||||
lora_requests) -> None:
|
||||
xm.mark_step() # Captures input updates
|
||||
super()._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
||||
lora_requests)
|
||||
xm.mark_step() # Captures metadata updates
|
||||
|
||||
def _precompile_mm_encoder(self) -> None:
|
||||
# Pre-compile MM encoder for all supported data modalities.
|
||||
hf_config = self.vllm_config.model_config.hf_config
|
||||
@@ -1151,7 +1174,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
generate_params_if_all_greedy,
|
||||
))
|
||||
sampling_metadata.all_greedy = all_greedy
|
||||
self.sample_from_logits(dummy_logits, sampling_metadata)
|
||||
with self.maybe_select_dummy_loras(
|
||||
self.lora_config, np.array([num_reqs],
|
||||
dtype=np.int32)):
|
||||
self.sample_from_logits(dummy_logits, sampling_metadata)
|
||||
logger.info(" -- num_seqs: %d", num_reqs)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
@@ -1167,7 +1193,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=self._hidden_states_dtype)
|
||||
dummy_tokens = torch.zeros((num_reqs, 1),
|
||||
dtype=torch.int64).to(self.device)
|
||||
self.gather_logprobs(dummy_logits, dummy_tokens)
|
||||
with self.maybe_select_dummy_loras(
|
||||
self.lora_config, np.array([num_reqs], dtype=np.int32)):
|
||||
self.gather_logprobs(dummy_logits, dummy_tokens)
|
||||
logger.info(" -- num_seqs: %d", num_reqs)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
@@ -1178,13 +1206,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
"""
|
||||
Precompile all the subgraphs with possible input shapes.
|
||||
"""
|
||||
self._precompile_mm_encoder()
|
||||
self._precompile_backbone()
|
||||
self._precompile_select_hidden_states()
|
||||
self._precompile_compute_logits()
|
||||
self._precompile_structured_decoding()
|
||||
self._precompile_sample_from_logits()
|
||||
self._precompile_gather_logprobs()
|
||||
with self.maybe_setup_dummy_loras(self.lora_config):
|
||||
self._precompile_mm_encoder()
|
||||
self._precompile_backbone()
|
||||
self._precompile_select_hidden_states()
|
||||
self._precompile_compute_logits()
|
||||
self._precompile_structured_decoding()
|
||||
self._precompile_sample_from_logits()
|
||||
self._precompile_gather_logprobs()
|
||||
|
||||
def profile_run(
|
||||
self,
|
||||
@@ -1467,11 +1496,11 @@ def _get_token_paddings(min_token_size: int, max_token_size: int,
|
||||
padding_gap: int) -> list[int]:
|
||||
"""Generate a list of padding size, starting from min_token_size,
|
||||
ending with a number that can cover max_token_size
|
||||
|
||||
|
||||
If padding_gap == 0 then:
|
||||
increase 2X each time (exponential)
|
||||
else:
|
||||
first increase the size to twice,
|
||||
first increase the size to twice,
|
||||
then increase the padding size by padding_gap.
|
||||
"""
|
||||
# assert min_token_size is power of 2
|
||||
@@ -1508,3 +1537,32 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
|
||||
index = bisect.bisect_left(paddings, x)
|
||||
assert index < len(paddings)
|
||||
return paddings[index]
|
||||
|
||||
|
||||
def replace_set_lora(model):
|
||||
|
||||
def _tpu_set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# TODO: The integer index leads to a recompilation, but converting it
|
||||
# to a tensor doesn't seem to work anymore. This might be fixed with a
|
||||
# later release of torch_xla.
|
||||
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias)
|
||||
xm.mark_step()
|
||||
|
||||
def _tpu_reset_lora(self, index: int):
|
||||
self._original_reset_lora(index)
|
||||
xm.mark_step()
|
||||
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, BaseLayerWithLoRA):
|
||||
module._original_set_lora = module.set_lora
|
||||
module._original_reset_lora = module.reset_lora
|
||||
module.set_lora = _tpu_set_lora.__get__(module, module.__class__)
|
||||
module.reset_lora = _tpu_reset_lora.__get__(
|
||||
module, module.__class__)
|
||||
|
||||
@@ -83,10 +83,6 @@ class TPUWorker:
|
||||
if self.model_config.seed is None:
|
||||
self.model_config.seed = 0
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
raise NotImplementedError(
|
||||
"The V1 TPU backend doesn't support LoRA serving")
|
||||
|
||||
def init_device(self):
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
||||
@@ -166,7 +162,8 @@ class TPUWorker:
|
||||
runner_kv_caches)
|
||||
|
||||
# `max_num_tokens >= max_num_batched_tokens` due to padding.
|
||||
self.model_runner.profile_run(self.model_runner.max_num_tokens)
|
||||
with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
|
||||
self.model_runner.profile_run(self.model_runner.max_num_tokens)
|
||||
|
||||
# Synchronize before measuring the memory usage.
|
||||
xm.wait_device_ops()
|
||||
|
||||
Reference in New Issue
Block a user