[core] remove beam search from the core (#9105)

This commit is contained in:
youkaichao
2024-10-06 22:47:04 -07:00
committed by GitHub
parent c8f26bb636
commit 18b296fdb2
25 changed files with 98 additions and 596 deletions

View File

@@ -85,73 +85,6 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
assert baseline_token_ids == test_token_ids
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "facebook/opt-125m",
# skip cuda graph creation for fast test.
"enforce_eager": True,
# Use a large block size to trigger more copy-on-writes.
"block_size": 32,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"preemption_mode": "swap"
}, {
"use_v2_block_manager": True,
"preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify beam search equality with block manager v1 and v2.
This requires copy-on-writes; if the v1 and v2 output is the same, then
we have some confidence cow is working.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
use_beam_search=True,
best_of=2,
)
print('Getting token ids from block manager v1')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
print('Getting token ids from block manager v2')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids
assert baseline_token_ids == test_token_ids
@pytest.mark.parametrize(
"common_llm_kwargs",
[{

View File

@@ -13,7 +13,6 @@ def create_dummy_prompt(
prompt_length: int,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1,
prompt_tokens: Optional[List[int]] = None,
min_tokens: int = 0,
@@ -37,7 +36,6 @@ def create_dummy_prompt(
seqs=[prompt],
arrival_time=time.time(),
sampling_params=SamplingParams(
use_beam_search=use_beam_search,
best_of=best_of,
max_tokens=max_tokens,
min_tokens=min_tokens),
@@ -52,7 +50,6 @@ def create_dummy_prompt_encoder_decoder(
encoder_prompt_length: int,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1,
) -> Tuple[Sequence, Sequence, SequenceGroup]:
if not block_size:
@@ -85,9 +82,7 @@ def create_dummy_prompt_encoder_decoder(
from_decoder_prompt=False)
seq_group = SequenceGroup(request_id=request_id,
seqs=[decoder_prompt],
sampling_params=SamplingParams(
use_beam_search=use_beam_search,
best_of=best_of),
sampling_params=SamplingParams(best_of=best_of),
arrival_time=time.time(),
lora_request=lora_request,
encoder_seq=encoder_prompt)