[Spec Decode] Introduce DraftModelRunner (#5799)

This commit is contained in:
Cody Yu
2024-06-28 09:17:51 -07:00
committed by GitHub
parent b90d8cd832
commit b2c620230a
15 changed files with 257 additions and 36 deletions

View File

@@ -14,6 +14,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker import Worker
T = TypeVar("T", bound=Worker)
@@ -66,7 +67,8 @@ def create_worker(cls: Callable[..., T],
num_gpu_blocks: int,
seed: int,
is_driver_worker: bool = True,
enforce_eager: bool = True) -> T:
enforce_eager: bool = True,
model_runner_cls: Optional[ModelRunner] = None) -> T:
engine_args = EngineArgs(
model=model_name,
seed=seed,
@@ -89,6 +91,7 @@ def create_worker(cls: Callable[..., T],
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
model_runner_cls=model_runner_cls,
)
worker.init_device()