Turn @config into a dataclass_transform (#31541)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -3,12 +3,10 @@
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -50,7 +48,6 @@ All2AllBackend = Literal[
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class EPLBConfig:
|
||||
"""Configuration for Expert Parallel Load Balancing (EP)."""
|
||||
|
||||
@@ -94,7 +91,6 @@ class EPLBConfig:
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ParallelConfig:
|
||||
"""Configuration for the distributed execution."""
|
||||
|
||||
@@ -715,6 +711,3 @@ class ParallelConfig:
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def replace(self, **kwargs) -> Self:
|
||||
return replace(self, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user