[Spec Decode] Introduce DraftModelRunner (#5799)

This commit is contained in:
Cody Yu
2024-06-28 09:17:51 -07:00
committed by GitHub
parent b90d8cd832
commit b2c620230a
15 changed files with 257 additions and 36 deletions

View File

@@ -7,6 +7,7 @@ import torch
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
@@ -85,6 +86,7 @@ def test_same_output_for_single_step():
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(
Worker,
@@ -168,6 +170,7 @@ def test_same_output_for_multi_step():
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(