Improve configs - TokenizerPoolConfig + DeviceConfig (#16603)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-04-17 12:19:42 +01:00
committed by GitHub
parent 99ed526101
commit d27ea94034
3 changed files with 136 additions and 81 deletions

View File

@@ -182,6 +182,23 @@ def config(cls: type[Config]) -> type[Config]:
return cls
def get_field(cls: type[Config], name: str) -> Field:
"""Get the default factory field of a dataclass by name. Used for getting
default factory fields in `EngineArgs`."""
if not is_dataclass(cls):
raise TypeError("The given class is not a dataclass.")
cls_fields = {f.name: f for f in fields(cls)}
if name not in cls_fields:
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
named_field: Field = cls_fields.get(name)
if (default_factory := named_field.default_factory) is not MISSING:
return field(default_factory=default_factory)
if (default := named_field.default) is not MISSING:
return field(default=default)
raise ValueError(
f"{cls.__name__}.{name} must have a default value or default factory.")
class ModelConfig:
"""Configuration for the model.
@@ -1364,20 +1381,26 @@ class CacheConfig:
logger.warning("Possibly too large swap space. %s", msg)
PoolType = Literal["ray"]
@config
@dataclass
class TokenizerPoolConfig:
"""Configuration for the tokenizer pool.
"""Configuration for the tokenizer pool."""
Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type.
"""
pool_size: int
pool_type: Union[str, type["BaseTokenizerGroup"]]
extra_config: dict
pool_size: int = 0
"""Number of tokenizer workers in the pool to use for asynchronous
tokenization. If 0, will use synchronous tokenization."""
pool_type: Union[PoolType, type["BaseTokenizerGroup"]] = "ray"
"""Type of tokenizer pool to use for asynchronous tokenization. Ignored if
tokenizer_pool_size is 0."""
extra_config: dict = field(default_factory=dict)
"""Additional config for the pool. The way the config will be used depends
on the pool type. This should be a JSON string that will be parsed into a
dictionary. Ignored if tokenizer_pool_size is 0."""
def compute_hash(self) -> str:
"""
@@ -1408,7 +1431,7 @@ class TokenizerPoolConfig:
@classmethod
def create_config(
cls, tokenizer_pool_size: int,
tokenizer_pool_type: Union[str, type["BaseTokenizerGroup"]],
tokenizer_pool_type: Union[PoolType, type["BaseTokenizerGroup"]],
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
@@ -1483,7 +1506,7 @@ class LoadConfig:
download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
model_loader_extra_config: Optional[Union[str, dict]] = None
model_loader_extra_config: dict = field(default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format. This should be a JSON string that
will be parsed into a dictionary."""
@@ -1514,10 +1537,6 @@ class LoadConfig:
return hash_str
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(
model_loader_extra_config)
if isinstance(self.load_format, str):
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
@@ -2029,9 +2048,19 @@ class SchedulerConfig:
return self.num_scheduler_steps > 1
Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"]
@config
@dataclass
class DeviceConfig:
device: Optional[torch.device]
device_type: str
"""Configuration for the device to use for vLLM execution."""
device: Union[Device, torch.device] = "auto"
"""Device type for vLLM execution."""
device_type: str = field(init=False)
"""Device type from the current platform. This is set in
`__post_init__`."""
def compute_hash(self) -> str:
"""
@@ -2053,8 +2082,8 @@ class DeviceConfig:
usedforsecurity=False).hexdigest()
return hash_str
def __init__(self, device: str = "auto") -> None:
if device == "auto":
def __post_init__(self):
if self.device == "auto":
# Automated device type detection
from vllm.platforms import current_platform
self.device_type = current_platform.device_type
@@ -2065,7 +2094,7 @@ class DeviceConfig:
"to turn on verbose logging to help debug the issue.")
else:
# Device type is assigned explicitly
self.device_type = device
self.device_type = self.device
# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]: