[Refactor] Separate sequence and token pooling types (#32026)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -357,7 +357,7 @@ class BertOutput(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class BertModel(nn.Module, SupportsQuant):
|
||||
is_pooling_model = True
|
||||
|
||||
@@ -461,7 +461,7 @@ class BertPoolingModel(BertModel):
|
||||
return loaded_params
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
@@ -675,7 +675,7 @@ class SPLADESparsePooler(Pooler):
|
||||
return torch.stack(pooled_list, dim=0).contiguous()
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
|
||||
"""
|
||||
BertEmbeddingModel + SPLADE sparse embedding.
|
||||
@@ -780,7 +780,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
|
||||
return loaded
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
@@ -849,7 +849,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
@default_pooling_type("ALL")
|
||||
@default_pooling_type(tok_pooling_type="ALL")
|
||||
class BertForTokenClassification(nn.Module):
|
||||
is_pooling_model = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user