[Spec Decode] Introduce DraftModelRunner (#5799)
This commit is contained in:
@@ -14,6 +14,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SequenceOutput)
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
T = TypeVar("T", bound=Worker)
|
||||
@@ -66,7 +67,8 @@ def create_worker(cls: Callable[..., T],
|
||||
num_gpu_blocks: int,
|
||||
seed: int,
|
||||
is_driver_worker: bool = True,
|
||||
enforce_eager: bool = True) -> T:
|
||||
enforce_eager: bool = True,
|
||||
model_runner_cls: Optional[ModelRunner] = None) -> T:
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
seed=seed,
|
||||
@@ -89,6 +91,7 @@ def create_worker(cls: Callable[..., T],
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
model_runner_cls=model_runner_cls,
|
||||
)
|
||||
|
||||
worker.init_device()
|
||||
|
||||
Reference in New Issue
Block a user