[Core] Increase default max_num_batched_tokens for multimodal models (#8028)

This commit is contained in:
Cyrus Leung
2024-08-30 23:20:34 +08:00
committed by GitHub
parent f97be32d1d
commit 98cef6a227
4 changed files with 33 additions and 12 deletions

View File

@@ -32,6 +32,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096
_PP_SUPPORTED_MODELS = [ _PP_SUPPORTED_MODELS = [
"AquilaModel", "AquilaModel",
@@ -571,6 +572,10 @@ class ModelConfig:
"""Extract the embedding model flag.""" """Extract the embedding model flag."""
return self.embedding_mode return self.embedding_mode
@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None
class CacheConfig: class CacheConfig:
"""Configuration for the KV cache. """Configuration for the KV cache.
@@ -947,25 +952,36 @@ class SchedulerConfig:
num_lookahead_slots: int = 0, num_lookahead_slots: int = 0,
delay_factor: float = 0.0, delay_factor: float = 0.0,
enable_chunked_prefill: bool = False, enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False, embedding_mode: bool = False,
is_multimodal_model: bool = False,
preemption_mode: Optional[str] = None, preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1, num_scheduler_steps: int = 1,
send_delta_data: bool = False) -> None: send_delta_data: bool = False) -> None:
if max_num_batched_tokens is not None: if max_num_batched_tokens is None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
if enable_chunked_prefill: if enable_chunked_prefill:
# It is the values that have the best balance between ITL # It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput. # and TTFT on A100. Note it is not optimized for throughput.
self.max_num_batched_tokens = 512 max_num_batched_tokens = 512
elif embedding_mode:
# For embedding, choose specific value for higher throughput
self.max_num_batched_tokens = max(
max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS)
else: else:
# If max_model_len is too short, use 2048 as the default value # If max_model_len is too short, use 2048 as the default value
# for higher throughput. # for higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048) max_num_batched_tokens = max(max_model_len, 2048)
if embedding_mode:
# For embedding, choose specific value for higher throughput
max_num_batched_tokens = max(
max_num_batched_tokens,
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if is_multimodal_model:
# The value needs to be at least the number of multimodal tokens
max_num_batched_tokens = max(
max_num_batched_tokens,
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)
self.max_num_batched_tokens = max_num_batched_tokens
if enable_chunked_prefill: if enable_chunked_prefill:
logger.info( logger.info(
"Chunked prefill is enabled with max_num_batched_tokens=%d.", "Chunked prefill is enabled with max_num_batched_tokens=%d.",

View File

@@ -921,6 +921,7 @@ class EngineArgs:
delay_factor=self.scheduler_delay_factor, delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill, enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode, embedding_mode=model_config.embedding_mode,
is_multimodal_model=model_config.is_multimodal_model,
preemption_mode=self.preemption_mode, preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps, num_scheduler_steps=self.num_scheduler_steps,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER

View File

@@ -2019,7 +2019,7 @@ class LLMEngine:
if prompt_ids is None or len(prompt_ids) == 0: if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty") raise ValueError("Prompt cannot be empty")
if self.model_config.multimodal_config is not None: if self.model_config.is_multimodal_model:
max_prompt_len = self.model_config.max_model_len max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len: if len(prompt_ids) > max_prompt_len:
@@ -2030,3 +2030,7 @@ class LLMEngine:
"number of text tokens plus multimodal tokens. For image " "number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number " "inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.") "of images, and possibly their aspect ratios as well.")
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens

View File

@@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario(
raise NotImplementedError( raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
if enc_dec_mr.model_config.multimodal_config is not None: if enc_dec_mr.model_config.is_multimodal_model:
raise NotImplementedError( raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM']) STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])