[Bugfix] Handle process_weights_after_loading for QKVCrossParallelLinear (#15328)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-04-09 01:02:23 +08:00
committed by GitHub
parent 4ebc0b9640
commit 40b4284fe3
3 changed files with 33 additions and 6 deletions

View File

@@ -254,6 +254,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
layer.register_parameter("weight_scale", scale)
else:
assert self.quant_config.activation_scheme == "dynamic"
@@ -268,6 +269,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
@@ -278,6 +280,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)