[Frontend] Add Support for MM Encoder/Decoder Beam Search (Offline) (#36153)
Signed-off-by: Alex Brooks <albrooks@redhat.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -843,7 +843,10 @@ class VllmRunner:
|
|||||||
|
|
||||||
def get_inputs(
|
def get_inputs(
|
||||||
self,
|
self,
|
||||||
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
prompts: list[str]
|
||||||
|
| list[torch.Tensor]
|
||||||
|
| list[list[int]]
|
||||||
|
| list[dict[str, Any]],
|
||||||
images: PromptImageInput | None = None,
|
images: PromptImageInput | None = None,
|
||||||
videos: PromptVideoInput | None = None,
|
videos: PromptVideoInput | None = None,
|
||||||
audios: PromptAudioInput | None = None,
|
audios: PromptAudioInput | None = None,
|
||||||
@@ -857,26 +860,32 @@ class VllmRunner:
|
|||||||
|
|
||||||
inputs = list[dict[str, Any]]()
|
inputs = list[dict[str, Any]]()
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
prompt_dict = dict[str, Any]()
|
# If we're passing an encoder/decoder prompt, we assume it
|
||||||
if isinstance(prompt, str):
|
# already contains the multimodal data in the prompt
|
||||||
prompt_dict["prompt"] = prompt
|
if isinstance(prompt, dict):
|
||||||
elif isinstance(prompt, list):
|
assert images is None and audios is None and videos is None
|
||||||
prompt_dict["prompt_token_ids"] = prompt
|
inputs.append(prompt.copy())
|
||||||
else:
|
else:
|
||||||
prompt_dict["prompt_embeds"] = prompt
|
prompt_dict = dict[str, Any]()
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt_dict["prompt"] = prompt
|
||||||
|
elif isinstance(prompt, list):
|
||||||
|
prompt_dict["prompt_token_ids"] = prompt
|
||||||
|
else:
|
||||||
|
prompt_dict["prompt_embeds"] = prompt
|
||||||
|
|
||||||
multi_modal_data = dict[str, Any]()
|
multi_modal_data = dict[str, Any]()
|
||||||
if images is not None and (image := images[i]) is not None:
|
if images is not None and (image := images[i]) is not None:
|
||||||
multi_modal_data["image"] = image
|
multi_modal_data["image"] = image
|
||||||
if videos is not None and (video := videos[i]) is not None:
|
if videos is not None and (video := videos[i]) is not None:
|
||||||
multi_modal_data["video"] = video
|
multi_modal_data["video"] = video
|
||||||
if audios is not None and (audio := audios[i]) is not None:
|
if audios is not None and (audio := audios[i]) is not None:
|
||||||
multi_modal_data["audio"] = audio
|
multi_modal_data["audio"] = audio
|
||||||
|
|
||||||
if multi_modal_data:
|
if multi_modal_data:
|
||||||
prompt_dict["multi_modal_data"] = multi_modal_data
|
prompt_dict["multi_modal_data"] = multi_modal_data
|
||||||
|
|
||||||
inputs.append(prompt_dict)
|
inputs.append(prompt_dict)
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|||||||
@@ -90,9 +90,9 @@ def run_test(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def input_audios() -> list[tuple[list[str], list[str], list[tuple[Any, int]]]]:
|
def resampled_assets() -> list[tuple[Any, int]]:
|
||||||
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
||||||
inputs = []
|
sampled_assets = []
|
||||||
for asset in audio_assets:
|
for asset in audio_assets:
|
||||||
audio, orig_sr = asset.audio_and_sample_rate
|
audio, orig_sr = asset.audio_and_sample_rate
|
||||||
# Resample to Whisper's expected sample rate (16kHz)
|
# Resample to Whisper's expected sample rate (16kHz)
|
||||||
@@ -100,8 +100,21 @@ def input_audios() -> list[tuple[list[str], list[str], list[tuple[Any, int]]]]:
|
|||||||
audio = librosa.resample(
|
audio = librosa.resample(
|
||||||
audio, orig_sr=orig_sr, target_sr=WHISPER_SAMPLE_RATE
|
audio, orig_sr=orig_sr, target_sr=WHISPER_SAMPLE_RATE
|
||||||
)
|
)
|
||||||
|
sampled_assets.append(
|
||||||
|
(audio, WHISPER_SAMPLE_RATE),
|
||||||
|
)
|
||||||
|
return sampled_assets
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def input_audios(
|
||||||
|
resampled_assets,
|
||||||
|
) -> list[tuple[list[str], list[str], list[tuple[Any, int]]]]:
|
||||||
|
inputs = []
|
||||||
|
# audio assets are resampled to WHISPER_SAMPLE_RATE
|
||||||
|
for audio_info in resampled_assets:
|
||||||
# vLLM prompts, HF prompts, audio inputs
|
# vLLM prompts, HF prompts, audio inputs
|
||||||
inputs.append(([VLLM_PROMPT], [HF_PROMPT], [(audio, WHISPER_SAMPLE_RATE)]))
|
inputs.append(([VLLM_PROMPT], [HF_PROMPT], [audio_info]))
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
@@ -111,6 +124,98 @@ def check_model_available(model: str) -> None:
|
|||||||
model_info.check_transformers_version(on_fail="skip")
|
model_info.check_transformers_version(on_fail="skip")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
|
@pytest.mark.parametrize("beam_width", [1, 2])
|
||||||
|
def test_beam_search_encoder_decoder(
|
||||||
|
monkeypatch,
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
beam_width: int,
|
||||||
|
resampled_assets,
|
||||||
|
) -> None:
|
||||||
|
"""Test beam search with encoder-decoder models (Whisper)."""
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
monkeypatch.setenv("VLLM_ROCM_USE_SKINNY_GEMM", "0")
|
||||||
|
|
||||||
|
model = "openai/whisper-large-v3-turbo"
|
||||||
|
check_model_available(model)
|
||||||
|
|
||||||
|
hf_prompts = [
|
||||||
|
"<|startoftranscript|>",
|
||||||
|
"<|startoftranscript|>",
|
||||||
|
]
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSpeechSeq2Seq) as hf_model:
|
||||||
|
hf_outputs = hf_model.generate_beam_search(
|
||||||
|
hf_prompts,
|
||||||
|
beam_width=beam_width,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
audios=resampled_assets,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test both explicit encoder/decoder prompts
|
||||||
|
vllm_prompts = [
|
||||||
|
# Implicit encoder/decoder prompt
|
||||||
|
{
|
||||||
|
"prompt": "<|startoftranscript|>",
|
||||||
|
"multi_modal_data": {"audio": resampled_assets[0]},
|
||||||
|
},
|
||||||
|
# Explicit encoder/decover prompt
|
||||||
|
{
|
||||||
|
"encoder_prompt": {
|
||||||
|
"prompt": "",
|
||||||
|
"multi_modal_data": {"audio": resampled_assets[1]},
|
||||||
|
},
|
||||||
|
"decoder_prompt": "<|startoftranscript|>",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
dtype="half",
|
||||||
|
max_model_len=448,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
max_num_seqs=4,
|
||||||
|
limit_mm_per_prompt={"audio": 2},
|
||||||
|
enforce_eager=True,
|
||||||
|
) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_beam_search(
|
||||||
|
vllm_prompts,
|
||||||
|
beam_width=beam_width,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(len(vllm_prompts)):
|
||||||
|
hf_output_ids, hf_output_texts = hf_outputs[i]
|
||||||
|
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
|
||||||
|
|
||||||
|
for j, (hf_text, vllm_text) in enumerate(
|
||||||
|
zip(hf_output_texts, vllm_output_texts)
|
||||||
|
):
|
||||||
|
print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:")
|
||||||
|
print(hf_text)
|
||||||
|
print(f">>>{j}-th vllm output:")
|
||||||
|
print(vllm_text)
|
||||||
|
|
||||||
|
# Check that we got the same number of beams
|
||||||
|
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||||
|
|
||||||
|
# For encoder-decoder models, we primarily want to verify that:
|
||||||
|
# 1. Beam search completes without errors
|
||||||
|
# 2. We get the expected number of beams
|
||||||
|
# 3. Outputs are reasonable (non-empty, diverse beams)
|
||||||
|
for j in range(len(vllm_output_ids)):
|
||||||
|
# Check that outputs are not empty
|
||||||
|
assert len(vllm_output_ids[j]) > 0, f"Prompt {i}, beam {j}: empty output"
|
||||||
|
# Check that decoded text is not empty
|
||||||
|
assert len(vllm_output_texts[j].strip()) > 0, (
|
||||||
|
f"Prompt {i}, beam {j}: empty text output"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_parse_language_detection_output():
|
def test_parse_language_detection_output():
|
||||||
"""Unit test for WhisperForConditionalGeneration.parse_language_detection_output.
|
"""Unit test for WhisperForConditionalGeneration.parse_language_detection_output.
|
||||||
|
|
||||||
|
|||||||
@@ -219,3 +219,7 @@ def test_beam_search_passes_multimodal_data(
|
|||||||
filtered_hf_output_ids = filtered_hf_output_ids[:-1]
|
filtered_hf_output_ids = filtered_hf_output_ids[:-1]
|
||||||
|
|
||||||
assert filtered_hf_output_ids == filtered_vllm_output_ids
|
assert filtered_hf_output_ids == filtered_vllm_output_ids
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: encoder/decoder tests are currently located under
|
||||||
|
# tests/models/multimodal/generation/test_whisper.py
|
||||||
|
|||||||
@@ -3,7 +3,8 @@
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from vllm.inputs import TokenInputs, token_inputs
|
from vllm.inputs import EncoderDecoderInputs, TokenInputs, token_inputs
|
||||||
|
from vllm.inputs.data import DecoderInputs
|
||||||
from vllm.logprobs import Logprob
|
from vllm.logprobs import Logprob
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal.inputs import MultiModalInputs, mm_inputs
|
from vllm.multimodal.inputs import MultiModalInputs, mm_inputs
|
||||||
@@ -17,9 +18,9 @@ class BeamSearchSequence:
|
|||||||
about to be returned to the user.
|
about to be returned to the user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
orig_prompt: TokenInputs | MultiModalInputs
|
orig_prompt: TokenInputs | MultiModalInputs | EncoderDecoderInputs
|
||||||
|
|
||||||
# The tokens include the prompt.
|
# NOTE: Tokens represents decoder tokens in the encoder / decoder case
|
||||||
tokens: list[int]
|
tokens: list[int]
|
||||||
logprobs: list[dict[int, Logprob]]
|
logprobs: list[dict[int, Logprob]]
|
||||||
lora_request: LoRARequest | None = None
|
lora_request: LoRARequest | None = None
|
||||||
@@ -31,6 +32,10 @@ class BeamSearchSequence:
|
|||||||
def get_prompt(self):
|
def get_prompt(self):
|
||||||
prompt = self.orig_prompt
|
prompt = self.orig_prompt
|
||||||
|
|
||||||
|
if prompt["type"] == "enc_dec":
|
||||||
|
return self._build_encoder_decoder_inputs(prompt)
|
||||||
|
|
||||||
|
# Handle decoder-only inputs
|
||||||
prompt_text = prompt.get("prompt")
|
prompt_text = prompt.get("prompt")
|
||||||
cache_salt = prompt.get("cache_salt")
|
cache_salt = prompt.get("cache_salt")
|
||||||
|
|
||||||
@@ -50,6 +55,44 @@ class BeamSearchSequence:
|
|||||||
cache_salt=cache_salt,
|
cache_salt=cache_salt,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _build_encoder_decoder_inputs(
|
||||||
|
self, prompt: EncoderDecoderInputs
|
||||||
|
) -> EncoderDecoderInputs:
|
||||||
|
"""Rebuild the encoder-decoder inputs with the current beam search
|
||||||
|
sequence's tokens.
|
||||||
|
|
||||||
|
FIXME (alex) - the encoder multimodal cache is not properly wired up
|
||||||
|
yet, which means that currently we are running the encoder on every
|
||||||
|
new beam because num_computed_tokens is 0 on each new request. This
|
||||||
|
will be fixed once the cache is correctly implemented.
|
||||||
|
"""
|
||||||
|
dec_prompt = prompt["decoder_prompt"]
|
||||||
|
|
||||||
|
# Rebuild decoder prompt with updated tokens,
|
||||||
|
# but keep everything else the same.
|
||||||
|
new_dec_prompt: DecoderInputs
|
||||||
|
if dec_prompt["type"] == "multimodal":
|
||||||
|
new_dec_prompt = mm_inputs(
|
||||||
|
self.tokens,
|
||||||
|
mm_kwargs=dec_prompt["mm_kwargs"],
|
||||||
|
mm_hashes=dec_prompt["mm_hashes"],
|
||||||
|
mm_placeholders=dec_prompt["mm_placeholders"],
|
||||||
|
prompt=dec_prompt.get("prompt"),
|
||||||
|
cache_salt=dec_prompt.get("cache_salt"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_dec_prompt = token_inputs(
|
||||||
|
self.tokens,
|
||||||
|
prompt=dec_prompt.get("prompt"),
|
||||||
|
cache_salt=dec_prompt.get("cache_salt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return EncoderDecoderInputs(
|
||||||
|
type="enc_dec",
|
||||||
|
encoder_prompt=prompt["encoder_prompt"],
|
||||||
|
decoder_prompt=new_dec_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BeamSearchOutput:
|
class BeamSearchOutput:
|
||||||
@@ -64,15 +107,20 @@ class BeamSearchOutput:
|
|||||||
class BeamSearchInstance:
|
class BeamSearchInstance:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
prompt: TokenInputs | MultiModalInputs,
|
prompt: TokenInputs | MultiModalInputs | EncoderDecoderInputs,
|
||||||
lora_request: LoRARequest | None = None,
|
lora_request: LoRARequest | None = None,
|
||||||
logprobs: list[dict[int, Logprob]] | None = None,
|
logprobs: list[dict[int, Logprob]] | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
decoder_prompt = (
|
||||||
|
prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
|
||||||
|
)
|
||||||
|
initial_tokens = decoder_prompt["prompt_token_ids"]
|
||||||
|
|
||||||
self.beams: list[BeamSearchSequence] = [
|
self.beams: list[BeamSearchSequence] = [
|
||||||
BeamSearchSequence(
|
BeamSearchSequence(
|
||||||
orig_prompt=prompt,
|
orig_prompt=prompt,
|
||||||
tokens=prompt["prompt_token_ids"],
|
tokens=initial_tokens,
|
||||||
logprobs=[] if logprobs is None else list(logprobs),
|
logprobs=[] if logprobs is None else list(logprobs),
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|||||||
@@ -734,10 +734,6 @@ class LLM:
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Embedding prompt not supported for beam search"
|
"Embedding prompt not supported for beam search"
|
||||||
)
|
)
|
||||||
if prompt["type"] == "enc_dec":
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Encoder-decoder prompt not supported for beam search"
|
|
||||||
)
|
|
||||||
|
|
||||||
instances.append(
|
instances.append(
|
||||||
BeamSearchInstance(
|
BeamSearchInstance(
|
||||||
|
|||||||
Reference in New Issue
Block a user