[RLHF] use worker_extension_cls for compatibility with V0 and V1 (#14185)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -1366,6 +1366,7 @@ class ParallelConfig:
|
||||
# will be determined based on the platform.
|
||||
worker_cls: str = "auto"
|
||||
sd_worker_cls: str = "auto"
|
||||
worker_extension_cls: str = ""
|
||||
|
||||
# world_size is TPxPP, it affects the number of workers we create.
|
||||
world_size: int = field(init=False)
|
||||
@@ -1523,6 +1524,9 @@ class ParallelConfig:
|
||||
raise ValueError("Unable to use nsight profiling unless workers "
|
||||
"run with Ray.")
|
||||
|
||||
assert isinstance(self.worker_extension_cls, str), (
|
||||
"worker_extension_cls must be a string (qualified class name).")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerConfig:
|
||||
|
||||
Reference in New Issue
Block a user