[Bugfix] Fix broken GritLM model and tests (missing pooling_metadata) (#16631)
Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io>
This commit is contained in:
@@ -170,7 +170,8 @@ class GritLMPooler(nn.Module):
|
||||
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
|
||||
1)
|
||||
|
||||
pooled_data = self.head(mean_embeddings)
|
||||
pooled_data = self.head(mean_embeddings,
|
||||
pooling_metadata=pooling_metadata)
|
||||
|
||||
pooled_outputs = [
|
||||
PoolingSequenceGroupOutput(data) for data in pooled_data
|
||||
|
||||
Reference in New Issue
Block a user