diff --git a/tests/conftest.py b/tests/conftest.py index b68696878..1e9d46d3c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -843,7 +843,10 @@ class VllmRunner: def get_inputs( self, - prompts: list[str] | list[torch.Tensor] | list[list[int]], + prompts: list[str] + | list[torch.Tensor] + | list[list[int]] + | list[dict[str, Any]], images: PromptImageInput | None = None, videos: PromptVideoInput | None = None, audios: PromptAudioInput | None = None, @@ -857,26 +860,32 @@ class VllmRunner: inputs = list[dict[str, Any]]() for i, prompt in enumerate(prompts): - prompt_dict = dict[str, Any]() - if isinstance(prompt, str): - prompt_dict["prompt"] = prompt - elif isinstance(prompt, list): - prompt_dict["prompt_token_ids"] = prompt + # If we're passing an encoder/decoder prompt, we assume it + # already contains the multimodal data in the prompt + if isinstance(prompt, dict): + assert images is None and audios is None and videos is None + inputs.append(prompt.copy()) else: - prompt_dict["prompt_embeds"] = prompt + prompt_dict = dict[str, Any]() + if isinstance(prompt, str): + prompt_dict["prompt"] = prompt + elif isinstance(prompt, list): + prompt_dict["prompt_token_ids"] = prompt + else: + prompt_dict["prompt_embeds"] = prompt - multi_modal_data = dict[str, Any]() - if images is not None and (image := images[i]) is not None: - multi_modal_data["image"] = image - if videos is not None and (video := videos[i]) is not None: - multi_modal_data["video"] = video - if audios is not None and (audio := audios[i]) is not None: - multi_modal_data["audio"] = audio + multi_modal_data = dict[str, Any]() + if images is not None and (image := images[i]) is not None: + multi_modal_data["image"] = image + if videos is not None and (video := videos[i]) is not None: + multi_modal_data["video"] = video + if audios is not None and (audio := audios[i]) is not None: + multi_modal_data["audio"] = audio - if multi_modal_data: - prompt_dict["multi_modal_data"] = multi_modal_data + if multi_modal_data: + prompt_dict["multi_modal_data"] = multi_modal_data - inputs.append(prompt_dict) + inputs.append(prompt_dict) return inputs diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index 4d58ad0a8..babf7e7a4 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -90,9 +90,9 @@ def run_test( @pytest.fixture -def input_audios() -> list[tuple[list[str], list[str], list[tuple[Any, int]]]]: +def resampled_assets() -> list[tuple[Any, int]]: audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] - inputs = [] + sampled_assets = [] for asset in audio_assets: audio, orig_sr = asset.audio_and_sample_rate # Resample to Whisper's expected sample rate (16kHz) @@ -100,8 +100,21 @@ def input_audios() -> list[tuple[list[str], list[str], list[tuple[Any, int]]]]: audio = librosa.resample( audio, orig_sr=orig_sr, target_sr=WHISPER_SAMPLE_RATE ) + sampled_assets.append( + (audio, WHISPER_SAMPLE_RATE), + ) + return sampled_assets + + +@pytest.fixture +def input_audios( + resampled_assets, +) -> list[tuple[list[str], list[str], list[tuple[Any, int]]]]: + inputs = [] + # audio assets are resampled to WHISPER_SAMPLE_RATE + for audio_info in resampled_assets: # vLLM prompts, HF prompts, audio inputs - inputs.append(([VLLM_PROMPT], [HF_PROMPT], [(audio, WHISPER_SAMPLE_RATE)])) + inputs.append(([VLLM_PROMPT], [HF_PROMPT], [audio_info])) return inputs @@ -111,6 +124,98 @@ def check_model_available(model: str) -> None: model_info.check_transformers_version(on_fail="skip") +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("beam_width", [1, 2]) +def test_beam_search_encoder_decoder( + monkeypatch, + hf_runner, + vllm_runner, + dtype: str, + max_tokens: int, + beam_width: int, + resampled_assets, +) -> None: + """Test beam search with encoder-decoder models (Whisper).""" + if current_platform.is_rocm(): + monkeypatch.setenv("VLLM_ROCM_USE_SKINNY_GEMM", "0") + + model = "openai/whisper-large-v3-turbo" + check_model_available(model) + + hf_prompts = [ + "<|startoftranscript|>", + "<|startoftranscript|>", + ] + + with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSpeechSeq2Seq) as hf_model: + hf_outputs = hf_model.generate_beam_search( + hf_prompts, + beam_width=beam_width, + max_tokens=max_tokens, + audios=resampled_assets, + ) + + # Test both explicit encoder/decoder prompts + vllm_prompts = [ + # Implicit encoder/decoder prompt + { + "prompt": "<|startoftranscript|>", + "multi_modal_data": {"audio": resampled_assets[0]}, + }, + # Explicit encoder/decover prompt + { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": {"audio": resampled_assets[1]}, + }, + "decoder_prompt": "<|startoftranscript|>", + }, + ] + + with vllm_runner( + model, + dtype="half", + max_model_len=448, + tensor_parallel_size=1, + max_num_seqs=4, + limit_mm_per_prompt={"audio": 2}, + enforce_eager=True, + ) as vllm_model: + vllm_outputs = vllm_model.generate_beam_search( + vllm_prompts, + beam_width=beam_width, + max_tokens=max_tokens, + ) + + for i in range(len(vllm_prompts)): + hf_output_ids, hf_output_texts = hf_outputs[i] + vllm_output_ids, vllm_output_texts = vllm_outputs[i] + + for j, (hf_text, vllm_text) in enumerate( + zip(hf_output_texts, vllm_output_texts) + ): + print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:") + print(hf_text) + print(f">>>{j}-th vllm output:") + print(vllm_text) + + # Check that we got the same number of beams + assert len(hf_output_ids) == len(vllm_output_ids) + + # For encoder-decoder models, we primarily want to verify that: + # 1. Beam search completes without errors + # 2. We get the expected number of beams + # 3. Outputs are reasonable (non-empty, diverse beams) + for j in range(len(vllm_output_ids)): + # Check that outputs are not empty + assert len(vllm_output_ids[j]) > 0, f"Prompt {i}, beam {j}: empty output" + # Check that decoded text is not empty + assert len(vllm_output_texts[j].strip()) > 0, ( + f"Prompt {i}, beam {j}: empty text output" + ) + + def test_parse_language_detection_output(): """Unit test for WhisperForConditionalGeneration.parse_language_detection_output. diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index 98675856a..e17e6d8ae 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -219,3 +219,7 @@ def test_beam_search_passes_multimodal_data( filtered_hf_output_ids = filtered_hf_output_ids[:-1] assert filtered_hf_output_ids == filtered_vllm_output_ids + + +# NOTE: encoder/decoder tests are currently located under +# tests/models/multimodal/generation/test_whisper.py diff --git a/vllm/beam_search.py b/vllm/beam_search.py index 239327dc9..230f5a123 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -3,7 +3,8 @@ from dataclasses import dataclass -from vllm.inputs import TokenInputs, token_inputs +from vllm.inputs import EncoderDecoderInputs, TokenInputs, token_inputs +from vllm.inputs.data import DecoderInputs from vllm.logprobs import Logprob from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalInputs, mm_inputs @@ -17,9 +18,9 @@ class BeamSearchSequence: about to be returned to the user. """ - orig_prompt: TokenInputs | MultiModalInputs + orig_prompt: TokenInputs | MultiModalInputs | EncoderDecoderInputs - # The tokens include the prompt. + # NOTE: Tokens represents decoder tokens in the encoder / decoder case tokens: list[int] logprobs: list[dict[int, Logprob]] lora_request: LoRARequest | None = None @@ -31,6 +32,10 @@ class BeamSearchSequence: def get_prompt(self): prompt = self.orig_prompt + if prompt["type"] == "enc_dec": + return self._build_encoder_decoder_inputs(prompt) + + # Handle decoder-only inputs prompt_text = prompt.get("prompt") cache_salt = prompt.get("cache_salt") @@ -50,6 +55,44 @@ class BeamSearchSequence: cache_salt=cache_salt, ) + def _build_encoder_decoder_inputs( + self, prompt: EncoderDecoderInputs + ) -> EncoderDecoderInputs: + """Rebuild the encoder-decoder inputs with the current beam search + sequence's tokens. + + FIXME (alex) - the encoder multimodal cache is not properly wired up + yet, which means that currently we are running the encoder on every + new beam because num_computed_tokens is 0 on each new request. This + will be fixed once the cache is correctly implemented. + """ + dec_prompt = prompt["decoder_prompt"] + + # Rebuild decoder prompt with updated tokens, + # but keep everything else the same. + new_dec_prompt: DecoderInputs + if dec_prompt["type"] == "multimodal": + new_dec_prompt = mm_inputs( + self.tokens, + mm_kwargs=dec_prompt["mm_kwargs"], + mm_hashes=dec_prompt["mm_hashes"], + mm_placeholders=dec_prompt["mm_placeholders"], + prompt=dec_prompt.get("prompt"), + cache_salt=dec_prompt.get("cache_salt"), + ) + else: + new_dec_prompt = token_inputs( + self.tokens, + prompt=dec_prompt.get("prompt"), + cache_salt=dec_prompt.get("cache_salt"), + ) + + return EncoderDecoderInputs( + type="enc_dec", + encoder_prompt=prompt["encoder_prompt"], + decoder_prompt=new_dec_prompt, + ) + @dataclass class BeamSearchOutput: @@ -64,15 +107,20 @@ class BeamSearchOutput: class BeamSearchInstance: def __init__( self, - prompt: TokenInputs | MultiModalInputs, + prompt: TokenInputs | MultiModalInputs | EncoderDecoderInputs, lora_request: LoRARequest | None = None, logprobs: list[dict[int, Logprob]] | None = None, **kwargs, ): + decoder_prompt = ( + prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"] + ) + initial_tokens = decoder_prompt["prompt_token_ids"] + self.beams: list[BeamSearchSequence] = [ BeamSearchSequence( orig_prompt=prompt, - tokens=prompt["prompt_token_ids"], + tokens=initial_tokens, logprobs=[] if logprobs is None else list(logprobs), lora_request=lora_request, **kwargs, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d5a51a6b9..eb1d4dbeb 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -734,10 +734,6 @@ class LLM: raise NotImplementedError( "Embedding prompt not supported for beam search" ) - if prompt["type"] == "enc_dec": - raise NotImplementedError( - "Encoder-decoder prompt not supported for beam search" - ) instances.append( BeamSearchInstance(