[Bugfix] Fix score api for missing max_model_len validation (#12119)
Signed-off-by: Wallas Santos <wallashss@ibm.com>
This commit is contained in:
@@ -12,6 +12,9 @@ MODEL_NAME = "BAAI/bge-reranker-v2-m3"
|
||||
def server():
|
||||
args = [
|
||||
"--enforce-eager",
|
||||
# Will be used on tests to compare prompt input length
|
||||
"--max-model-len",
|
||||
"100"
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@@ -20,8 +23,7 @@ def server():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
def test_text_1_str_text_2_list(server: RemoteOpenAIServer, model_name: str):
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
@@ -45,8 +47,7 @@ async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
def test_text_1_list_text_2_list(server: RemoteOpenAIServer, model_name: str):
|
||||
text_1 = [
|
||||
"What is the capital of the United States?",
|
||||
"What is the capital of France?"
|
||||
@@ -73,8 +74,7 @@ async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
def test_text_1_str_text_2_str(server: RemoteOpenAIServer, model_name: str):
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = "The capital of France is Paris."
|
||||
|
||||
@@ -91,3 +91,36 @@ async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 1
|
||||
assert score.data[0].score >= 0.9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
text_1 = "What is the capital of France?" * 20
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
})
|
||||
assert score_response.status_code == 400
|
||||
# Assert just a small fragments of the response
|
||||
assert "Please reduce the length of the input." in \
|
||||
score_response.text
|
||||
|
||||
# Test truncation
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
"truncate_prompt_tokens": 101
|
||||
})
|
||||
assert score_response.status_code == 400
|
||||
assert "Please, select a smaller truncation size." in \
|
||||
score_response.text
|
||||
|
||||
Reference in New Issue
Block a user