[Misc] LoRA + Chunked Prefill (#9057)

This commit is contained in:
Aurick Qiao
2024-12-10 21:09:20 -05:00
committed by GitHub
parent 9a93973708
commit d5c5154fcf
12 changed files with 49 additions and 20 deletions

View File

@@ -1698,7 +1698,8 @@ class LoRAConfig:
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
if scheduler_config.chunked_prefill_enabled:
raise ValueError("LoRA is not supported with chunked prefill yet.")
logger.warning("LoRA with chunked prefill is still experimental "
"and may be unstable.")
@dataclass

View File

@@ -166,9 +166,18 @@ class SchedulerOutputs:
and not self.blocks_to_swap_out and not self.blocks_to_copy)
def _sort_by_lora_ids(self):
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups)
def key_fn(group: ScheduledSequenceGroup):
key = (group.seq_group.lora_int_id, group.seq_group.request_id)
if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups):
# Sort sequence groups so that all prefills come before all
# decodes as required by chunked prefill.
return (not group.seq_group.is_prefill(), *key)
return key
self.scheduled_seq_groups = sorted(self.scheduled_seq_groups,
key=key_fn)
@property
def lora_requests(self) -> Set[LoRARequest]:

View File

@@ -622,11 +622,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data.lora_requests.add(seq_group_metadata.lora_request)
query_len = inter_data.query_lens[seq_idx]
inter_data.lora_index_mapping.append([lora_id] * query_len)
inter_data.lora_prompt_mapping.append(
[lora_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs is not None
else 1))
sampling_params = seq_group_metadata.sampling_params
if sampling_params and sampling_params.prompt_logprobs is not None:
inter_data.lora_prompt_mapping.append([lora_id] * query_len)
elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample:
inter_data.lora_prompt_mapping.append([lora_id])
else:
inter_data.lora_prompt_mapping.append([])
def _compute_prompt_adapter_input(
self, inter_data: InterDataForSeqGroup,