[Misc] LoRA + Chunked Prefill (#9057)
This commit is contained in:
@@ -1698,7 +1698,8 @@ class LoRAConfig:
|
||||
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||
# If the feature combo become valid
|
||||
if scheduler_config.chunked_prefill_enabled:
|
||||
raise ValueError("LoRA is not supported with chunked prefill yet.")
|
||||
logger.warning("LoRA with chunked prefill is still experimental "
|
||||
"and may be unstable.")
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -166,9 +166,18 @@ class SchedulerOutputs:
|
||||
and not self.blocks_to_swap_out and not self.blocks_to_copy)
|
||||
|
||||
def _sort_by_lora_ids(self):
|
||||
self.scheduled_seq_groups = sorted(
|
||||
self.scheduled_seq_groups,
|
||||
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
|
||||
assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups)
|
||||
|
||||
def key_fn(group: ScheduledSequenceGroup):
|
||||
key = (group.seq_group.lora_int_id, group.seq_group.request_id)
|
||||
if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups):
|
||||
# Sort sequence groups so that all prefills come before all
|
||||
# decodes as required by chunked prefill.
|
||||
return (not group.seq_group.is_prefill(), *key)
|
||||
return key
|
||||
|
||||
self.scheduled_seq_groups = sorted(self.scheduled_seq_groups,
|
||||
key=key_fn)
|
||||
|
||||
@property
|
||||
def lora_requests(self) -> Set[LoRARequest]:
|
||||
|
||||
@@ -622,11 +622,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
inter_data.lora_requests.add(seq_group_metadata.lora_request)
|
||||
query_len = inter_data.query_lens[seq_idx]
|
||||
inter_data.lora_index_mapping.append([lora_id] * query_len)
|
||||
inter_data.lora_prompt_mapping.append(
|
||||
[lora_id] *
|
||||
(query_len if seq_group_metadata.sampling_params
|
||||
and seq_group_metadata.sampling_params.prompt_logprobs is not None
|
||||
else 1))
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
if sampling_params and sampling_params.prompt_logprobs is not None:
|
||||
inter_data.lora_prompt_mapping.append([lora_id] * query_len)
|
||||
elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample:
|
||||
inter_data.lora_prompt_mapping.append([lora_id])
|
||||
else:
|
||||
inter_data.lora_prompt_mapping.append([])
|
||||
|
||||
def _compute_prompt_adapter_input(
|
||||
self, inter_data: InterDataForSeqGroup,
|
||||
|
||||
Reference in New Issue
Block a user