[RLHF] use worker_extension_cls for compatibility with V0 and V1 (#14185)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-03-07 00:32:46 +08:00
committed by GitHub
parent 81b2f4a45f
commit 151b08e0fe
7 changed files with 153 additions and 100 deletions

View File

@@ -1366,6 +1366,7 @@ class ParallelConfig:
# will be determined based on the platform.
worker_cls: str = "auto"
sd_worker_cls: str = "auto"
worker_extension_cls: str = ""
# world_size is TPxPP, it affects the number of workers we create.
world_size: int = field(init=False)
@@ -1523,6 +1524,9 @@ class ParallelConfig:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")
assert isinstance(self.worker_extension_cls, str), (
"worker_extension_cls must be a string (qualified class name).")
@dataclass
class SchedulerConfig: