[Model] Automatic conversion of TokenClassification model (#30666)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
wang.yuqi
2025-12-15 16:13:00 +08:00
committed by GitHub
parent 33278073d6
commit 4429d934de
4 changed files with 45 additions and 0 deletions

View File

@@ -337,6 +337,18 @@ def as_seq_cls_model(cls: _T) -> _T:
tokens = getattr(text_config, "classifier_from_token", None)
method = getattr(text_config, "method", None)
def auto_set_score_bias(weights):
for name, weight in weights:
if name == "score.bias":
device = self.score.weight.device
dtype = self.score.weight.dtype
bias = weight.to(device).to(dtype)
self.score.bias = torch.nn.Parameter(bias)
self.score.skip_bias_add = False
else:
yield name, weight
weights = auto_set_score_bias(weights)
if tokens is None and method is None:
return super().load_weights(weights)
else: