[Core] Add Lora Support to Beam Search (#18346)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
@@ -65,6 +65,7 @@ class EngineClient(ABC):
|
||||
prompt: PromptType,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
|
||||
beam_width = params.beam_width
|
||||
@@ -106,27 +107,31 @@ class EngineClient(ABC):
|
||||
cum_logprob=0,
|
||||
logprobs=[],
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
lora_request=lora_request)
|
||||
]
|
||||
completed = []
|
||||
|
||||
for _ in range(max_tokens):
|
||||
prompts_batch = [
|
||||
prompts_batch, lora_req_batch = zip(*[(
|
||||
TokensPrompt(prompt_token_ids=beam.tokens,
|
||||
multi_modal_data=beam.multi_modal_data,
|
||||
mm_processor_kwargs=beam.mm_processor_kwargs)
|
||||
for beam in all_beams
|
||||
]
|
||||
mm_processor_kwargs=beam.mm_processor_kwargs),
|
||||
beam.lora_request,
|
||||
) for beam in all_beams])
|
||||
|
||||
tasks = []
|
||||
|
||||
request_id = f"beam_search-{random_uuid()}"
|
||||
for i, individual_prompt in enumerate(prompts_batch):
|
||||
for i, (individual_prompt,
|
||||
lora_req) in enumerate(zip(prompts_batch, lora_req_batch)):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
task = asyncio.create_task(
|
||||
collect_from_async_generator(
|
||||
self.generate(individual_prompt, beam_search_params,
|
||||
request_id_item)))
|
||||
self.generate(individual_prompt,
|
||||
beam_search_params,
|
||||
request_id_item,
|
||||
lora_request=lora_req)))
|
||||
tasks.append(task)
|
||||
|
||||
output = await asyncio.gather(*tasks)
|
||||
@@ -159,6 +164,7 @@ class EngineClient(ABC):
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs +
|
||||
[logprobs],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob,
|
||||
multi_modal_data=current_beam.
|
||||
|
||||
Reference in New Issue
Block a user