2025-02-02 14:58:18 -05:00
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-06-03 11:20:17 -07:00
|
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
2023-09-04 17:29:42 -07:00
|
|
|
|
"""Compare the outputs of HF and vLLM when using beam search.
|
|
|
|
|
|
|
2024-03-29 13:06:40 +09:00
|
|
|
|
Run `pytest tests/samplers/test_beam_search.py`.
|
2023-09-04 17:29:42 -07:00
|
|
|
|
"""
|
2024-03-24 21:39:33 -07:00
|
|
|
|
|
2023-09-04 17:29:42 -07:00
|
|
|
|
import pytest
|
2025-04-14 20:33:02 -06:00
|
|
|
|
from transformers import AutoModelForSeq2SeqLM
|
|
|
|
|
|
|
|
|
|
|
|
from vllm.assets.audio import AudioAsset
|
2026-02-19 02:25:26 -06:00
|
|
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
|
|
|
|
|
|
# Extra engine kwargs needed for numerically deterministic beam search.
|
|
|
|
|
|
# On ROCm, floating-point reductions in attention and GEMM kernels are
|
|
|
|
|
|
# non-associative and sensitive to batch geometry, so we:
|
|
|
|
|
|
# async_scheduling=False – deterministic batch composition
|
|
|
|
|
|
# enforce_eager=True – no CUDA-graph padding changing effective size
|
|
|
|
|
|
# enable_prefix_caching=False – avoid prefix-sharing side effects
|
|
|
|
|
|
# max_num_seqs=1 – fixed batch size across runs
|
|
|
|
|
|
# On other platforms these are not needed and the dict is empty.
|
|
|
|
|
|
EXTRA_ENGINE_KWARGS: dict = (
|
|
|
|
|
|
dict(
|
|
|
|
|
|
async_scheduling=False,
|
|
|
|
|
|
enforce_eager=True,
|
|
|
|
|
|
enable_prefix_caching=False,
|
|
|
|
|
|
max_num_seqs=1,
|
|
|
|
|
|
)
|
|
|
|
|
|
if current_platform.is_rocm()
|
2026-02-19 19:16:58 +08:00
|
|
|
|
else dict(async_scheduling=False, max_num_seqs=1)
|
2026-02-19 02:25:26 -06:00
|
|
|
|
)
|
2023-09-04 17:29:42 -07:00
|
|
|
|
|
|
|
|
|
|
# FIXME(zhuohan): The test can not pass if we:
|
|
|
|
|
|
# 1. Increase max_tokens to 256.
|
|
|
|
|
|
# 2. Increase beam_width to 8.
|
|
|
|
|
|
# 3. Use the model "huggyllama/llama-7b".
|
2024-09-23 22:08:12 -07:00
|
|
|
|
MAX_TOKENS = [64]
|
2023-09-04 17:29:42 -07:00
|
|
|
|
BEAM_WIDTHS = [4]
|
2025-04-14 20:33:02 -06:00
|
|
|
|
MM_BEAM_WIDTHS = [2]
|
2024-09-20 19:55:33 -07:00
|
|
|
|
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
|
2023-09-04 17:29:42 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("model", MODELS)
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", ["half"])
|
|
|
|
|
|
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
|
|
|
|
|
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
|
|
|
|
|
|
def test_beam_search_single_input(
|
2026-02-19 02:25:26 -06:00
|
|
|
|
monkeypatch,
|
2023-09-04 17:29:42 -07:00
|
|
|
|
hf_runner,
|
|
|
|
|
|
vllm_runner,
|
|
|
|
|
|
example_prompts,
|
|
|
|
|
|
model: str,
|
|
|
|
|
|
dtype: str,
|
|
|
|
|
|
max_tokens: int,
|
|
|
|
|
|
beam_width: int,
|
|
|
|
|
|
) -> None:
|
2026-02-19 02:25:26 -06:00
|
|
|
|
if current_platform.is_rocm():
|
|
|
|
|
|
monkeypatch.setenv("VLLM_ROCM_USE_SKINNY_GEMM", "0")
|
|
|
|
|
|
|
2024-02-20 14:37:39 -08:00
|
|
|
|
example_prompts = example_prompts[:1]
|
2024-06-07 22:31:32 -07:00
|
|
|
|
with hf_runner(model, dtype=dtype) as hf_model:
|
|
|
|
|
|
hf_outputs = hf_model.generate_beam_search(
|
|
|
|
|
|
example_prompts, beam_width, max_tokens
|
|
|
|
|
|
)
|
2023-09-04 17:29:42 -07:00
|
|
|
|
|
2026-02-19 02:25:26 -06:00
|
|
|
|
with vllm_runner(model, dtype=dtype, **EXTRA_ENGINE_KWARGS) as vllm_model:
|
2024-10-06 22:47:04 -07:00
|
|
|
|
vllm_outputs = vllm_model.generate_beam_search(
|
|
|
|
|
|
example_prompts, beam_width, max_tokens
|
|
|
|
|
|
)
|
2023-09-04 17:29:42 -07:00
|
|
|
|
|
|
|
|
|
|
for i in range(len(example_prompts)):
|
2024-09-20 19:55:33 -07:00
|
|
|
|
hf_output_ids, hf_output_texts = hf_outputs[i]
|
|
|
|
|
|
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
|
|
|
|
|
|
for j, (hf_text, vllm_text) in enumerate(
|
|
|
|
|
|
zip(hf_output_texts, vllm_output_texts)
|
|
|
|
|
|
):
|
2025-04-14 20:33:02 -06:00
|
|
|
|
print(f">>>{j}-th hf output:")
|
2024-09-20 19:55:33 -07:00
|
|
|
|
print(hf_text)
|
2025-04-14 20:33:02 -06:00
|
|
|
|
print(f">>>{j}-th vllm output:")
|
2024-09-20 19:55:33 -07:00
|
|
|
|
print(vllm_text)
|
2023-09-04 17:29:42 -07:00
|
|
|
|
assert len(hf_output_ids) == len(vllm_output_ids)
|
|
|
|
|
|
for j in range(len(hf_output_ids)):
|
|
|
|
|
|
assert hf_output_ids[j] == vllm_output_ids[j], (
|
|
|
|
|
|
f"Test{i} output{j}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}"
|
|
|
|
|
|
)
|
2025-04-14 20:33:02 -06:00
|
|
|
|
|
|
|
|
|
|
|
2025-08-26 21:59:14 -07:00
|
|
|
|
@pytest.mark.parametrize("model", MODELS)
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", ["half"])
|
|
|
|
|
|
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
|
|
|
|
|
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
|
|
|
|
|
|
def test_beam_search_with_concurrency_limit(
|
2026-02-19 02:25:26 -06:00
|
|
|
|
monkeypatch,
|
2025-08-26 21:59:14 -07:00
|
|
|
|
hf_runner,
|
|
|
|
|
|
vllm_runner,
|
|
|
|
|
|
example_prompts,
|
|
|
|
|
|
model: str,
|
|
|
|
|
|
dtype: str,
|
|
|
|
|
|
max_tokens: int,
|
|
|
|
|
|
beam_width: int,
|
|
|
|
|
|
) -> None:
|
2026-02-19 02:25:26 -06:00
|
|
|
|
if current_platform.is_rocm():
|
|
|
|
|
|
monkeypatch.setenv("VLLM_ROCM_USE_SKINNY_GEMM", "0")
|
|
|
|
|
|
|
2025-08-26 21:59:14 -07:00
|
|
|
|
# example_prompts[1]&[3]&[7] fails due to unknown reason even without
|
2025-09-03 11:44:50 +08:00
|
|
|
|
# concurrency limit. skip them for now.
|
2025-08-26 21:59:14 -07:00
|
|
|
|
example_prompts = example_prompts[:8]
|
|
|
|
|
|
concurrency_limit = 2
|
|
|
|
|
|
assert len(example_prompts) > concurrency_limit
|
2026-02-19 02:25:26 -06:00
|
|
|
|
with vllm_runner(model, dtype=dtype, **EXTRA_ENGINE_KWARGS) as vllm_model:
|
2025-08-26 21:59:14 -07:00
|
|
|
|
outputs_with_limit = vllm_model.generate_beam_search(
|
2026-02-19 02:25:26 -06:00
|
|
|
|
example_prompts,
|
|
|
|
|
|
beam_width,
|
|
|
|
|
|
max_tokens,
|
|
|
|
|
|
concurrency_limit=concurrency_limit,
|
2025-08-26 21:59:14 -07:00
|
|
|
|
)
|
|
|
|
|
|
outputs_without_limit = []
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(0, len(example_prompts), concurrency_limit):
|
|
|
|
|
|
outputs_without_limit.extend(
|
|
|
|
|
|
vllm_model.generate_beam_search(
|
2026-02-19 02:25:26 -06:00
|
|
|
|
example_prompts[i : i + concurrency_limit],
|
|
|
|
|
|
beam_width,
|
|
|
|
|
|
max_tokens,
|
2025-08-26 21:59:14 -07:00
|
|
|
|
)
|
2025-10-05 15:06:22 +01:00
|
|
|
|
)
|
2025-08-26 21:59:14 -07:00
|
|
|
|
|
|
|
|
|
|
correct = True
|
|
|
|
|
|
for i in range(len(example_prompts)):
|
|
|
|
|
|
output_ids_with_limit, output_texts_with_limit = outputs_with_limit[i]
|
|
|
|
|
|
output_ids_without_limit, output_texts_without_limit = outputs_without_limit[i]
|
|
|
|
|
|
for j, (text_with_limit, text_without_limit) in enumerate(
|
|
|
|
|
|
zip(output_texts_with_limit, output_texts_without_limit)
|
|
|
|
|
|
):
|
|
|
|
|
|
print(f">>>{j}-th with limit output:")
|
|
|
|
|
|
print(text_with_limit)
|
|
|
|
|
|
print(f">>>{j}-th without limit output:")
|
|
|
|
|
|
print(text_without_limit)
|
|
|
|
|
|
assert len(output_ids_with_limit) == len(output_ids_without_limit)
|
|
|
|
|
|
for j in range(len(output_ids_with_limit)):
|
|
|
|
|
|
if output_ids_with_limit[j] != output_ids_without_limit[j]:
|
|
|
|
|
|
print(
|
|
|
|
|
|
f"Test{i} output{j}:\n+limit: {output_ids_with_limit}\n"
|
|
|
|
|
|
f"-limit: {output_ids_without_limit}"
|
|
|
|
|
|
)
|
|
|
|
|
|
correct = False
|
|
|
|
|
|
assert correct
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-04-14 20:33:02 -06:00
|
|
|
|
@pytest.mark.parametrize("dtype", ["half"])
|
|
|
|
|
|
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
|
|
|
|
|
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)
|
|
|
|
|
|
def test_beam_search_passes_multimodal_data(
|
2026-02-19 02:25:26 -06:00
|
|
|
|
monkeypatch,
|
2025-04-14 20:33:02 -06:00
|
|
|
|
hf_runner,
|
|
|
|
|
|
vllm_runner,
|
|
|
|
|
|
dtype: str,
|
|
|
|
|
|
max_tokens: int,
|
|
|
|
|
|
beam_width: int,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
"""Ensure that beam search passes multimodal data through correctly."""
|
2026-02-19 02:25:26 -06:00
|
|
|
|
if current_platform.is_rocm():
|
|
|
|
|
|
monkeypatch.setenv("VLLM_ROCM_USE_SKINNY_GEMM", "0")
|
|
|
|
|
|
|
2025-04-14 20:33:02 -06:00
|
|
|
|
# NOTE - this test is primarily to check that mm data is passed to beams
|
|
|
|
|
|
# correctly. As such, we just need to check one extra modality to make
|
|
|
|
|
|
# sure things pass through properly.
|
|
|
|
|
|
audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate]
|
|
|
|
|
|
model = "Qwen/Qwen2-Audio-7B-Instruct"
|
|
|
|
|
|
audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
|
|
|
|
|
prompts = [
|
|
|
|
|
|
f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" # noqa: E501
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
|
|
|
|
|
audio_token_id = hf_model.config.audio_token_index
|
|
|
|
|
|
eos_token_id = hf_model.tokenizer.eos_token_id # <|im_end|>
|
|
|
|
|
|
hf_outputs = hf_model.generate_beam_search(
|
|
|
|
|
|
prompts,
|
|
|
|
|
|
beam_width=beam_width,
|
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
|
audios=audios,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-02-19 02:25:26 -06:00
|
|
|
|
with vllm_runner(model, dtype=dtype, **EXTRA_ENGINE_KWARGS) as vllm_model:
|
2025-04-14 20:33:02 -06:00
|
|
|
|
vllm_outputs = vllm_model.generate_beam_search(
|
|
|
|
|
|
prompts,
|
|
|
|
|
|
beam_width=beam_width,
|
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
|
audios=audios,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
seq_with_no_audio_toks = lambda seq: [tok for tok in seq if tok != audio_token_id]
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(prompts)):
|
|
|
|
|
|
hf_output_ids, hf_output_texts = hf_outputs[i]
|
|
|
|
|
|
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
|
|
|
|
|
|
|
|
|
|
|
|
for j, (hf_text, vllm_text) in enumerate(
|
|
|
|
|
|
zip(hf_output_texts, vllm_output_texts)
|
|
|
|
|
|
):
|
|
|
|
|
|
print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:")
|
|
|
|
|
|
print(hf_text)
|
|
|
|
|
|
print(f">>>{j}-th vllm output:")
|
|
|
|
|
|
print(vllm_text)
|
|
|
|
|
|
assert len(hf_output_ids) == len(vllm_output_ids)
|
|
|
|
|
|
|
|
|
|
|
|
for j in range(len(hf_output_ids)):
|
|
|
|
|
|
# Compare everything except for the audio tokens; we do this since
|
|
|
|
|
|
# the IDs returned from the transformers helper expands the audio
|
|
|
|
|
|
# token to match features, while the vLLM helper maintains the
|
|
|
|
|
|
# single audio token in the input text
|
|
|
|
|
|
filtered_hf_output_ids = seq_with_no_audio_toks(hf_output_ids[j])
|
|
|
|
|
|
filtered_vllm_output_ids = seq_with_no_audio_toks(vllm_output_ids[j])
|
|
|
|
|
|
|
|
|
|
|
|
# HF output IDs may contain the end of sequence
|
|
|
|
|
|
if len(filtered_hf_output_ids) == len(filtered_vllm_output_ids) + 1:
|
|
|
|
|
|
assert filtered_hf_output_ids[-1] == eos_token_id
|
|
|
|
|
|
filtered_hf_output_ids = filtered_hf_output_ids[:-1]
|
|
|
|
|
|
|
|
|
|
|
|
assert filtered_hf_output_ids == filtered_vllm_output_ids
|