[Speculative decoding] Adding configuration object for speculative decoding (#3706)

Co-authored-by: Lily Liu <lilyliupku@gmail.com>
This commit is contained in:
Cade Daniel
2024-04-02 17:40:57 -07:00
committed by GitHub
parent a3c226e7eb
commit 5757d90e26
12 changed files with 394 additions and 61 deletions

View File

@@ -6,7 +6,8 @@ from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.utils import check_block_size_valid
@@ -41,6 +42,7 @@ class RayGPUExecutor(ExecutorBase):
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
@@ -49,6 +51,8 @@ class RayGPUExecutor(ExecutorBase):
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
assert (not speculative_config
), "Speculative decoding not yet supported for RayGPU backend."
assert self.parallel_config.worker_use_ray
placement_group = self.parallel_config.placement_group