[1/N] pass the complete config from engine to executor (#9933)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-01 13:51:57 -07:00
committed by GitHub
parent 598b6d7b07
commit 18bd7587b7
6 changed files with 64 additions and 136 deletions

View File

@@ -13,11 +13,8 @@ import torch
from typing_extensions import TypeIs, TypeVar
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
@@ -222,17 +219,7 @@ class LLMEngine:
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
vllm_config: EngineConfig,
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
@@ -240,6 +227,22 @@ class LLMEngine:
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
# TODO: remove the local variables and use self.* throughout the class.
model_config = self.model_config = vllm_config.model_config
cache_config = self.cache_config = vllm_config.cache_config
lora_config = self.lora_config = vllm_config.lora_config
parallel_config = self.parallel_config = vllm_config.parallel_config
scheduler_config = self.scheduler_config = vllm_config.scheduler_config
device_config = self.device_config = vllm_config.device_config
speculative_config = self.speculative_config = vllm_config.speculative_config # noqa
load_config = self.load_config = vllm_config.load_config
decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)
logger.info(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
@@ -340,18 +343,7 @@ class LLMEngine:
self.input_processor = input_registry.create_input_processor(
model_config)
self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
observability_config=self.observability_config,
)
self.model_executor = executor_class(vllm_config=vllm_config, )
if self.model_config.task != "embedding":
self._initialize_kv_caches()
@@ -582,7 +574,7 @@ class LLMEngine:
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine.
engine = cls(
**engine_config.to_dict(),
vllm_config=engine_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,