[Bugfix] Fix broken deepseek fp8 TP weights loading (#24367)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -262,7 +262,7 @@ class LinearBase(CustomOp):
|
||||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||
if not disable_tp else 1)
|
||||
|
||||
def __post_init__(self):
|
||||
def update_param_tp_status(self):
|
||||
for param in self.parameters():
|
||||
if isinstance(param, BasevLLMParameter):
|
||||
param.tp_rank = self.tp_rank
|
||||
@@ -459,6 +459,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
self.update_param_tp_status()
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
|
||||
@@ -1250,6 +1251,7 @@ class RowParallelLinear(LinearBase):
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
self.update_param_tp_status()
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
|
||||
Reference in New Issue
Block a user