[RLHF] use worker_extension_cls for compatibility with V0 and V1 (#14185)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-03-07 00:32:46 +08:00
committed by GitHub
parent 81b2f4a45f
commit 151b08e0fe
7 changed files with 153 additions and 100 deletions

View File

@@ -202,6 +202,7 @@ class EngineArgs:
override_pooler_config: Optional[PoolerConfig] = None
compilation_config: Optional[CompilationConfig] = None
worker_cls: str = "auto"
worker_extension_cls: str = ""
kv_transfer_config: Optional[KVTransferConfig] = None
@@ -1015,6 +1016,13 @@ class EngineArgs:
type=str,
default="auto",
help='The worker class to use for distributed execution.')
parser.add_argument(
'--worker-extension-cls',
type=str,
default="",
help='The worker extension class on top of the worker cls, '
'it is useful if you just want to add new functions to the worker '
'class without changing the existing functions.')
parser.add_argument(
"--generation-config",
type=nullable_str,
@@ -1209,6 +1217,7 @@ class EngineArgs:
ray_workers_use_nsight=self.ray_workers_use_nsight,
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
)
max_model_len = model_config.max_model_len