[Hardware][TPU][V1] Multi-LoRA implementation for the V1 TPU backend (#14238)

Signed-off-by: Akshat Tripathi <akshat@krai.ai>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Akshat Tripathi
2025-05-07 21:28:47 +01:00
committed by GitHub
parent db593aa67f
commit c20ef40fd0
19 changed files with 929 additions and 46 deletions

View File

@@ -39,6 +39,7 @@ from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from .utils import sanity_check_mm_encoder_outputs
@@ -90,7 +91,7 @@ MIN_NUM_SEQS = 8
# The dummy_run should be comprehensive, ensuring all potential input shapes and
# branch predictions are included as subgraph inputs to facilitate
# pre-compilation.
class TPUModelRunner:
class TPUModelRunner(LoRAModelRunnerMixin):
def __init__(
self,
@@ -568,6 +569,17 @@ class TPUModelRunner:
self.device)
seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)
if self.lora_config is not None:
# We need to respect padding when activating LoRA adapters
padded_num_scheduled_tokens_per_req = np.copy(
num_scheduled_tokens_per_req
) # Copying to avoid accidental state corruption bugs
padded_num_scheduled_tokens_per_req[-1] += \
padded_total_num_scheduled_tokens - total_num_scheduled_tokens
self.set_active_loras(self.input_batch,
padded_num_scheduled_tokens_per_req)
attn_metadata = PallasMetadata(
slot_mapping=slot_mapping,
block_tables=block_tables,
@@ -907,6 +919,11 @@ class TPUModelRunner:
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
model = get_model(vllm_config=self.vllm_config)
if self.lora_config is not None:
model = self.load_lora_model(model, self.model_config,
self.scheduler_config,
self.lora_config, self.device)
# Sync all pending XLA execution during model initialization and weight
# loading.
xm.mark_step()
@@ -970,7 +987,10 @@ class TPUModelRunner:
for layer_name in layer_names
}
with set_forward_context(per_layer_attn_metadata, self.vllm_config, 0):
with self.maybe_dummy_run_with_lora(
self.lora_config,
np.array([num_tokens], dtype=np.int32)), set_forward_context(
per_layer_attn_metadata, self.vllm_config, 0):
out = self.model(input_ids=input_ids,
positions=position_ids,
inputs_embeds=inputs_embeds)

View File

@@ -15,6 +15,7 @@ from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
@@ -82,6 +83,10 @@ class TPUWorker:
if self.model_config.seed is None:
self.model_config.seed = 0
if vllm_config.lora_config is not None:
raise NotImplementedError(
"The V1 TPU backend doesn't support LoRA serving")
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
@@ -211,6 +216,9 @@ class TPUWorker:
else:
xp.stop_trace()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def load_model(self) -> None:
self.model_runner.load_model()