Improve configs - TokenizerPoolConfig + DeviceConfig (#16603)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -182,6 +182,23 @@ def config(cls: type[Config]) -> type[Config]:
|
||||
return cls
|
||||
|
||||
|
||||
def get_field(cls: type[Config], name: str) -> Field:
|
||||
"""Get the default factory field of a dataclass by name. Used for getting
|
||||
default factory fields in `EngineArgs`."""
|
||||
if not is_dataclass(cls):
|
||||
raise TypeError("The given class is not a dataclass.")
|
||||
cls_fields = {f.name: f for f in fields(cls)}
|
||||
if name not in cls_fields:
|
||||
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
|
||||
named_field: Field = cls_fields.get(name)
|
||||
if (default_factory := named_field.default_factory) is not MISSING:
|
||||
return field(default_factory=default_factory)
|
||||
if (default := named_field.default) is not MISSING:
|
||||
return field(default=default)
|
||||
raise ValueError(
|
||||
f"{cls.__name__}.{name} must have a default value or default factory.")
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""Configuration for the model.
|
||||
|
||||
@@ -1364,20 +1381,26 @@ class CacheConfig:
|
||||
logger.warning("Possibly too large swap space. %s", msg)
|
||||
|
||||
|
||||
PoolType = Literal["ray"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class TokenizerPoolConfig:
|
||||
"""Configuration for the tokenizer pool.
|
||||
"""Configuration for the tokenizer pool."""
|
||||
|
||||
Args:
|
||||
pool_size: Number of tokenizer workers in the pool.
|
||||
pool_type: Type of the pool.
|
||||
extra_config: Additional config for the pool.
|
||||
The way the config will be used depends on the
|
||||
pool type.
|
||||
"""
|
||||
pool_size: int
|
||||
pool_type: Union[str, type["BaseTokenizerGroup"]]
|
||||
extra_config: dict
|
||||
pool_size: int = 0
|
||||
"""Number of tokenizer workers in the pool to use for asynchronous
|
||||
tokenization. If 0, will use synchronous tokenization."""
|
||||
|
||||
pool_type: Union[PoolType, type["BaseTokenizerGroup"]] = "ray"
|
||||
"""Type of tokenizer pool to use for asynchronous tokenization. Ignored if
|
||||
tokenizer_pool_size is 0."""
|
||||
|
||||
extra_config: dict = field(default_factory=dict)
|
||||
"""Additional config for the pool. The way the config will be used depends
|
||||
on the pool type. This should be a JSON string that will be parsed into a
|
||||
dictionary. Ignored if tokenizer_pool_size is 0."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@@ -1408,7 +1431,7 @@ class TokenizerPoolConfig:
|
||||
@classmethod
|
||||
def create_config(
|
||||
cls, tokenizer_pool_size: int,
|
||||
tokenizer_pool_type: Union[str, type["BaseTokenizerGroup"]],
|
||||
tokenizer_pool_type: Union[PoolType, type["BaseTokenizerGroup"]],
|
||||
tokenizer_pool_extra_config: Optional[Union[str, dict]]
|
||||
) -> Optional["TokenizerPoolConfig"]:
|
||||
"""Create a TokenizerPoolConfig from the given parameters.
|
||||
@@ -1483,7 +1506,7 @@ class LoadConfig:
|
||||
download_dir: Optional[str] = None
|
||||
"""Directory to download and load the weights, default to the default
|
||||
cache directory of Hugging Face."""
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = None
|
||||
model_loader_extra_config: dict = field(default_factory=dict)
|
||||
"""Extra config for model loader. This will be passed to the model loader
|
||||
corresponding to the chosen load_format. This should be a JSON string that
|
||||
will be parsed into a dictionary."""
|
||||
@@ -1514,10 +1537,6 @@ class LoadConfig:
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||
if isinstance(model_loader_extra_config, str):
|
||||
self.model_loader_extra_config = json.loads(
|
||||
model_loader_extra_config)
|
||||
if isinstance(self.load_format, str):
|
||||
load_format = self.load_format.lower()
|
||||
self.load_format = LoadFormat(load_format)
|
||||
@@ -2029,9 +2048,19 @@ class SchedulerConfig:
|
||||
return self.num_scheduler_steps > 1
|
||||
|
||||
|
||||
Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class DeviceConfig:
|
||||
device: Optional[torch.device]
|
||||
device_type: str
|
||||
"""Configuration for the device to use for vLLM execution."""
|
||||
|
||||
device: Union[Device, torch.device] = "auto"
|
||||
"""Device type for vLLM execution."""
|
||||
device_type: str = field(init=False)
|
||||
"""Device type from the current platform. This is set in
|
||||
`__post_init__`."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@@ -2053,8 +2082,8 @@ class DeviceConfig:
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __init__(self, device: str = "auto") -> None:
|
||||
if device == "auto":
|
||||
def __post_init__(self):
|
||||
if self.device == "auto":
|
||||
# Automated device type detection
|
||||
from vllm.platforms import current_platform
|
||||
self.device_type = current_platform.device_type
|
||||
@@ -2065,7 +2094,7 @@ class DeviceConfig:
|
||||
"to turn on verbose logging to help debug the issue.")
|
||||
else:
|
||||
# Device type is assigned explicitly
|
||||
self.device_type = device
|
||||
self.device_type = self.device
|
||||
|
||||
# Some device types require processing inputs on CPU
|
||||
if self.device_type in ["neuron"]:
|
||||
|
||||
Reference in New Issue
Block a user