[Bugfix] Fix broken GritLM model and tests (missing pooling_metadata) (#16631)

Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io>
This commit is contained in:
Pooya Davoodi
2025-04-14 23:09:58 -07:00
committed by GitHub
parent dbb036cf61
commit bc5dd4f669
2 changed files with 13 additions and 11 deletions

View File

@@ -170,7 +170,8 @@ class GritLMPooler(nn.Module):
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
1)
pooled_data = self.head(mean_embeddings)
pooled_data = self.head(mean_embeddings,
pooling_metadata=pooling_metadata)
pooled_outputs = [
PoolingSequenceGroupOutput(data) for data in pooled_data