[Core] Increase default max_num_batched_tokens for multimodal models (#8028)
This commit is contained in:
@@ -32,6 +32,7 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||||
|
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096
|
||||||
|
|
||||||
_PP_SUPPORTED_MODELS = [
|
_PP_SUPPORTED_MODELS = [
|
||||||
"AquilaModel",
|
"AquilaModel",
|
||||||
@@ -571,6 +572,10 @@ class ModelConfig:
|
|||||||
"""Extract the embedding model flag."""
|
"""Extract the embedding model flag."""
|
||||||
return self.embedding_mode
|
return self.embedding_mode
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_multimodal_model(self) -> bool:
|
||||||
|
return self.multimodal_config is not None
|
||||||
|
|
||||||
|
|
||||||
class CacheConfig:
|
class CacheConfig:
|
||||||
"""Configuration for the KV cache.
|
"""Configuration for the KV cache.
|
||||||
@@ -947,25 +952,36 @@ class SchedulerConfig:
|
|||||||
num_lookahead_slots: int = 0,
|
num_lookahead_slots: int = 0,
|
||||||
delay_factor: float = 0.0,
|
delay_factor: float = 0.0,
|
||||||
enable_chunked_prefill: bool = False,
|
enable_chunked_prefill: bool = False,
|
||||||
embedding_mode: Optional[bool] = False,
|
embedding_mode: bool = False,
|
||||||
|
is_multimodal_model: bool = False,
|
||||||
preemption_mode: Optional[str] = None,
|
preemption_mode: Optional[str] = None,
|
||||||
num_scheduler_steps: int = 1,
|
num_scheduler_steps: int = 1,
|
||||||
send_delta_data: bool = False) -> None:
|
send_delta_data: bool = False) -> None:
|
||||||
if max_num_batched_tokens is not None:
|
if max_num_batched_tokens is None:
|
||||||
self.max_num_batched_tokens = max_num_batched_tokens
|
|
||||||
else:
|
|
||||||
if enable_chunked_prefill:
|
if enable_chunked_prefill:
|
||||||
# It is the values that have the best balance between ITL
|
# It is the values that have the best balance between ITL
|
||||||
# and TTFT on A100. Note it is not optimized for throughput.
|
# and TTFT on A100. Note it is not optimized for throughput.
|
||||||
self.max_num_batched_tokens = 512
|
max_num_batched_tokens = 512
|
||||||
elif embedding_mode:
|
|
||||||
# For embedding, choose specific value for higher throughput
|
|
||||||
self.max_num_batched_tokens = max(
|
|
||||||
max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS)
|
|
||||||
else:
|
else:
|
||||||
# If max_model_len is too short, use 2048 as the default value
|
# If max_model_len is too short, use 2048 as the default value
|
||||||
# for higher throughput.
|
# for higher throughput.
|
||||||
self.max_num_batched_tokens = max(max_model_len, 2048)
|
max_num_batched_tokens = max(max_model_len, 2048)
|
||||||
|
|
||||||
|
if embedding_mode:
|
||||||
|
# For embedding, choose specific value for higher throughput
|
||||||
|
max_num_batched_tokens = max(
|
||||||
|
max_num_batched_tokens,
|
||||||
|
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||||
|
)
|
||||||
|
if is_multimodal_model:
|
||||||
|
# The value needs to be at least the number of multimodal tokens
|
||||||
|
max_num_batched_tokens = max(
|
||||||
|
max_num_batched_tokens,
|
||||||
|
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
|
|
||||||
if enable_chunked_prefill:
|
if enable_chunked_prefill:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
|
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
|
||||||
|
|||||||
@@ -921,6 +921,7 @@ class EngineArgs:
|
|||||||
delay_factor=self.scheduler_delay_factor,
|
delay_factor=self.scheduler_delay_factor,
|
||||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||||
embedding_mode=model_config.embedding_mode,
|
embedding_mode=model_config.embedding_mode,
|
||||||
|
is_multimodal_model=model_config.is_multimodal_model,
|
||||||
preemption_mode=self.preemption_mode,
|
preemption_mode=self.preemption_mode,
|
||||||
num_scheduler_steps=self.num_scheduler_steps,
|
num_scheduler_steps=self.num_scheduler_steps,
|
||||||
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
|
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
|
||||||
|
|||||||
@@ -2019,7 +2019,7 @@ class LLMEngine:
|
|||||||
if prompt_ids is None or len(prompt_ids) == 0:
|
if prompt_ids is None or len(prompt_ids) == 0:
|
||||||
raise ValueError("Prompt cannot be empty")
|
raise ValueError("Prompt cannot be empty")
|
||||||
|
|
||||||
if self.model_config.multimodal_config is not None:
|
if self.model_config.is_multimodal_model:
|
||||||
max_prompt_len = self.model_config.max_model_len
|
max_prompt_len = self.model_config.max_model_len
|
||||||
|
|
||||||
if len(prompt_ids) > max_prompt_len:
|
if len(prompt_ids) > max_prompt_len:
|
||||||
@@ -2030,3 +2030,7 @@ class LLMEngine:
|
|||||||
"number of text tokens plus multimodal tokens. For image "
|
"number of text tokens plus multimodal tokens. For image "
|
||||||
"inputs, the number of image tokens depends on the number "
|
"inputs, the number of image tokens depends on the number "
|
||||||
"of images, and possibly their aspect ratios as well.")
|
"of images, and possibly their aspect ratios as well.")
|
||||||
|
|
||||||
|
# TODO: Find out how many placeholder tokens are there so we can
|
||||||
|
# check that chunked prefill does not truncate them
|
||||||
|
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario(
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
|
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
|
||||||
|
|
||||||
if enc_dec_mr.model_config.multimodal_config is not None:
|
if enc_dec_mr.model_config.is_multimodal_model:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])
|
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user