[Model] Systematic support for fp32 head, pooling models part (#23810)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -62,7 +62,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
|
||||
linear = nn.Linear(layer_config.get("in_features", 768),
|
||||
layer_config.get("out_features", 768),
|
||||
bias=layer_config.get("bias", True),
|
||||
dtype=torch.float32)
|
||||
dtype=model_config.head_dtype)
|
||||
|
||||
if not _load_dense_weights(linear, folder, model_config):
|
||||
continue
|
||||
@@ -70,7 +70,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
|
||||
layers.append(linear)
|
||||
if act_name := layer_config.get("activation_function"):
|
||||
layers.append(get_act_fn(act_name))
|
||||
return nn.Sequential(*layers).to(dtype=torch.float32)
|
||||
return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
|
||||
except Exception:
|
||||
logger.exception("ST projector loading failed")
|
||||
|
||||
@@ -105,15 +105,13 @@ def _load_dense_weights(linear: nn.Linear, folder: str,
|
||||
if weight_key in state_dict:
|
||||
weight_loader = getattr(linear.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(linear.weight,
|
||||
state_dict[weight_key].to(torch.float32))
|
||||
weight_loader(linear.weight, state_dict[weight_key])
|
||||
|
||||
bias_key = weight_key.replace("weight", "bias")
|
||||
if linear.bias is not None and bias_key in state_dict:
|
||||
bias_loader = getattr(linear.bias, "weight_loader",
|
||||
default_weight_loader)
|
||||
bias_loader(linear.bias,
|
||||
state_dict[bias_key].to(torch.float32))
|
||||
bias_loader(linear.bias, state_dict[bias_key])
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to load %s", filename)
|
||||
|
||||
Reference in New Issue
Block a user