[Hardware][TPU][V1] Multi-LoRA implementation for the V1 TPU backend (#14238)
Signed-off-by: Akshat Tripathi <akshat@krai.ai> Signed-off-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
@@ -39,6 +39,7 @@ from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from .utils import sanity_check_mm_encoder_outputs
|
||||
|
||||
@@ -90,7 +91,7 @@ MIN_NUM_SEQS = 8
|
||||
# The dummy_run should be comprehensive, ensuring all potential input shapes and
|
||||
# branch predictions are included as subgraph inputs to facilitate
|
||||
# pre-compilation.
|
||||
class TPUModelRunner:
|
||||
class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -568,6 +569,17 @@ class TPUModelRunner:
|
||||
self.device)
|
||||
seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)
|
||||
|
||||
if self.lora_config is not None:
|
||||
# We need to respect padding when activating LoRA adapters
|
||||
padded_num_scheduled_tokens_per_req = np.copy(
|
||||
num_scheduled_tokens_per_req
|
||||
) # Copying to avoid accidental state corruption bugs
|
||||
padded_num_scheduled_tokens_per_req[-1] += \
|
||||
padded_total_num_scheduled_tokens - total_num_scheduled_tokens
|
||||
|
||||
self.set_active_loras(self.input_batch,
|
||||
padded_num_scheduled_tokens_per_req)
|
||||
|
||||
attn_metadata = PallasMetadata(
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=block_tables,
|
||||
@@ -907,6 +919,11 @@ class TPUModelRunner:
|
||||
"get_tensor_model_parallel_rank",
|
||||
return_value=xm_tp_rank):
|
||||
model = get_model(vllm_config=self.vllm_config)
|
||||
if self.lora_config is not None:
|
||||
model = self.load_lora_model(model, self.model_config,
|
||||
self.scheduler_config,
|
||||
self.lora_config, self.device)
|
||||
|
||||
# Sync all pending XLA execution during model initialization and weight
|
||||
# loading.
|
||||
xm.mark_step()
|
||||
@@ -970,7 +987,10 @@ class TPUModelRunner:
|
||||
for layer_name in layer_names
|
||||
}
|
||||
|
||||
with set_forward_context(per_layer_attn_metadata, self.vllm_config, 0):
|
||||
with self.maybe_dummy_run_with_lora(
|
||||
self.lora_config,
|
||||
np.array([num_tokens], dtype=np.int32)), set_forward_context(
|
||||
per_layer_attn_metadata, self.vllm_config, 0):
|
||||
out = self.model(input_ids=input_ids,
|
||||
positions=position_ids,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -82,6 +83,10 @@ class TPUWorker:
|
||||
if self.model_config.seed is None:
|
||||
self.model_config.seed = 0
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
raise NotImplementedError(
|
||||
"The V1 TPU backend doesn't support LoRA serving")
|
||||
|
||||
def init_device(self):
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
||||
@@ -211,6 +216,9 @@ class TPUWorker:
|
||||
else:
|
||||
xp.stop_trace()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model_runner.load_model()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user