[Misc][Refactor] Introduce ExecuteModelData (#4540)

This commit is contained in:
Cody Yu
2024-05-03 17:47:07 -07:00
committed by GitHub
parent 344bf7cd2d
commit bc8ad68455
23 changed files with 355 additions and 511 deletions

View File

@@ -7,7 +7,7 @@ import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
@@ -15,8 +15,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
split_num_cache_blocks_evenly)
from .utils import (ExecuteModelData, create_batch, create_sampler_output_list,
mock_worker)
from .utils import create_batch, create_sampler_output_list, mock_worker
@pytest.mark.parametrize('k', [1, 2, 6])
@@ -36,24 +35,19 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
execute_model_data, _, _ = create_batch(batch_size, k)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
worker.execute_model(execute_model_req=execute_model_req)
call_args_list = draft_worker.get_spec_proposals.call_args_list
assert len(call_args_list) == 1
for args, _ in call_args_list:
(seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, actual_k) = args
actual_execute_model_data = ExecuteModelData(seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy)
assert actual_execute_model_data == execute_model_data
assert actual_k == k
actual_execute_model_data = args[0]
assert actual_execute_model_data == execute_model_req
@pytest.mark.parametrize('k', [1, 2, 6])
@@ -93,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
execute_model_data, prompts, prev_output_tokens = create_batch(
seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
@@ -105,20 +99,20 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
target_worker.execute_model.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
seen_contexts = []
call_args_list = target_worker.execute_model.call_args_list
assert len(call_args_list) == 1
for args, kwargs in call_args_list:
target_execute_model_data = ExecuteModelData.from_dict(kwargs)
for _, kwargs in call_args_list:
seq_group_metadata_list = kwargs[
"execute_model_req"].seq_group_metadata_list
assert len(target_execute_model_data.seq_group_metadata_list) == (
k + 1) * batch_size
for seq_group_metadata in (
target_execute_model_data.seq_group_metadata_list):
assert len(seq_group_metadata_list) == (k + 1) * batch_size
for seq_group_metadata in seq_group_metadata_list:
for seq_data in seq_group_metadata.seq_data.values():
seen_contexts.append(seq_data.get_token_ids())
@@ -175,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
execute_model_data, _, _ = create_batch(batch_size, k)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
@@ -207,8 +201,9 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
rejection_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
assert len(rejection_sampler.call_args_list) == 1
_, kwargs = rejection_sampler.call_args_list[0]
@@ -262,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
execute_model_data, _, _ = create_batch(batch_size, k)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
@@ -302,8 +297,9 @@ def test_correctly_formats_output(k: int, batch_size: int):
rejection_sampler.return_value = rejection_sampler_output
output = worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
expected_output = create_sampler_output_list(
token_ids=rejection_sampler_output.transpose(0, 1),
@@ -312,7 +308,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
seq_ids = [
next(iter(seq_group_metadata.seq_data.keys()))
for seq_group_metadata in execute_model_data.seq_group_metadata_list
for seq_group_metadata in seq_group_metadata_list
]
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}
@@ -383,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
execute_model_data, _, _ = create_batch(batch_size, k)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
@@ -428,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
metrics_collector.maybe_collect_rejsample_metrics.return_value = (
mock_rejsample_metrics)
output = worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
call_args_list = (
@@ -462,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
execute_model_data, prompts, prev_output_tokens = create_batch(
batch_size, k, prev_output_token_len=0)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
prev_output_token_len=0)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
out = worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
out = worker.execute_model(execute_model_req=execute_model_req)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())
target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
@pytest.mark.parametrize('k', [0, 5])
@@ -503,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
execute_model_data, prompts, prev_output_tokens = create_batch(
batch_size, k, prev_output_token_len=0)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
prev_output_token_len=0)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
out = worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
out = worker.execute_model(execute_model_req=execute_model_req)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())
target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
@pytest.mark.skip_global_cleanup