[platforms] absorb worker cls difference into platforms folder (#10555)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
236
vllm/config.py
236
vllm/config.py
@@ -926,56 +926,56 @@ class LoadConfig:
|
||||
f"{rocm_supported_load_format}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelConfig:
|
||||
"""Configuration for the distributed execution.
|
||||
"""Configuration for the distributed execution."""
|
||||
|
||||
Args:
|
||||
pipeline_parallel_size: Number of pipeline parallel groups.
|
||||
tensor_parallel_size: Number of tensor parallel groups.
|
||||
worker_use_ray: Deprecated, use distributed_executor_backend instead.
|
||||
max_parallel_loading_workers: Maximum number of multiple batches
|
||||
when load model sequentially. To avoid RAM OOM when using tensor
|
||||
parallel and large models.
|
||||
disable_custom_all_reduce: Disable the custom all-reduce kernel and
|
||||
fall back to NCCL.
|
||||
tokenizer_pool_config: Config for the tokenizer pool.
|
||||
If None, will use synchronous tokenization.
|
||||
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
|
||||
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
|
||||
placement_group: ray distributed model workers placement group.
|
||||
distributed_executor_backend: Backend to use for distributed model
|
||||
workers, either "ray" or "mp" (multiprocessing). If the product
|
||||
of pipeline_parallel_size and tensor_parallel_size is less than
|
||||
or equal to the number of GPUs available, "mp" will be used to
|
||||
keep processing on a single host. Otherwise, this will default
|
||||
to "ray" if Ray is installed and fail otherwise. Note that tpu
|
||||
and hpu only support Ray for distributed inference.
|
||||
"""
|
||||
pipeline_parallel_size: int = 1 # Number of pipeline parallel groups.
|
||||
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
worker_use_ray: Optional[bool] = None,
|
||||
max_parallel_loading_workers: Optional[int] = None,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
|
||||
ray_workers_use_nsight: bool = False,
|
||||
placement_group: Optional["PlacementGroup"] = None,
|
||||
distributed_executor_backend: Optional[Union[
|
||||
str, Type["ExecutorBase"]]] = None,
|
||||
) -> None:
|
||||
self.pipeline_parallel_size = pipeline_parallel_size
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.distributed_executor_backend = distributed_executor_backend
|
||||
self.max_parallel_loading_workers = max_parallel_loading_workers
|
||||
self.disable_custom_all_reduce = disable_custom_all_reduce
|
||||
self.tokenizer_pool_config = tokenizer_pool_config
|
||||
self.ray_workers_use_nsight = ray_workers_use_nsight
|
||||
self.placement_group = placement_group
|
||||
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
|
||||
# Deprecated, use distributed_executor_backend instead.
|
||||
worker_use_ray: Optional[bool] = None
|
||||
|
||||
if worker_use_ray:
|
||||
# Maximum number of multiple batches
|
||||
# when load model sequentially. To avoid RAM OOM when using tensor
|
||||
# parallel and large models.
|
||||
max_parallel_loading_workers: Optional[int] = None
|
||||
|
||||
# Disable the custom all-reduce kernel and fall back to NCCL.
|
||||
disable_custom_all_reduce: bool = False
|
||||
|
||||
# Config for the tokenizer pool. If None, will use synchronous tokenization.
|
||||
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None
|
||||
|
||||
# Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
|
||||
ray_workers_use_nsight: bool = False
|
||||
|
||||
# ray distributed model workers placement group.
|
||||
placement_group: Optional["PlacementGroup"] = None
|
||||
|
||||
# Backend to use for distributed model
|
||||
# workers, either "ray" or "mp" (multiprocessing). If the product
|
||||
# of pipeline_parallel_size and tensor_parallel_size is less than
|
||||
# or equal to the number of GPUs available, "mp" will be used to
|
||||
# keep processing on a single host. Otherwise, this will default
|
||||
# to "ray" if Ray is installed and fail otherwise. Note that tpu
|
||||
# and hpu only support Ray for distributed inference.
|
||||
distributed_executor_backend: Optional[Union[str,
|
||||
Type["ExecutorBase"]]] = None
|
||||
|
||||
# the full name of the worker class to use. If "auto", the worker class
|
||||
# will be determined based on the platform.
|
||||
worker_cls: str = "auto"
|
||||
|
||||
world_size: int = field(init=False)
|
||||
|
||||
rank: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.world_size = self.pipeline_parallel_size * \
|
||||
self.tensor_parallel_size
|
||||
|
||||
if self.worker_use_ray:
|
||||
if self.distributed_executor_backend is None:
|
||||
self.distributed_executor_backend = "ray"
|
||||
elif not self.use_ray:
|
||||
@@ -1026,7 +1026,6 @@ class ParallelConfig:
|
||||
backend)
|
||||
|
||||
self._verify_args()
|
||||
self.rank: int = 0
|
||||
|
||||
@property
|
||||
def use_ray(self) -> bool:
|
||||
@@ -1059,100 +1058,97 @@ class ParallelConfig:
|
||||
"run with Ray.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerConfig:
|
||||
"""Scheduler configuration.
|
||||
"""Scheduler configuration."""
|
||||
|
||||
Args:
|
||||
task: The task to use the model for.
|
||||
max_num_batched_tokens: Maximum number of tokens to be processed in
|
||||
a single iteration.
|
||||
max_num_seqs: Maximum number of sequences to be processed in a single
|
||||
iteration.
|
||||
max_model_len: Maximum length of a sequence (including prompt
|
||||
and generated text).
|
||||
num_lookahead_slots: The number of slots to allocate per sequence per
|
||||
step, beyond the known token ids. This is used in speculative
|
||||
decoding to store KV activations of tokens which may or may not be
|
||||
accepted.
|
||||
delay_factor: Apply a delay (of delay factor multiplied by previous
|
||||
prompt latency) before scheduling next prompt.
|
||||
enable_chunked_prefill: If True, prefill requests can be chunked based
|
||||
on the remaining max_num_batched_tokens.
|
||||
preemption_mode: Whether to perform preemption by swapping or
|
||||
recomputation. If not specified, we determine the mode as follows:
|
||||
We use recomputation by default since it incurs lower overhead than
|
||||
swapping. However, when the sequence group has multiple sequences
|
||||
(e.g., beam search), recomputation is not currently supported. In
|
||||
such a case, we use swapping instead.
|
||||
send_delta_data: Private API. If used, scheduler sends delta data to
|
||||
workers instead of an entire data. It should be enabled only
|
||||
when SPMD worker architecture is enabled. I.e.,
|
||||
VLLM_USE_RAY_SPMD_WORKER=1
|
||||
policy: The scheduling policy to use. "fcfs" (default) or "priority".
|
||||
"""
|
||||
task: str = "generate" # The task to use the model for.
|
||||
|
||||
def __init__(self,
|
||||
task: _Task,
|
||||
max_num_batched_tokens: Optional[int],
|
||||
max_num_seqs: int,
|
||||
max_model_len: int,
|
||||
num_lookahead_slots: int = 0,
|
||||
delay_factor: float = 0.0,
|
||||
enable_chunked_prefill: bool = False,
|
||||
is_multimodal_model: bool = False,
|
||||
preemption_mode: Optional[str] = None,
|
||||
num_scheduler_steps: int = 1,
|
||||
multi_step_stream_outputs: bool = False,
|
||||
send_delta_data: bool = False,
|
||||
policy: str = "fcfs") -> None:
|
||||
if max_num_batched_tokens is None:
|
||||
if enable_chunked_prefill:
|
||||
if num_scheduler_steps > 1:
|
||||
# Maximum number of tokens to be processed in a single iteration.
|
||||
max_num_batched_tokens: int = field(default=None) # type: ignore
|
||||
|
||||
# Maximum number of sequences to be processed in a single iteration.
|
||||
max_num_seqs: int = 128
|
||||
|
||||
# Maximum length of a sequence (including prompt and generated text).
|
||||
max_model_len: int = 8192
|
||||
|
||||
# The number of slots to allocate per sequence per
|
||||
# step, beyond the known token ids. This is used in speculative
|
||||
# decoding to store KV activations of tokens which may or may not be
|
||||
# accepted.
|
||||
num_lookahead_slots: int = 0
|
||||
|
||||
# Apply a delay (of delay factor multiplied by previous
|
||||
# prompt latency) before scheduling next prompt.
|
||||
delay_factor: float = 0.0
|
||||
|
||||
# If True, prefill requests can be chunked based
|
||||
# on the remaining max_num_batched_tokens.
|
||||
enable_chunked_prefill: bool = False
|
||||
|
||||
is_multimodal_model: bool = False
|
||||
|
||||
# Whether to perform preemption by swapping or
|
||||
# recomputation. If not specified, we determine the mode as follows:
|
||||
# We use recomputation by default since it incurs lower overhead than
|
||||
# swapping. However, when the sequence group has multiple sequences
|
||||
# (e.g., beam search), recomputation is not currently supported. In
|
||||
# such a case, we use swapping instead.
|
||||
preemption_mode: Optional[str] = None
|
||||
|
||||
num_scheduler_steps: int = 1
|
||||
|
||||
multi_step_stream_outputs: bool = False
|
||||
|
||||
# Private API. If used, scheduler sends delta data to
|
||||
# workers instead of an entire data. It should be enabled only
|
||||
# when SPMD worker architecture is enabled. I.e.,
|
||||
# VLLM_USE_RAY_SPMD_WORKER=1
|
||||
send_delta_data: bool = False
|
||||
|
||||
# The scheduling policy to use. "fcfs" (default) or "priority".
|
||||
policy: str = "fcfs"
|
||||
|
||||
chunked_prefill_enabled: bool = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.max_num_batched_tokens is None:
|
||||
if self.enable_chunked_prefill:
|
||||
if self.num_scheduler_steps > 1:
|
||||
# Multi-step Chunked-Prefill doesn't allow prompt-chunking
|
||||
# for now. Have max_num_batched_tokens set to max_model_len
|
||||
# so we don't reject sequences on account of a short
|
||||
# max_num_batched_tokens.
|
||||
max_num_batched_tokens = max(max_model_len, 2048)
|
||||
self.max_num_batched_tokens = max(self.max_model_len, 2048)
|
||||
else:
|
||||
# It is the values that have the best balance between ITL
|
||||
# and TTFT on A100. Note it is not optimized for throughput.
|
||||
max_num_batched_tokens = 512
|
||||
self.max_num_batched_tokens = 512
|
||||
else:
|
||||
# If max_model_len is too short, use 2048 as the default value
|
||||
# for higher throughput.
|
||||
max_num_batched_tokens = max(max_model_len, 2048)
|
||||
self.max_num_batched_tokens = max(self.max_model_len, 2048)
|
||||
|
||||
if task == "embedding":
|
||||
if self.task == "embedding":
|
||||
# For embedding, choose specific value for higher throughput
|
||||
max_num_batched_tokens = max(
|
||||
max_num_batched_tokens,
|
||||
self.max_num_batched_tokens = max(
|
||||
self.max_num_batched_tokens,
|
||||
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
if is_multimodal_model:
|
||||
if self.is_multimodal_model:
|
||||
# The value needs to be at least the number of multimodal tokens
|
||||
max_num_batched_tokens = max(
|
||||
max_num_batched_tokens,
|
||||
self.max_num_batched_tokens = max(
|
||||
self.max_num_batched_tokens,
|
||||
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
|
||||
if enable_chunked_prefill:
|
||||
if self.enable_chunked_prefill:
|
||||
logger.info(
|
||||
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
|
||||
self.max_num_batched_tokens)
|
||||
|
||||
self.task: Final = task
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_model_len = max_model_len
|
||||
self.num_lookahead_slots = num_lookahead_slots
|
||||
self.delay_factor = delay_factor
|
||||
self.chunked_prefill_enabled = enable_chunked_prefill
|
||||
self.preemption_mode = preemption_mode
|
||||
self.num_scheduler_steps = num_scheduler_steps
|
||||
self.multi_step_stream_outputs = multi_step_stream_outputs
|
||||
self.send_delta_data = send_delta_data
|
||||
self.policy = policy
|
||||
self.chunked_prefill_enabled = self.enable_chunked_prefill
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
@@ -2293,10 +2289,10 @@ class VllmConfig:
|
||||
|
||||
model_config: ModelConfig = field(default=None, init=True) # type: ignore
|
||||
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
|
||||
parallel_config: ParallelConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
scheduler_config: SchedulerConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
parallel_config: ParallelConfig = field(default_factory=ParallelConfig,
|
||||
init=True)
|
||||
scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig,
|
||||
init=True)
|
||||
device_config: DeviceConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
load_config: LoadConfig = field(default=None, init=True) # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user