[Core] Add Lora Support to Beam Search (#18346)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Alex Brooks
2025-05-28 09:58:24 -06:00
committed by GitHub
parent 6e4cea1cc5
commit 321331b8ae
7 changed files with 150 additions and 16 deletions

View File

@@ -65,6 +65,7 @@ class EngineClient(ABC):
prompt: PromptType,
request_id: str,
params: BeamSearchParams,
lora_request: Optional[LoRARequest] = None,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
@@ -106,27 +107,31 @@ class EngineClient(ABC):
cum_logprob=0,
logprobs=[],
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs)
mm_processor_kwargs=mm_processor_kwargs,
lora_request=lora_request)
]
completed = []
for _ in range(max_tokens):
prompts_batch = [
prompts_batch, lora_req_batch = zip(*[(
TokensPrompt(prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs)
for beam in all_beams
]
mm_processor_kwargs=beam.mm_processor_kwargs),
beam.lora_request,
) for beam in all_beams])
tasks = []
request_id = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
for i, (individual_prompt,
lora_req) in enumerate(zip(prompts_batch, lora_req_batch)):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
self.generate(individual_prompt,
beam_search_params,
request_id_item,
lora_request=lora_req)))
tasks.append(task)
output = await asyncio.gather(*tasks)
@@ -159,6 +164,7 @@ class EngineClient(ABC):
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs +
[logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
multi_modal_data=current_beam.