[Core][Bugfix] Fix Offline MM Beam Search (#16390)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -38,9 +38,18 @@ class BeamSearchOutput:
|
||||
|
||||
class BeamSearchInstance:
|
||||
|
||||
def __init__(self, prompt_tokens: list[int]):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_tokens: list[int],
|
||||
logprobs: Optional[list[dict[int, Logprob]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.beams: list[BeamSearchSequence] = [
|
||||
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
|
||||
BeamSearchSequence(
|
||||
tokens=prompt_tokens,
|
||||
logprobs=[] if logprobs is None else list(logprobs),
|
||||
**kwargs,
|
||||
)
|
||||
]
|
||||
self.completed: list[BeamSearchSequence] = []
|
||||
|
||||
|
||||
@@ -536,15 +536,18 @@ class LLM:
|
||||
tokenizer.eos_token_id,
|
||||
length_penalty)
|
||||
|
||||
# TODO - fix handling of multimodal data for beam search; we pass it
|
||||
# through in the async version on the abstract EngineClient, but not
|
||||
# here.
|
||||
if any("multi_modal_data" in prompt
|
||||
and prompt["multi_modal_data"] is not None
|
||||
for prompt in prompts):
|
||||
logger.warning(
|
||||
"Multimodal data appears to have been provided, but is not"
|
||||
" currently being passed through in LLM.beam_search()!")
|
||||
def create_tokens_prompt_from_beam(
|
||||
beam: BeamSearchSequence) -> TokensPrompt:
|
||||
token_prompt_kwargs: TokensPrompt = {
|
||||
"prompt_token_ids": beam.tokens
|
||||
}
|
||||
if beam.multi_modal_data is not None:
|
||||
token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data
|
||||
|
||||
if beam.mm_processor_kwargs is not None:
|
||||
token_prompt_kwargs[
|
||||
"mm_processor_kwargs"] = beam.mm_processor_kwargs
|
||||
return TokensPrompt(**token_prompt_kwargs)
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
# generate 2 * beam_width candidates at each step
|
||||
@@ -556,11 +559,20 @@ class LLM:
|
||||
instances: list[BeamSearchInstance] = []
|
||||
|
||||
for prompt in prompts:
|
||||
# Add multimodal processor kwargs & data
|
||||
mm_kwargs = {}
|
||||
if "multi_modal_data" in prompt:
|
||||
mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
|
||||
if "mm_processor_kwargs" in prompt:
|
||||
mm_kwargs["mm_processor_kwargs"] = prompt[
|
||||
"mm_processor_kwargs"]
|
||||
|
||||
if is_token_prompt(prompt):
|
||||
prompt_tokens = prompt["prompt_token_ids"]
|
||||
else:
|
||||
prompt_tokens = tokenizer.encode(prompt["prompt"])
|
||||
instances.append(BeamSearchInstance(prompt_tokens))
|
||||
instances.append(
|
||||
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
|
||||
|
||||
for _ in range(max_tokens):
|
||||
all_beams: list[BeamSearchSequence] = list(
|
||||
@@ -575,8 +587,7 @@ class LLM:
|
||||
break
|
||||
|
||||
prompts_batch = [
|
||||
TokensPrompt(prompt_token_ids=beam.tokens)
|
||||
for beam in all_beams
|
||||
create_tokens_prompt_from_beam(beam) for beam in all_beams
|
||||
]
|
||||
|
||||
# only runs for one step
|
||||
@@ -602,7 +613,10 @@ class LLM:
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob)
|
||||
logprob_obj.logprob,
|
||||
multi_modal_data=current_beam.multi_modal_data,
|
||||
mm_processor_kwargs=current_beam.
|
||||
mm_processor_kwargs)
|
||||
|
||||
if token_id == tokenizer.eos_token_id and \
|
||||
not ignore_eos:
|
||||
|
||||
Reference in New Issue
Block a user