[Model] Automatic conversion of TokenClassification model (#30666)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -337,6 +337,18 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
tokens = getattr(text_config, "classifier_from_token", None)
|
||||
method = getattr(text_config, "method", None)
|
||||
|
||||
def auto_set_score_bias(weights):
|
||||
for name, weight in weights:
|
||||
if name == "score.bias":
|
||||
device = self.score.weight.device
|
||||
dtype = self.score.weight.dtype
|
||||
bias = weight.to(device).to(dtype)
|
||||
self.score.bias = torch.nn.Parameter(bias)
|
||||
self.score.skip_bias_add = False
|
||||
else:
|
||||
yield name, weight
|
||||
|
||||
weights = auto_set_score_bias(weights)
|
||||
if tokens is None and method is None:
|
||||
return super().load_weights(weights)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user