[Speculative decoding] Adding configuration object for speculative decoding (#3706)

Co-authored-by: Lily Liu <lilyliupku@gmail.com>
This commit is contained in:
Cade Daniel
2024-04-02 17:40:57 -07:00
committed by GitHub
parent a3c226e7eb
commit 5757d90e26
12 changed files with 394 additions and 61 deletions

View File

@@ -5,7 +5,8 @@ from transformers import PreTrainedTokenizer
import vllm
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats
@@ -52,6 +53,11 @@ class LLMEngine:
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device.
lora_config (Optional): The configuration related to serving multi-LoRA.
vision_language_config (Optional): The configuration related to vision
language models.
speculative_config (Optional): The configuration related to speculative
decoding.
executor_class: The model executor class for managing distributed
execution.
log_stats: Whether to log statistics.
@@ -66,7 +72,8 @@ class LLMEngine:
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional["VisionLanguageConfig"],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
@@ -74,6 +81,7 @@ class LLMEngine:
logger.info(
f"Initializing an LLM engine (v{vllm.__version__}) with config: "
f"model={model_config.model!r}, "
f"speculative_config={speculative_config!r}, "
f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, "
f"revision={model_config.revision}, "
@@ -100,17 +108,23 @@ class LLMEngine:
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.log_stats = log_stats
self._verify_args()
self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
self.seq_counter = Counter()
self.model_executor = executor_class(model_config, cache_config,
parallel_config, scheduler_config,
device_config, lora_config,
vision_language_config)
self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
)
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
@@ -171,30 +185,28 @@ class LLMEngine:
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
device_config = engine_configs[4]
engine_config = engine_args.create_engine_config()
# Initialize the cluster and specify the executor class.
if device_config.device_type == "neuron":
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif device_config.device_type == "cpu":
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif parallel_config.worker_use_ray:
initialize_ray_cluster(parallel_config)
elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
else:
assert parallel_config.world_size == 1, (
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor
# Create the LLM engine.
engine = cls(
*engine_configs,
**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,