diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index b8787c765..45af2b693 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -439,6 +439,8 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai Code example: [examples/online_serving/openai_transcription_client.py](../../examples/online_serving/openai_transcription_client.py) +NOTE: beam search is currently supported in the transcriptions endpoint for encoder-decoder multimodal models, e.g., whisper, but highly inefficient as work for handling the encoder/decoder cache is actively ongoing. This is an active point of ongoing optimization and will be handled properly in the very near future. + #### API Enforced Limits Set the maximum audio file size (in MB) that VLLM will accept, via the diff --git a/tests/entrypoints/openai/test_transcription_validation_whisper.py b/tests/entrypoints/openai/test_transcription_validation_whisper.py index cbee032a7..c2479efe4 100644 --- a/tests/entrypoints/openai/test_transcription_validation_whisper.py +++ b/tests/entrypoints/openai/test_transcription_validation_whisper.py @@ -317,3 +317,72 @@ async def test_language_auto_detect( assert any(word.lower() in text_lower for word in expected_text), ( f"Expected {expected_lang} text but got: {transcription.text}" ) + + +@pytest.mark.asyncio +async def test_whisper_beam_search_single_beam(mary_had_lamb, whisper_client): + """Test beam search with encoder-decoder model (Whisper) on transcriptions with + one beam aligns with greedy decoding. + """ + beam_transcription = await whisper_client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0, + extra_body=dict( + use_beam_search=True, + n=1, + ), + ) + + greedy_transcription = await whisper_client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + response_format="text", + temperature=0.0, + ) + + greedy_res = json.loads(greedy_transcription)["text"] + beam_res = json.loads(beam_transcription)["text"] + assert greedy_res == beam_res + + +@pytest.mark.asyncio +async def test_whisper_beam_search_multibeam(mary_had_lamb, whisper_client): + """Test n>1 for beam search returns one transcription (best beam).""" + transcription = await whisper_client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0, + extra_body=dict( + use_beam_search=True, + n=2, + ), + ) + + result = json.loads(transcription) + + text = result["text"] + + assert text is not None + assert len(text) > 0 + assert "mary had a little lamb" in text.lower() + + +@pytest.mark.asyncio +async def test_stream_with_beams_raises(winning_call, whisper_client): + """Test that stream=True + beam search raises bad request for now.""" + with pytest.raises(openai.BadRequestError): + await whisper_client.audio.transcriptions.create( + model=MODEL_NAME, + file=winning_call, + language="en", + stream=True, + extra_body=dict( + use_beam_search=True, + n=2, + ), + ) diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index 27320cbd0..dc5ef5639 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -129,6 +129,11 @@ class OpenAIServingCompletion(OpenAIServing): - suffix (the language models we currently support do not support suffix) """ + if request.stream and request.use_beam_search: + return self.create_error_response( + "Streaming is not currently supported with beam search" + ) + result = await self.render_completion_request(request) if isinstance(result, ErrorResponse): return result @@ -211,13 +216,10 @@ class OpenAIServingCompletion(OpenAIServing): model_name = self.models.model_name(lora_request) num_prompts = len(engine_prompts) - # We do not stream the results when using beam search. - stream = request.stream and not request.use_beam_search - # Streaming response tokenizer = self.renderer.tokenizer - if stream: + if request.stream: return self.completion_stream_generator( request, engine_prompts, diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 73557fac6..58e593ea5 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -237,13 +237,14 @@ class OpenAIServing: if prompt["type"] == "embeds": raise NotImplementedError("Embedding prompt not supported for beam search") - if prompt["type"] == "enc_dec": - raise NotImplementedError( - "Encoder-decoder prompt not supported for beam search" - ) - prompt_text = prompt.get("prompt") - prompt_token_ids = prompt["prompt_token_ids"] + # Extract prompt tokens and text based on model type + decoder_prompt = ( + prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"] + ) + prompt_text = decoder_prompt.get("prompt") + prompt_token_ids = decoder_prompt["prompt_token_ids"] + tokenized_length = len(prompt_token_ids) logprobs_num = 2 * beam_width diff --git a/vllm/entrypoints/openai/speech_to_text/protocol.py b/vllm/entrypoints/openai/speech_to_text/protocol.py index 978113e6a..ed32db2f0 100644 --- a/vllm/entrypoints/openai/speech_to_text/protocol.py +++ b/vllm/entrypoints/openai/speech_to_text/protocol.py @@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import ( from vllm.exceptions import VLLMValidationError from vllm.logger import init_logger from vllm.sampling_params import ( + BeamSearchParams, RequestOutputKind, SamplingParams, ) @@ -123,6 +124,18 @@ class TranscriptionRequest(OpenAIBaseModel): """ # --8<-- [start:transcription-sampling-params] + use_beam_search: bool = False + """Whether or not beam search should be used.""" + + n: int = 1 + """The number of beams to be used in beam search.""" + + length_penalty: float = 1.0 + """Length penalty to be used for beam search.""" + + include_stop_str_in_output: bool = False + """Whether to include the stop strings in output text.""" + temperature: float = Field(default=0.0) """The sampling temperature, between 0 and 1. @@ -170,6 +183,29 @@ class TranscriptionRequest(OpenAIBaseModel): "min_p": 0.0, } + def to_beam_search_params( + self, + default_max_tokens: int, + default_sampling_params: dict | None = None, + ) -> BeamSearchParams: + if default_sampling_params is None: + default_sampling_params = {} + + max_tokens = default_max_tokens + n = self.n if self.n is not None else 1 + + # NOTE: Temp 0 is a different fallback than completions + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get("temperature", 0) + + return BeamSearchParams( + beam_width=n, + max_tokens=max_tokens, + temperature=temperature, + length_penalty=self.length_penalty, + include_stop_str_in_output=self.include_stop_str_in_output, + ) + def to_sampling_params( self, default_max_tokens: int, default_sampling_params: dict | None = None ) -> SamplingParams: @@ -376,6 +412,18 @@ class TranslationRequest(OpenAIBaseModel): # TODO support additional sampling parameters # --8<-- [start:translation-sampling-params] + use_beam_search: bool = False + """Whether or not beam search should be used.""" + + n: int = 1 + """The number of beams to be used in beam search.""" + + length_penalty: float = 1.0 + """Length penalty to be used for beam search.""" + + include_stop_str_in_output: bool = False + """Whether to include the stop strings in output text.""" + seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) """The seed to use for sampling.""" @@ -424,6 +472,29 @@ class TranslationRequest(OpenAIBaseModel): "temperature": 0, } + def to_beam_search_params( + self, + default_max_tokens: int, + default_sampling_params: dict | None = None, + ) -> BeamSearchParams: + if default_sampling_params is None: + default_sampling_params = {} + + max_tokens = default_max_tokens + n = self.n if self.n is not None else 1 + + # NOTE: Temp 0 is a different fallback than completions + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get("temperature", 0) + + return BeamSearchParams( + beam_width=n, + max_tokens=max_tokens, + temperature=temperature, + length_penalty=self.length_penalty, + include_stop_str_in_output=self.include_stop_str_in_output, + ) + def to_sampling_params( self, default_max_tokens: int, default_sampling_params: dict | None = None ) -> SamplingParams: diff --git a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py index 7f12892f4..3de088fa9 100644 --- a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py @@ -39,7 +39,7 @@ from vllm.entrypoints.openai.speech_to_text.protocol import ( ) from vllm.entrypoints.utils import get_max_tokens from vllm.exceptions import VLLMValidationError -from vllm.inputs import ProcessorInputs +from vllm.inputs import EncoderDecoderInputs, ProcessorInputs from vllm.logger import init_logger from vllm.logprobs import FlatLogprobs, Logprob from vllm.model_executor.models import ( @@ -50,6 +50,7 @@ from vllm.multimodal.audio import split_audio from vllm.outputs import RequestOutput from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt +from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tokenizers import get_tokenizer from vllm.utils.import_utils import PlaceholderModule @@ -264,8 +265,6 @@ class OpenAISpeechToText(OpenAIServing): via ``get_language_detection_prompt`` and ``parse_language_detection_output``. """ - from vllm.sampling_params import SamplingParams - prompt = self.model_cls.get_language_detection_prompt( audio_chunk, self.asr_config, @@ -403,6 +402,26 @@ class OpenAISpeechToText(OpenAIServing): return prompt + @staticmethod + def _get_decoder_prompt_len(engine_prompts: list[ProcessorInputs]) -> int: + """Get the length of the decoder prompt. Currently we need to offset + by the decoder prompt length when running beam search because the mm + encoder is not currently cached and runs on decode calls; because of + this, we need to make sure the redundant encoder calls won't exceed + the context :( + + FIXME (Alex) - this will be removed in the very near future once the + encoder/decoder caching is implemented. + """ + input_len = 0 + assert len(engine_prompts) > 0 + first_eng_prompt = engine_prompts[0] + + if first_eng_prompt.get("type") == "enc_dec": + first_eng_prompt = cast(EncoderDecoderInputs, first_eng_prompt) + input_len = len(first_eng_prompt["decoder_prompt"]["prompt_token_ids"]) + return input_len + def _get_verbose_segments( self, tokens: tuple, @@ -481,6 +500,11 @@ class OpenAISpeechToText(OpenAIServing): ) -> T | V | AsyncGenerator[str, None] | ErrorResponse: """Base method for speech-to-text operations like transcription and translation.""" + if request.stream and request.use_beam_search: + return self.create_error_response( + "Streaming is not currently supported with beam search" + ) + error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret @@ -526,6 +550,13 @@ class OpenAISpeechToText(OpenAIServing): # Schedule the request and get the result generator. max_model_len = self.model_config.max_model_len list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None + + input_len = ( + OpenAISpeechToText._get_decoder_prompt_len(engine_prompts) + if request.use_beam_search + else 0 + ) + # Unlike most decoder-only models, whisper generation length is not # constrained by the size of the input audio, which is mapped to a # fixed-size log-mel-spectogram. Still, allow for fewer tokens to be @@ -533,14 +564,20 @@ class OpenAISpeechToText(OpenAIServing): max_tokens = get_max_tokens( max_model_len, request.max_completion_tokens, - 0, + input_len, self.default_sampling_params, ) - sampling_params = request.to_sampling_params( - max_tokens, - self.default_sampling_params, - ) + if request.use_beam_search: + sampling_params = request.to_beam_search_params( + max_tokens, self.default_sampling_params + ) + else: + sampling_params = request.to_sampling_params( + max_tokens, + self.default_sampling_params, + ) + if request.response_format == "verbose_json": sampling_params.logprobs = 1 @@ -561,13 +598,22 @@ class OpenAISpeechToText(OpenAIServing): else await self._get_trace_headers(raw_request.headers) ) - generator = self.engine_client.generate( - engine_prompt, - sampling_params, - request_id_item, - lora_request=lora_request, - trace_headers=trace_headers, - ) + if isinstance(sampling_params, BeamSearchParams): + generator = self.beam_search( + prompt=engine_prompt, + params=sampling_params, + request_id=request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + ) + else: + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + ) list_result_generator.append(generator)