Fix model name included in responses (#24663)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-09-11 18:47:51 +01:00
committed by GitHub
parent 4aa23892d6
commit c1eda615ba
10 changed files with 50 additions and 74 deletions

View File

@@ -213,8 +213,12 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
MODEL_NAME = "openai-community/gpt2"
MODEL_NAME_SHORT = "gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
BASE_MODEL_PATHS = [
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT)
]
@dataclass
@@ -270,6 +274,42 @@ def test_async_serving_chat_init():
assert serving_completion.chat_template == CHAT_TEMPLATE
@pytest.mark.asyncio
async def test_serving_chat_returns_correct_model_name():
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
messages = [{"role": "user", "content": "what is 1+1?"}]
async def return_model_name(*args):
return args[3]
serving_chat.chat_completion_full_generator = return_model_name
# Test that full name is returned when short name is requested
req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages)
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
# Test that full name is returned when empty string is specified
req = ChatCompletionRequest(model="", messages=messages)
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
# Test that full name is returned when no model is specified
req = ChatCompletionRequest(messages=messages)
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
@pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=MQLLMEngineClient)