Improve configs - ParallelConfig (#16332)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-04-10 18:34:37 +01:00
committed by GitHub
parent c1b57855ec
commit 0c54fc7273
2 changed files with 182 additions and 85 deletions

View File

@@ -4,13 +4,16 @@ import ast
import copy
import enum
import hashlib
import inspect
import json
import sys
import textwrap
import warnings
from collections import Counter
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field, replace
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
replace)
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
@@ -104,6 +107,77 @@ class ModelImpl(str, enum.Enum):
TRANSFORMERS = "transformers"
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
"""
Get any docstrings placed after attribute assignments in a class body.
https://davidism.com/mit-license/
"""
def pairwise(iterable):
"""
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
Can be removed when Python 3.9 support is dropped.
"""
iterator = iter(iterable)
a = next(iterator, None)
for b in iterator:
yield a, b
a = b
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
if not isinstance(cls_node, ast.ClassDef):
raise TypeError("Given object was not a class.")
out = {}
# Consider each pair of nodes.
for a, b in pairwise(cls_node.body):
# Must be an assignment then a constant string.
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
or not isinstance(b, ast.Expr)
or not isinstance(b.value, ast.Constant)
or not isinstance(b.value.value, str)):
continue
doc = inspect.cleandoc(b.value.value)
# An assignment can have multiple targets (a = b = v), but an
# annotated assignment only has one target.
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
for target in targets:
# Must be assigning to a plain name.
if not isinstance(target, ast.Name):
continue
out[target.id] = doc
return out
def config(cls: type[Any]) -> type[Any]:
"""
A decorator that ensures all fields in a dataclass have default values
and that each field has a docstring.
"""
if not is_dataclass(cls):
raise TypeError("The decorated class must be a dataclass.")
attr_docs = get_attr_docs(cls)
for f in fields(cls):
if f.init and f.default is MISSING and f.default_factory is MISSING:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must have a default value."
)
if f.name not in attr_docs:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must have a docstring.")
return cls
class ModelConfig:
"""Configuration for the model.
@@ -1432,61 +1506,77 @@ class LoadConfig:
self.ignore_patterns = ["original/**/*"]
@config
@dataclass
class ParallelConfig:
"""Configuration for the distributed execution."""
pipeline_parallel_size: int = 1 # Number of pipeline parallel groups.
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
data_parallel_size: int = 1 # Number of data parallel groups.
data_parallel_rank: int = 0 # Rank of the data parallel group.
# Local rank of the data parallel group, defaults to global rank.
pipeline_parallel_size: int = 1
"""Number of pipeline parallel groups."""
tensor_parallel_size: int = 1
"""Number of tensor parallel groups."""
data_parallel_size: int = 1
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
data_parallel_rank: int = 0
"""Rank of the data parallel group."""
data_parallel_rank_local: Optional[int] = None
# IP of the data parallel master.
"""Local rank of the data parallel group, defaults to global rank."""
data_parallel_master_ip: str = "127.0.0.1"
data_parallel_master_port: int = 29500 # Port of the data parallel master.
enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers.
"""IP of the data parallel master."""
data_parallel_master_port: int = 29500
"""Port of the data parallel master."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
# Maximum number of multiple batches
# when load model sequentially. To avoid RAM OOM when using tensor
# parallel and large models.
max_parallel_loading_workers: Optional[int] = None
"""Maximum number of parallal loading workers when loading model
sequentially in multiple batches. To avoid RAM OOM when using tensor
parallel and large models."""
# Disable the custom all-reduce kernel and fall back to NCCL.
disable_custom_all_reduce: bool = False
"""Disable the custom all-reduce kernel and fall back to NCCL."""
# Config for the tokenizer pool. If None, will use synchronous tokenization.
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None
"""Config for the tokenizer pool. If None, will use synchronous
tokenization."""
# Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
# ray distributed model workers placement group.
placement_group: Optional["PlacementGroup"] = None
"""ray distributed model workers placement group."""
# Backend to use for distributed model
# workers, either "ray" or "mp" (multiprocessing). If the product
# of pipeline_parallel_size and tensor_parallel_size is less than
# or equal to the number of GPUs available, "mp" will be used to
# keep processing on a single host. Otherwise, this will default
# to "ray" if Ray is installed and fail otherwise. Note that tpu
# and hpu only support Ray for distributed inference.
distributed_executor_backend: Optional[Union[str,
type["ExecutorBase"]]] = None
"""Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If the product
of pipeline_parallel_size and tensor_parallel_size is less than
or equal to the number of GPUs available, "mp" will be used to
keep processing on a single host. Otherwise, this will default
to "ray" if Ray is installed and fail otherwise. Note that tpu
and hpu only support Ray for distributed inference."""
# the full name of the worker class to use. If "auto", the worker class
# will be determined based on the platform.
worker_cls: str = "auto"
"""The full name of the worker class to use. If "auto", the worker class
will be determined based on the platform."""
sd_worker_cls: str = "auto"
"""The full name of the worker class to use for speculative decofing.
If "auto", the worker class will be determined based on the platform."""
worker_extension_cls: str = ""
"""The full name of the worker extension class to use. The worker extension
class is dynamically inherited by the worker class. This is used to inject
new attributes and methods to the worker class for use in collective_rpc
calls."""
# world_size is TPxPP, it affects the number of workers we create.
world_size: int = field(init=False)
# world_size_across_dp is TPxPPxDP, it is the size of the world
# including data parallelism.
"""world_size is TPxPP, it affects the number of workers we create."""
world_size_across_dp: int = field(init=False)
"""world_size_across_dp is TPxPPxDP, it is the size of the world
including data parallelism."""
rank: int = 0
"""Global rank in distributed setup."""
def get_next_dp_init_port(self) -> int:
"""