[Spec Decode] Introduce DraftModelRunner (#5799)
This commit is contained in:
@@ -7,6 +7,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
@@ -85,6 +86,7 @@ def test_same_output_for_single_step():
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
@@ -168,6 +170,7 @@ def test_same_output_for_multi_step():
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
|
||||
worker = create_worker(
|
||||
|
||||
Reference in New Issue
Block a user