[Refactor] Separate sequence and token pooling types (#32026)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -93,7 +93,7 @@ class RobertaClassificationHead(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
"""A model that uses Roberta to provide embedding functionalities."""
|
||||
|
||||
@@ -150,7 +150,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
return loader.load_weights(weights_list, mapper=mapper)
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
"""A model that uses Roberta to provide embedding functionalities.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user