[Model] Systematic support for fp32 head, pooling models part (#23810)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-09-09 22:29:50 +08:00
committed by GitHub
parent a55cf41a09
commit 19332c0479
14 changed files with 166 additions and 61 deletions

View File

@@ -62,7 +62,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
linear = nn.Linear(layer_config.get("in_features", 768),
layer_config.get("out_features", 768),
bias=layer_config.get("bias", True),
dtype=torch.float32)
dtype=model_config.head_dtype)
if not _load_dense_weights(linear, folder, model_config):
continue
@@ -70,7 +70,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
layers.append(linear)
if act_name := layer_config.get("activation_function"):
layers.append(get_act_fn(act_name))
return nn.Sequential(*layers).to(dtype=torch.float32)
return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
except Exception:
logger.exception("ST projector loading failed")
@@ -105,15 +105,13 @@ def _load_dense_weights(linear: nn.Linear, folder: str,
if weight_key in state_dict:
weight_loader = getattr(linear.weight, "weight_loader",
default_weight_loader)
weight_loader(linear.weight,
state_dict[weight_key].to(torch.float32))
weight_loader(linear.weight, state_dict[weight_key])
bias_key = weight_key.replace("weight", "bias")
if linear.bias is not None and bias_key in state_dict:
bias_loader = getattr(linear.bias, "weight_loader",
default_weight_loader)
bias_loader(linear.bias,
state_dict[bias_key].to(torch.float32))
bias_loader(linear.bias, state_dict[bias_key])
return True
except Exception:
logger.exception("Failed to load %s", filename)