Adds truncate_prompt_tokens param for embeddings creation (#8999)

Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
Flávia Béo
2024-10-04 15:31:40 -03:00
committed by GitHub
parent 26aa325f4f
commit 0dcc8cbe5a
3 changed files with 76 additions and 5 deletions

View File

@@ -144,3 +144,64 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
0].embedding
assert responses_float.data[1].embedding == responses_default.data[
1].embedding
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_single_embedding_truncation(
embedding_client: openai.AsyncOpenAI, model_name: str):
input_texts = [
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
]
# test single embedding
embeddings = await embedding_client.embeddings.create(
model=model_name,
input=input_texts,
extra_body={"truncate_prompt_tokens": 10})
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 10
assert embeddings.usage.total_tokens == 10
input_tokens = [
1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728,
9901, 340, 2229, 385, 340, 315, 28741, 28804, 2
]
embeddings = await embedding_client.embeddings.create(
model=model_name,
input=input_tokens,
extra_body={"truncate_prompt_tokens": 10})
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 10
assert embeddings.usage.total_tokens == 10
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_single_embedding_truncation_invalid(
embedding_client: openai.AsyncOpenAI, model_name: str):
input_texts = [
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
]
with pytest.raises(openai.BadRequestError):
embeddings = await embedding_client.embeddings.create(
model=model_name,
input=input_texts,
extra_body={"truncate_prompt_tokens": 8193})
assert "error" in embeddings.object
assert "truncate_prompt_tokens value is greater than max_model_len. "\
"Please, select a smaller truncation size." in embeddings.message