[Misc][Refactor] Introduce ExecuteModelData (#4540)
This commit is contained in:
@@ -5,13 +5,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
from .utils import (assert_logprobs_dict_allclose, create_batch,
|
||||
create_execute_model_data,
|
||||
create_seq_group_metadata_from_prompts, create_worker,
|
||||
patch_execute_model_with_seeds, zero_kv_cache)
|
||||
|
||||
@@ -105,31 +104,32 @@ def test_same_output_for_single_step():
|
||||
|
||||
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
multi_step_execute_model_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
single_step_execute_model_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
multi_step_seq_group = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
actual_output, _ = multi_step_worker.sampler_output(
|
||||
**multi_step_execute_model_data.to_dict(), sample_len=num_steps)
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=multi_step_seq_group),
|
||||
sample_len=num_steps)
|
||||
assert len(actual_output) == num_steps
|
||||
actual_output = actual_output[0]
|
||||
|
||||
single_step_seq_group = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
expected_output = worker.execute_model(
|
||||
**single_step_execute_model_data.to_dict(), )[0]
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=single_step_seq_group))[0]
|
||||
|
||||
actual_token_ids = [
|
||||
output.samples[0].output_token for output in actual_output
|
||||
@@ -193,19 +193,20 @@ def test_same_output_for_multi_step():
|
||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
|
||||
continuations = [[1] for _ in prompts]
|
||||
execute_model_data = create_execute_model_data(
|
||||
create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens), )
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
**execute_model_data.to_dict(), sample_len=num_steps)
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=num_steps)
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
@@ -215,16 +216,16 @@ def test_same_output_for_multi_step():
|
||||
|
||||
for _ in multi_step_output:
|
||||
|
||||
execute_model_data = create_execute_model_data(
|
||||
create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
single_step_output.extend(
|
||||
worker.execute_model(**execute_model_data.to_dict(), ))
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list)))
|
||||
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
@@ -304,12 +305,11 @@ def test_draft_proposals_full_speculation_len():
|
||||
) for _ in range(k)
|
||||
], True
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
proposal_len=k,
|
||||
)
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
@@ -340,14 +340,13 @@ def test_draft_proposals_no_speculations():
|
||||
max_proposal_len=prompt_len + k - 1,
|
||||
)
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prompt_len=prompt_len)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prompt_len=prompt_len)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
proposal_len=k,
|
||||
)
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
@@ -409,17 +408,16 @@ def test_draft_proposals_mixed_k():
|
||||
) for _ in range(k)
|
||||
], True
|
||||
|
||||
execute_model_data, _, _ = create_batch(
|
||||
seq_group_metadata_list, _, _ = create_batch(
|
||||
batch_size,
|
||||
k,
|
||||
prompt_len=prompt_len,
|
||||
prev_output_token_len=prev_output_token_len,
|
||||
)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
proposal_len=k,
|
||||
)
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import torch
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from .utils import (create_execute_model_data,
|
||||
create_seq_group_metadata_from_prompts, create_worker)
|
||||
from .utils import create_seq_group_metadata_from_prompts, create_worker
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_single_no_match():
|
||||
@@ -44,17 +44,15 @@ def test_ngram_algo_correctness_for_single_no_match():
|
||||
|
||||
proposal_len = 5
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
proposal_len=proposal_len,
|
||||
)
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len), )
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
@@ -113,17 +111,15 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
|
||||
proposal_len = 5
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
proposal_len=proposal_len,
|
||||
)
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len), )
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
@@ -185,17 +181,15 @@ def test_ngram_algo_correctness_for_batches_match_all():
|
||||
|
||||
proposal_len = 5
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
proposal_len=proposal_len,
|
||||
)
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len), )
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
@@ -15,8 +15,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
||||
split_num_cache_blocks_evenly)
|
||||
|
||||
from .utils import (ExecuteModelData, create_batch, create_sampler_output_list,
|
||||
mock_worker)
|
||||
from .utils import create_batch, create_sampler_output_list, mock_worker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@@ -36,24 +35,19 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
call_args_list = draft_worker.get_spec_proposals.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
|
||||
for args, _ in call_args_list:
|
||||
(seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
||||
blocks_to_copy, actual_k) = args
|
||||
actual_execute_model_data = ExecuteModelData(seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy)
|
||||
assert actual_execute_model_data == execute_model_data
|
||||
assert actual_k == k
|
||||
actual_execute_model_data = args[0]
|
||||
assert actual_execute_model_data == execute_model_req
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@@ -93,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
execute_model_data, prompts, prev_output_tokens = create_batch(
|
||||
seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
|
||||
batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
@@ -105,20 +99,20 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
seen_contexts = []
|
||||
|
||||
call_args_list = target_worker.execute_model.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
for args, kwargs in call_args_list:
|
||||
target_execute_model_data = ExecuteModelData.from_dict(kwargs)
|
||||
for _, kwargs in call_args_list:
|
||||
seq_group_metadata_list = kwargs[
|
||||
"execute_model_req"].seq_group_metadata_list
|
||||
|
||||
assert len(target_execute_model_data.seq_group_metadata_list) == (
|
||||
k + 1) * batch_size
|
||||
for seq_group_metadata in (
|
||||
target_execute_model_data.seq_group_metadata_list):
|
||||
assert len(seq_group_metadata_list) == (k + 1) * batch_size
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seen_contexts.append(seq_data.get_token_ids())
|
||||
|
||||
@@ -175,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
@@ -207,8 +201,9 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
rejection_sampler.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
assert len(rejection_sampler.call_args_list) == 1
|
||||
_, kwargs = rejection_sampler.call_args_list[0]
|
||||
@@ -262,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
@@ -302,8 +297,9 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
|
||||
rejection_sampler.return_value = rejection_sampler_output
|
||||
|
||||
output = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
expected_output = create_sampler_output_list(
|
||||
token_ids=rejection_sampler_output.transpose(0, 1),
|
||||
@@ -312,7 +308,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
|
||||
seq_ids = [
|
||||
next(iter(seq_group_metadata.seq_data.keys()))
|
||||
for seq_group_metadata in execute_model_data.seq_group_metadata_list
|
||||
for seq_group_metadata in seq_group_metadata_list
|
||||
]
|
||||
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
||||
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
||||
@@ -383,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
@@ -428,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
metrics_collector.maybe_collect_rejsample_metrics.return_value = (
|
||||
mock_rejsample_metrics)
|
||||
|
||||
output = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
|
||||
|
||||
call_args_list = (
|
||||
@@ -462,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int):
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
|
||||
execute_model_data, prompts, prev_output_tokens = create_batch(
|
||||
batch_size, k, prev_output_token_len=0)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prev_output_token_len=0)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||
|
||||
out = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||
assert out[
|
||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict())
|
||||
target_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict())
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [0, 5])
|
||||
@@ -503,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int):
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
|
||||
execute_model_data, prompts, prev_output_tokens = create_batch(
|
||||
batch_size, k, prev_output_token_len=0)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prev_output_token_len=0)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||
|
||||
out = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||
assert out[
|
||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict())
|
||||
target_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict())
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from dataclasses import dataclass, fields
|
||||
from itertools import count
|
||||
from typing import Dict, Iterable, List, Optional, Union
|
||||
from unittest.mock import MagicMock
|
||||
@@ -16,50 +15,10 @@ from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteModelData:
|
||||
"""Helper data structure which facilitates cleaner tests.
|
||||
"""
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
blocks_to_swap_in: Dict[int, int]
|
||||
blocks_to_swap_out: Dict[int, int]
|
||||
blocks_to_copy: Dict[int, List[int]]
|
||||
|
||||
def to_dict(self):
|
||||
return dict(
|
||||
(field.name, getattr(self, field.name)) for field in fields(self))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
cleaned = dict((field.name, d[field.name]) for field in fields(cls))
|
||||
return cls(**cleaned)
|
||||
|
||||
|
||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||
return (seq_len + block_size - 1) // block_size
|
||||
|
||||
|
||||
def create_execute_model_data(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
||||
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
||||
blocks_to_copy: Optional[Dict[int, int]] = None,
|
||||
) -> ExecuteModelData:
|
||||
if blocks_to_swap_in is None:
|
||||
blocks_to_swap_in = {}
|
||||
if blocks_to_swap_out is None:
|
||||
blocks_to_swap_out = {}
|
||||
if blocks_to_copy is None:
|
||||
blocks_to_copy = {}
|
||||
|
||||
return ExecuteModelData(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
|
||||
def mock_worker(cls=None,
|
||||
vocab_size: int = 30_000,
|
||||
max_model_len: int = 2048,
|
||||
@@ -258,8 +217,7 @@ def create_batch(batch_size,
|
||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
execute_model_data = create_execute_model_data(
|
||||
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
|
||||
block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids), )
|
||||
return execute_model_data, prompts, prev_output_tokens
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids)
|
||||
return seq_group_metadata_list, prompts, prev_output_tokens
|
||||
|
||||
Reference in New Issue
Block a user