Simplify TokenizerGroup (#16790)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-04-24 12:43:56 +01:00
committed by GitHub
parent 14288d1332
commit 0a05ed57e6
24 changed files with 80 additions and 752 deletions

View File

@@ -52,8 +52,6 @@ if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.loader import BaseModelLoader
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
ConfigType = type[DataclassInstance]
else:
@@ -1407,83 +1405,33 @@ class CacheConfig:
logger.warning("Possibly too large swap space. %s", msg)
PoolType = Literal["ray"]
@config
@dataclass
class TokenizerPoolConfig:
"""Configuration for the tokenizer pool."""
"""This config is deprecated and will be removed in a future release.
Passing these parameters will have no effect. Please remove them from your
configurations.
"""
pool_size: int = 0
"""Number of tokenizer workers in the pool to use for asynchronous
tokenization. If 0, will use synchronous tokenization."""
pool_type: Union[PoolType, type["BaseTokenizerGroup"]] = "ray"
"""Type of tokenizer pool to use for asynchronous tokenization. Ignored if
tokenizer_pool_size is 0."""
"""This parameter is deprecated and will be removed in a future release.
Passing this parameter will have no effect. Please remove it from your
configurations."""
pool_type: str = "ray"
"""This parameter is deprecated and will be removed in a future release.
Passing this parameter will have no effect. Please remove it from your
configurations."""
extra_config: dict = field(default_factory=dict)
"""Additional config for the pool. The way the config will be used depends
on the pool type. This should be a JSON string that will be parsed into a
dictionary. Ignored if tokenizer_pool_size is 0."""
"""This parameter is deprecated and will be removed in a future release.
Passing this parameter will have no effect. Please remove it from your
configurations."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if self.pool_type not in ("ray", ) and not isinstance(
self.pool_type, type):
raise ValueError(f"Unknown pool type: {self.pool_type}")
if not isinstance(self.extra_config, dict):
raise ValueError("extra_config must be a dictionary.")
@classmethod
def create_config(
cls, tokenizer_pool_size: int,
tokenizer_pool_type: Union[PoolType, type["BaseTokenizerGroup"]],
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
If tokenizer_pool_size is 0, return None.
Args:
tokenizer_pool_size: Number of tokenizer workers in the pool.
tokenizer_pool_type: Type of the pool.
tokenizer_pool_extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type. This can be a JSON string (will be parsed).
"""
if tokenizer_pool_size:
if isinstance(tokenizer_pool_extra_config, str):
tokenizer_pool_extra_config_parsed = json.loads(
tokenizer_pool_extra_config)
else:
tokenizer_pool_extra_config_parsed = (
tokenizer_pool_extra_config or {})
tokenizer_pool_config = cls(tokenizer_pool_size,
tokenizer_pool_type,
tokenizer_pool_extra_config_parsed)
else:
tokenizer_pool_config = None
return tokenizer_pool_config
def __post_init__(self) -> None:
logger.warning_once(
"TokenizerPoolConfig is deprecated and will be removed in a "
"future release. Passing this parameter will have no effect. "
"Please remove it from your configurations.")
class LoadFormat(str, enum.Enum):
@@ -1624,8 +1572,8 @@ class ParallelConfig:
"""Disable the custom all-reduce kernel and fall back to NCCL."""
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None
"""Config for the tokenizer pool. If None, will use synchronous
tokenization."""
"""This parameter is deprecated and will be removed in a future release.
Please remove it from your configs"""
ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
@@ -2544,7 +2492,6 @@ class SpeculativeConfig:
max_parallel_loading_workers,
disable_custom_all_reduce=target_parallel_config.
disable_custom_all_reduce,
tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
ray_workers_use_nsight=target_parallel_config.
ray_workers_use_nsight,
placement_group=target_parallel_config.placement_group,