[Spec Decode] Introduce DraftModelRunner (#5799)

This commit is contained in:
Cody Yu
2024-06-28 09:17:51 -07:00
committed by GitHub
parent b90d8cd832
commit b2c620230a
15 changed files with 257 additions and 36 deletions

View File

@@ -7,6 +7,7 @@ import torch
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
@@ -85,6 +86,7 @@ def test_same_output_for_single_step():
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(
Worker,
@@ -168,6 +170,7 @@ def test_same_output_for_multi_step():
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(

View File

@@ -14,6 +14,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker import Worker
T = TypeVar("T", bound=Worker)
@@ -66,7 +67,8 @@ def create_worker(cls: Callable[..., T],
num_gpu_blocks: int,
seed: int,
is_driver_worker: bool = True,
enforce_eager: bool = True) -> T:
enforce_eager: bool = True,
model_runner_cls: Optional[ModelRunner] = None) -> T:
engine_args = EngineArgs(
model=model_name,
seed=seed,
@@ -89,6 +91,7 @@ def create_worker(cls: Callable[..., T],
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
model_runner_cls=model_runner_cls,
)
worker.init_device()