[Core] Asynchronous Output Processor (#7049)

Co-authored-by: Alexander Matveev <alexm@neuralmagic.com>
This commit is contained in:
Megha Agarwal
2024-08-26 20:53:20 -07:00
committed by GitHub
parent 015e6cc252
commit 2eedede875
21 changed files with 652 additions and 214 deletions

View File

@@ -140,6 +140,7 @@ class ModelConfig:
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
) -> None:
self.model = model
self.tokenizer = tokenizer
@@ -172,6 +173,7 @@ class ModelConfig:
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
# Choose a default enforce_eager value if the user did not specify
# a value (enforce_eager is None)
@@ -326,6 +328,49 @@ class ModelConfig:
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len)
def verify_async_output_proc(self, parallel_config, speculative_config,
device_config) -> None:
if not self.use_async_output_proc:
# Nothing to check
return
if parallel_config.pipeline_parallel_size > 1:
logger.warning("Async output processing can not be enabled "
"with pipeline parallel")
self.use_async_output_proc = False
return
if device_config.device_type != "cuda":
logger.warning(
"Async output processing is only supported for CUDA."
" Disabling it for other platforms.")
self.use_async_output_proc = False
return
if envs.VLLM_USE_RAY_SPMD_WORKER:
logger.warning(
"Async output processing can not be enabled with ray spmd")
self.use_async_output_proc = False
return
if self.enforce_eager:
logger.warning(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
"processor cannot be used")
self.use_async_output_proc = not self.enforce_eager
return
# Async postprocessor is not necessary with embedding mode
# since there is no token generation
if self.embedding_mode:
self.use_async_output_proc = False
if speculative_config:
logger.warning("Async output processing is not supported with"
" speculative decoding currently.")
self.use_async_output_proc = False
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
@@ -358,6 +403,11 @@ class ModelConfig:
"fallback to the eager mode.")
self.enforce_eager = True
if pipeline_parallel_size > 1 and self.use_async_output_proc:
logger.warning("Async output processor is not supported with "
"pipeline parallelism currently. Disabling it.")
self.use_async_output_proc = False
def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled."""
@@ -1769,6 +1819,9 @@ class EngineConfig:
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
"""
self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config,
self.device_config)
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)