[Bugfix] Handle process_weights_after_loading for QKVCrossParallelLinear (#15328)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -254,6 +254,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
else:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
@@ -268,6 +269,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
# The weight_scale_inv name is intentional for deepseekv3
|
||||
layer.register_parameter("weight_scale_inv", scale)
|
||||
|
||||
@@ -278,6 +280,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {"scale_type": "input_scale"})
|
||||
layer.register_parameter("input_scale", scale)
|
||||
else:
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
Reference in New Issue
Block a user