[Speculative decoding] Adding configuration object for speculative decoding (#3706)
Co-authored-by: Lily Liu <lilyliupku@gmail.com>
This commit is contained in:
@@ -5,7 +5,8 @@ from transformers import PreTrainedTokenizer
|
||||
|
||||
import vllm
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.metrics import StatLogger, Stats
|
||||
@@ -52,6 +53,11 @@ class LLMEngine:
|
||||
parallel_config: The configuration related to distributed execution.
|
||||
scheduler_config: The configuration related to the request scheduler.
|
||||
device_config: The configuration related to the device.
|
||||
lora_config (Optional): The configuration related to serving multi-LoRA.
|
||||
vision_language_config (Optional): The configuration related to vision
|
||||
language models.
|
||||
speculative_config (Optional): The configuration related to speculative
|
||||
decoding.
|
||||
executor_class: The model executor class for managing distributed
|
||||
execution.
|
||||
log_stats: Whether to log statistics.
|
||||
@@ -66,7 +72,8 @@ class LLMEngine:
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional["VisionLanguageConfig"],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
executor_class: Type[ExecutorBase],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
@@ -74,6 +81,7 @@ class LLMEngine:
|
||||
logger.info(
|
||||
f"Initializing an LLM engine (v{vllm.__version__}) with config: "
|
||||
f"model={model_config.model!r}, "
|
||||
f"speculative_config={speculative_config!r}, "
|
||||
f"tokenizer={model_config.tokenizer!r}, "
|
||||
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
||||
f"revision={model_config.revision}, "
|
||||
@@ -100,17 +108,23 @@ class LLMEngine:
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.speculative_config = speculative_config
|
||||
self.log_stats = log_stats
|
||||
self._verify_args()
|
||||
|
||||
self._init_tokenizer()
|
||||
self.detokenizer = Detokenizer(self.tokenizer)
|
||||
self.seq_counter = Counter()
|
||||
|
||||
self.model_executor = executor_class(model_config, cache_config,
|
||||
parallel_config, scheduler_config,
|
||||
device_config, lora_config,
|
||||
vision_language_config)
|
||||
self.model_executor = executor_class(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
speculative_config=speculative_config,
|
||||
)
|
||||
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
if is_usage_stats_enabled():
|
||||
@@ -171,30 +185,28 @@ class LLMEngine:
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
device_config = engine_configs[4]
|
||||
engine_config = engine_args.create_engine_config()
|
||||
|
||||
# Initialize the cluster and specify the executor class.
|
||||
if device_config.device_type == "neuron":
|
||||
if engine_config.device_config.device_type == "neuron":
|
||||
from vllm.executor.neuron_executor import NeuronExecutor
|
||||
executor_class = NeuronExecutor
|
||||
elif device_config.device_type == "cpu":
|
||||
elif engine_config.device_config.device_type == "cpu":
|
||||
from vllm.executor.cpu_executor import CPUExecutor
|
||||
executor_class = CPUExecutor
|
||||
elif parallel_config.worker_use_ray:
|
||||
initialize_ray_cluster(parallel_config)
|
||||
elif engine_config.parallel_config.worker_use_ray:
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutor
|
||||
executor_class = RayGPUExecutor
|
||||
else:
|
||||
assert parallel_config.world_size == 1, (
|
||||
assert engine_config.parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
executor_class = GPUExecutor
|
||||
|
||||
# Create the LLM engine.
|
||||
engine = cls(
|
||||
*engine_configs,
|
||||
**engine_config.to_dict(),
|
||||
executor_class=executor_class,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
|
||||
Reference in New Issue
Block a user