[Core][Bugfix] Fix Offline MM Beam Search (#16390)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Alex Brooks
2025-04-14 20:33:02 -06:00
committed by GitHub
parent d2020acac7
commit 6b40996ae8
4 changed files with 140 additions and 30 deletions

View File

@@ -38,9 +38,18 @@ class BeamSearchOutput:
class BeamSearchInstance:
def __init__(self, prompt_tokens: list[int]):
def __init__(
self,
prompt_tokens: list[int],
logprobs: Optional[list[dict[int, Logprob]]] = None,
**kwargs,
):
self.beams: list[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
BeamSearchSequence(
tokens=prompt_tokens,
logprobs=[] if logprobs is None else list(logprobs),
**kwargs,
)
]
self.completed: list[BeamSearchSequence] = []

View File

@@ -536,15 +536,18 @@ class LLM:
tokenizer.eos_token_id,
length_penalty)
# TODO - fix handling of multimodal data for beam search; we pass it
# through in the async version on the abstract EngineClient, but not
# here.
if any("multi_modal_data" in prompt
and prompt["multi_modal_data"] is not None
for prompt in prompts):
logger.warning(
"Multimodal data appears to have been provided, but is not"
" currently being passed through in LLM.beam_search()!")
def create_tokens_prompt_from_beam(
beam: BeamSearchSequence) -> TokensPrompt:
token_prompt_kwargs: TokensPrompt = {
"prompt_token_ids": beam.tokens
}
if beam.multi_modal_data is not None:
token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data
if beam.mm_processor_kwargs is not None:
token_prompt_kwargs[
"mm_processor_kwargs"] = beam.mm_processor_kwargs
return TokensPrompt(**token_prompt_kwargs)
tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
@@ -556,11 +559,20 @@ class LLM:
instances: list[BeamSearchInstance] = []
for prompt in prompts:
# Add multimodal processor kwargs & data
mm_kwargs = {}
if "multi_modal_data" in prompt:
mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
if "mm_processor_kwargs" in prompt:
mm_kwargs["mm_processor_kwargs"] = prompt[
"mm_processor_kwargs"]
if is_token_prompt(prompt):
prompt_tokens = prompt["prompt_token_ids"]
else:
prompt_tokens = tokenizer.encode(prompt["prompt"])
instances.append(BeamSearchInstance(prompt_tokens))
instances.append(
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
for _ in range(max_tokens):
all_beams: list[BeamSearchSequence] = list(
@@ -575,8 +587,7 @@ class LLM:
break
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
create_tokens_prompt_from_beam(beam) for beam in all_beams
]
# only runs for one step
@@ -602,7 +613,10 @@ class LLM:
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.
mm_processor_kwargs)
if token_id == tokenizer.eos_token_id and \
not ignore_eos: