Re-enable the 80 char line width limit (#3305)
This commit is contained in:
@@ -4,12 +4,15 @@ import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker, split_num_cache_blocks_evenly
|
||||
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
||||
split_num_cache_blocks_evenly)
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from .utils import mock_worker, create_batch, ExecuteModelData, create_sampler_output_list
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics, AsyncMetricsCollector
|
||||
from .utils import (mock_worker, create_batch, ExecuteModelData,
|
||||
create_sampler_output_list)
|
||||
from vllm.spec_decode.metrics import (SpecDecodeWorkerMetrics,
|
||||
AsyncMetricsCollector)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@@ -391,13 +394,15 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
|
||||
mock_rejsample_metrics = MagicMock(
|
||||
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
|
||||
metrics_collector.maybe_collect_rejsample_metrics.return_value = mock_rejsample_metrics
|
||||
metrics_collector.maybe_collect_rejsample_metrics.return_value = (
|
||||
mock_rejsample_metrics)
|
||||
|
||||
output = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_spec_tokens=k)
|
||||
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
|
||||
|
||||
call_args_list = metrics_collector.maybe_collect_rejsample_metrics.call_args_list
|
||||
call_args_list = (
|
||||
metrics_collector.maybe_collect_rejsample_metrics.call_args_list)
|
||||
assert len(call_args_list) == 1
|
||||
args, kwargs = call_args_list[0]
|
||||
assert args[0] == k or kwargs.get('k', -1) == k
|
||||
@@ -547,7 +552,8 @@ def test_profile_num_available_blocks(available_gpu_blocks: int,
|
||||
|
||||
target_worker.profile_num_available_blocks.return_value = (
|
||||
available_gpu_blocks, available_cpu_blocks)
|
||||
target_worker.get_cache_block_size_bytes.return_value = target_cache_block_size_bytes
|
||||
target_worker.get_cache_block_size_bytes.return_value = (
|
||||
target_cache_block_size_bytes)
|
||||
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
|
||||
Reference in New Issue
Block a user