diff --git a/docs/design/io_processor_plugins.md b/docs/design/io_processor_plugins.md index 5a86940fa..3e029259e 100644 --- a/docs/design/io_processor_plugins.md +++ b/docs/design/io_processor_plugins.md @@ -79,7 +79,7 @@ The `post_process*` methods take `PoolingRequestOutput` objects as input and gen The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters. The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/pooling/pooling/serving.py). -An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/pooling/plugin/prithvi_geospatial_mae_client.py](../../examples/pooling/plugin/prithvi_geospatial_mae_client.py)) and offline ([examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py](../../examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py)) inference examples. +An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/pooling/plugin/prithvi_geospatial_mae_online.py](../../examples/pooling/plugin/prithvi_geospatial_mae_online.py)) and offline ([examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py](../../examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py)) inference examples. ## Using an IO Processor plugin diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index b4b0150fa..9c5069ba7 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -273,7 +273,7 @@ outputs = llm.embed( print(outputs[0].outputs) ``` -A code example can be found here: [examples/pooling/embed/embed_matryoshka_fy.py](../../examples/pooling/embed/embed_matryoshka_fy.py) +A code example can be found here: [examples/pooling/embed/embed_matryoshka_fy_offline.py](../../examples/pooling/embed/embed_matryoshka_fy_offline.py) ### Online Inference @@ -303,7 +303,7 @@ Expected output: {"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}} ``` -An OpenAI client example can be found here: [examples/pooling/embed/openai_embedding_matryoshka_fy.py](../../examples/pooling/embed/openai_embedding_matryoshka_fy.py) +An OpenAI client example can be found here: [examples/pooling/embed/openai_embedding_matryoshka_fy_client.py](../../examples/pooling/embed/openai_embedding_matryoshka_fy_client.py) ## Deprecated Features diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 089d9649d..0cd21666f 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -619,7 +619,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode) | `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | !!! note - Named Entity Recognition (NER) usage, please refer to [examples/pooling/token_classify/ner.py](../../examples/pooling/token_classify/ner.py), [examples/pooling/token_classify/ner_client.py](../../examples/pooling/token_classify/ner_client.py). + Named Entity Recognition (NER) usage, please refer to [examples/pooling/token_classify/ner_offline.py](../../examples/pooling/token_classify/ner_offline.py), [examples/pooling/token_classify/ner_online.py](../../examples/pooling/token_classify/ner_online.py). ## List of Multimodal Language Models diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index 6585643b5..21f6de962 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -551,7 +551,7 @@ Our Pooling API encodes input prompts using a [pooling model](../models/pooling_ The input format is the same as [Embeddings API](#embeddings-api), but the output data can contain an arbitrary nested list, not just a 1-D list of floats. -Code example: [examples/pooling/pooling/openai_pooling_client.py](../../examples/pooling/pooling/openai_pooling_client.py) +Code example: [examples/pooling/pooling/pooling_online.py](../../examples/pooling/pooling/pooling_online.py) ### Classification API diff --git a/examples/pooling/embed/embed_jina_embeddings_v3.py b/examples/pooling/embed/embed_jina_embeddings_v3_offline.py similarity index 100% rename from examples/pooling/embed/embed_jina_embeddings_v3.py rename to examples/pooling/embed/embed_jina_embeddings_v3_offline.py diff --git a/examples/pooling/embed/embed_matryoshka_fy.py b/examples/pooling/embed/embed_matryoshka_fy_offline.py similarity index 100% rename from examples/pooling/embed/embed_matryoshka_fy.py rename to examples/pooling/embed/embed_matryoshka_fy_offline.py diff --git a/examples/pooling/embed/embedding_requests_base64_client.py b/examples/pooling/embed/embedding_requests_base64_online.py similarity index 69% rename from examples/pooling/embed/embedding_requests_base64_client.py rename to examples/pooling/embed/embedding_requests_base64_online.py index 4c2399b58..88c961370 100644 --- a/examples/pooling/embed/embedding_requests_base64_client.py +++ b/examples/pooling/embed/embedding_requests_base64_online.py @@ -26,36 +26,42 @@ def post_http_request(prompt: dict, api_url: str) -> requests.Response: def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--model", type=str, default="intfloat/e5-small") - - return parser.parse_args() + parse = argparse.ArgumentParser() + parse.add_argument("--host", type=str, default="localhost") + parse.add_argument("--port", type=int, default=8000) + return parse.parse_args() def main(args): - api_url = f"http://{args.host}:{args.port}/v1/embeddings" - model_name = args.model + base_url = f"http://{args.host}:{args.port}" + models_url = base_url + "/v1/models" + embeddings_url = base_url + "/v1/embeddings" + + response = requests.get(models_url) + model = response.json()["data"][0]["id"] + + input_texts = [ + "The best thing about vLLM is that it supports many different models", + ] * 2 # The OpenAI client does not support the embed_dtype and endianness parameters. for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE: for endianness in ENDIANNESS: prompt = { - "model": model_name, - "input": "vLLM is great!", + "model": model, + "input": input_texts, "encoding_format": "base64", "embed_dtype": embed_dtype, "endianness": endianness, } - response = post_http_request(prompt=prompt, api_url=api_url) + response = post_http_request(prompt=prompt, api_url=embeddings_url) embedding = [] for data in response.json()["data"]: binary = base64.b64decode(data["embedding"]) tensor = binary2tensor(binary, (-1,), embed_dtype, endianness) embedding.append(tensor.to(torch.float32)) - embedding = torch.cat(embedding) + embedding = torch.stack(embedding) print(embed_dtype, endianness, embedding.shape) diff --git a/examples/pooling/embed/embedding_requests_bytes_client.py b/examples/pooling/embed/embedding_requests_bytes_online.py similarity index 90% rename from examples/pooling/embed/embedding_requests_bytes_client.py rename to examples/pooling/embed/embedding_requests_bytes_online.py index 5ea452524..6a45beb0b 100644 --- a/examples/pooling/embed/embedding_requests_bytes_client.py +++ b/examples/pooling/embed/embedding_requests_bytes_online.py @@ -31,14 +31,18 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--model", type=str, default="intfloat/e5-small") return parser.parse_args() def main(args): - api_url = f"http://{args.host}:{args.port}/v1/embeddings" - model_name = args.model + base_url = f"http://{args.host}:{args.port}" + models_url = base_url + "/v1/models" + embeddings_url = base_url + "/v1/embeddings" + + response = requests.get(models_url) + model = response.json()["data"][0]["id"] + embedding_size = 0 input_texts = [ @@ -50,13 +54,13 @@ def main(args): for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE: for endianness in ENDIANNESS: prompt = { - "model": model_name, + "model": model, "input": input_texts, "encoding_format": "bytes", "embed_dtype": embed_dtype, "endianness": endianness, } - response = post_http_request(prompt=prompt, api_url=api_url) + response = post_http_request(prompt=prompt, api_url=embeddings_url) metadata = json.loads(response.headers["metadata"]) body = response.content items = [MetadataItem(**x) for x in metadata["data"]] @@ -73,13 +77,13 @@ def main(args): for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE: for endianness in ENDIANNESS: prompt = { - "model": model_name, + "model": model, "input": input_texts, "encoding_format": "bytes_only", "embed_dtype": embed_dtype, "endianness": endianness, } - response = post_http_request(prompt=prompt, api_url=api_url) + response = post_http_request(prompt=prompt, api_url=embeddings_url) body = response.content items = build_metadata_items( diff --git a/examples/pooling/embed/openai_embedding_matryoshka_fy.py b/examples/pooling/embed/openai_embedding_matryoshka_fy_client.py similarity index 100% rename from examples/pooling/embed/openai_embedding_matryoshka_fy.py rename to examples/pooling/embed/openai_embedding_matryoshka_fy_client.py diff --git a/examples/pooling/plugin/prithvi_geospatial_mae_client.py b/examples/pooling/plugin/prithvi_geospatial_mae_online.py similarity index 100% rename from examples/pooling/plugin/prithvi_geospatial_mae_client.py rename to examples/pooling/plugin/prithvi_geospatial_mae_online.py diff --git a/examples/pooling/pooling/openai_pooling_client.py b/examples/pooling/pooling/pooling_online.py similarity index 81% rename from examples/pooling/pooling/openai_pooling_client.py rename to examples/pooling/pooling/pooling_online.py index 569015746..e8ff38889 100644 --- a/examples/pooling/pooling/openai_pooling_client.py +++ b/examples/pooling/pooling/pooling_online.py @@ -25,18 +25,21 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--model", type=str, default="internlm/internlm2-1_8b-reward") return parser.parse_args() def main(args): - api_url = f"http://{args.host}:{args.port}/pooling" - model_name = args.model + base_url = f"http://{args.host}:{args.port}" + models_url = base_url + "/v1/models" + pooing_url = base_url + "/pooling" + + response = requests.get(models_url) + model = response.json()["data"][0]["id"] # Input like Completions API - prompt = {"model": model_name, "input": "vLLM is great!"} - pooling_response = post_http_request(prompt=prompt, api_url=api_url) + prompt = {"model": model, "input": "vLLM is great!"} + pooling_response = post_http_request(prompt=prompt, api_url=pooing_url) print("-" * 50) print("Pooling Response:") pprint.pprint(pooling_response.json()) @@ -44,7 +47,7 @@ def main(args): # Input like Chat API prompt = { - "model": model_name, + "model": model, "messages": [ { "role": "user", @@ -52,7 +55,7 @@ def main(args): } ], } - pooling_response = post_http_request(prompt=prompt, api_url=api_url) + pooling_response = post_http_request(prompt=prompt, api_url=pooing_url) print("Pooling Response:") pprint.pprint(pooling_response.json()) print("-" * 50) diff --git a/examples/pooling/token_classify/ner.py b/examples/pooling/token_classify/ner_offline.py similarity index 100% rename from examples/pooling/token_classify/ner.py rename to examples/pooling/token_classify/ner_offline.py diff --git a/examples/pooling/token_classify/ner_client.py b/examples/pooling/token_classify/ner_online.py similarity index 100% rename from examples/pooling/token_classify/ner_client.py rename to examples/pooling/token_classify/ner_online.py diff --git a/examples/pooling/token_embed/jina_embeddings_v4.py b/examples/pooling/token_embed/jina_embeddings_v4_offline.py similarity index 100% rename from examples/pooling/token_embed/jina_embeddings_v4.py rename to examples/pooling/token_embed/jina_embeddings_v4_offline.py diff --git a/examples/pooling/token_embed/multi_vector_retrieval.py b/examples/pooling/token_embed/multi_vector_retrieval_offline.py similarity index 100% rename from examples/pooling/token_embed/multi_vector_retrieval.py rename to examples/pooling/token_embed/multi_vector_retrieval_offline.py diff --git a/examples/pooling/token_embed/multi_vector_retrieval_client.py b/examples/pooling/token_embed/multi_vector_retrieval_online.py similarity index 100% rename from examples/pooling/token_embed/multi_vector_retrieval_client.py rename to examples/pooling/token_embed/multi_vector_retrieval_online.py diff --git a/tests/entrypoints/pooling/classify/test_online.py b/tests/entrypoints/pooling/classify/test_online.py index 0bd62b9f4..592c862d0 100644 --- a/tests/entrypoints/pooling/classify/test_online.py +++ b/tests/entrypoints/pooling/classify/test_online.py @@ -167,7 +167,8 @@ def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str): @pytest.mark.parametrize("model_name", [MODEL_NAME]) def test_add_special_tokens(server: RemoteOpenAIServer, model_name: str): - # FIXME: The add_special_tokens parameter doesn't seem to be working. + # The add_special_tokens parameter doesn't seem to be working with this model. + # working with papluca/xlm-roberta-base-language-detection response = requests.post( server.url_for("classify"), json={"model": model_name, "input": input_text, "add_special_tokens": False}, @@ -184,7 +185,110 @@ def test_add_special_tokens(server: RemoteOpenAIServer, model_name: str): @pytest.mark.asyncio -async def test_invocations(server: RemoteOpenAIServer): +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_chat_request(server: RemoteOpenAIServer, model_name: str): + messages = [ + { + "role": "user", + "content": "The cat sat on the mat.", + }, + { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, + { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }, + ] + + # test chat request basic usage + response = requests.post( + server.url_for("classify"), + json={"model": model_name, "messages": messages}, + ) + + response.raise_for_status() + output = ClassificationResponse.model_validate(response.json()) + + assert output.object == "list" + assert output.model == MODEL_NAME + assert len(output.data) == 1 + assert hasattr(output.data[0], "label") + assert hasattr(output.data[0], "probs") + assert output.usage.prompt_tokens == 51 + + # test add_generation_prompt + response = requests.post( + server.url_for("classify"), + json={"model": model_name, "messages": messages, "add_generation_prompt": True}, + ) + + response.raise_for_status() + output = ClassificationResponse.model_validate(response.json()) + + assert output.object == "list" + assert output.model == MODEL_NAME + assert len(output.data) == 1 + assert hasattr(output.data[0], "label") + assert hasattr(output.data[0], "probs") + assert output.usage.prompt_tokens == 54 + + # test continue_final_message + response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "messages": messages, + "continue_final_message": True, + }, + ) + + response.raise_for_status() + output = ClassificationResponse.model_validate(response.json()) + + assert output.object == "list" + assert output.model == MODEL_NAME + assert len(output.data) == 1 + assert hasattr(output.data[0], "label") + assert hasattr(output.data[0], "probs") + assert output.usage.prompt_tokens == 49 + + # test add_special_tokens + # The add_special_tokens parameter doesn't seem to be working with this model. + response = requests.post( + server.url_for("classify"), + json={"model": model_name, "messages": messages, "add_special_tokens": True}, + ) + + response.raise_for_status() + output = ClassificationResponse.model_validate(response.json()) + + assert output.object == "list" + assert output.model == MODEL_NAME + assert len(output.data) == 1 + assert hasattr(output.data[0], "label") + assert hasattr(output.data[0], "probs") + assert output.usage.prompt_tokens == 51 + + # test continue_final_message with add_generation_prompt + response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "messages": messages, + "continue_final_message": True, + "add_generation_prompt": True, + }, + ) + assert ( + "Cannot set both `continue_final_message` and `add_generation_prompt` to True." + in response.json()["error"]["message"] + ) + + +@pytest.mark.asyncio +async def test_invocations_completion_request(server: RemoteOpenAIServer): request_args = { "model": MODEL_NAME, "input": input_text, @@ -213,6 +317,48 @@ async def test_invocations(server: RemoteOpenAIServer): ) +@pytest.mark.asyncio +async def test_invocations_chat_request(server: RemoteOpenAIServer): + messages = [ + { + "role": "user", + "content": "The cat sat on the mat.", + }, + { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, + { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }, + ] + + request_args = {"model": MODEL_NAME, "messages": messages} + + classification_response = requests.post( + server.url_for("classify"), json=request_args + ) + classification_response.raise_for_status() + + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) + invocation_response.raise_for_status() + + classification_output = classification_response.json() + invocation_output = invocation_response.json() + + assert classification_output.keys() == invocation_output.keys() + for classification_data, invocation_data in zip( + classification_output["data"], invocation_output["data"] + ): + assert classification_data.keys() == invocation_data.keys() + assert classification_data["probs"] == pytest.approx( + invocation_data["probs"], rel=0.01 + ) + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_use_activation(server: RemoteOpenAIServer, model_name: str): diff --git a/tests/entrypoints/pooling/embed/test_online.py b/tests/entrypoints/pooling/embed/test_online.py index 3a368de7d..8f3f8a850 100644 --- a/tests/entrypoints/pooling/embed/test_online.py +++ b/tests/entrypoints/pooling/embed/test_online.py @@ -214,64 +214,6 @@ async def test_completion_request_batched( run_embedding_correctness_test(hf_model, input_texts, vllm_outputs) -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_conversation_embedding( - server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str -): - messages = [ - { - "role": "user", - "content": "The cat sat on the mat.", - }, - { - "role": "assistant", - "content": "A feline was resting on a rug.", - }, - { - "role": "user", - "content": "Stars twinkle brightly in the night sky.", - }, - ] - - chat_response = requests.post( - server.url_for("v1/embeddings"), - json={ - "model": model_name, - "messages": messages, - "encoding_format": "float", - }, - ) - chat_response.raise_for_status() - chat_embeddings = EmbeddingResponse.model_validate(chat_response.json()) - - tokenizer = get_tokenizer(tokenizer_name=model_name) - prompt = tokenizer.apply_chat_template( - messages, - chat_template=DUMMY_CHAT_TEMPLATE, - add_generation_prompt=True, - continue_final_message=False, - tokenize=False, - ) - completion_response = await client.embeddings.create( - model=model_name, - input=prompt, - encoding_format="float", - # To be consistent with chat - extra_body={"add_special_tokens": False}, - ) - completion_embeddings = EmbeddingResponse.model_validate( - completion_response.model_dump(mode="json") - ) - - assert chat_embeddings.id is not None - assert completion_embeddings.id is not None - assert chat_embeddings.created <= completion_embeddings.created - assert chat_embeddings.model_dump(exclude={"id", "created"}) == ( - completion_embeddings.model_dump(exclude={"id", "created"}) - ) - - @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_truncate_prompt_tokens(client: openai.AsyncOpenAI, model_name: str): @@ -350,7 +292,129 @@ async def test_truncate_prompt_tokens(client: openai.AsyncOpenAI, model_name: st @pytest.mark.asyncio -async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_chat_request( + server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str +): + messages = [ + { + "role": "user", + "content": "The cat sat on the mat.", + }, + { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, + { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }, + ] + + # test chat request basic usage + chat_response = requests.post( + server.url_for("v1/embeddings"), + json={ + "model": model_name, + "messages": messages, + "encoding_format": "float", + }, + ) + chat_response.raise_for_status() + chat_embeddings = EmbeddingResponse.model_validate(chat_response.json()) + + tokenizer = get_tokenizer(tokenizer_name=model_name) + prompt = tokenizer.apply_chat_template( + messages, + chat_template=DUMMY_CHAT_TEMPLATE, + add_generation_prompt=True, + continue_final_message=False, + tokenize=False, + ) + completion_response = await client.embeddings.create( + model=model_name, + input=prompt, + encoding_format="float", + # To be consistent with chat + extra_body={"add_special_tokens": False}, + ) + completion_embeddings = EmbeddingResponse.model_validate( + completion_response.model_dump(mode="json") + ) + + assert chat_embeddings.id is not None + assert completion_embeddings.id is not None + assert chat_embeddings.created <= completion_embeddings.created + assert chat_embeddings.model_dump(exclude={"id", "created"}) == ( + completion_embeddings.model_dump(exclude={"id", "created"}) + ) + + # test add_generation_prompt + response = requests.post( + server.url_for("v1/embeddings"), + json={"model": model_name, "messages": messages, "add_generation_prompt": True}, + ) + + response.raise_for_status() + output = EmbeddingResponse.model_validate(response.json()) + + assert output.object == "list" + assert len(output.data) == 1 + assert output.model == MODEL_NAME + assert output.usage.prompt_tokens == 34 + + # test continue_final_message + response = requests.post( + server.url_for("v1/embeddings"), + json={ + "model": model_name, + "messages": messages, + "continue_final_message": True, + }, + ) + + response.raise_for_status() + output = EmbeddingResponse.model_validate(response.json()) + + assert output.object == "list" + assert len(output.data) == 1 + assert output.model == MODEL_NAME + assert output.usage.prompt_tokens == 33 + + # test add_special_tokens + response = requests.post( + server.url_for("v1/embeddings"), + json={"model": model_name, "messages": messages, "add_special_tokens": True}, + ) + + response.raise_for_status() + output = EmbeddingResponse.model_validate(response.json()) + + assert output.object == "list" + assert len(output.data) == 1 + assert output.model == MODEL_NAME + assert output.usage.prompt_tokens == 36 + + # test continue_final_message with add_generation_prompt + response = requests.post( + server.url_for("v1/embeddings"), + json={ + "model": model_name, + "messages": messages, + "continue_final_message": True, + "add_generation_prompt": True, + }, + ) + assert ( + "Cannot set both `continue_final_message` and `add_generation_prompt` to True." + in response.json()["error"]["message"] + ) + + +@pytest.mark.asyncio +async def test_invocations_completion_request( + server: RemoteOpenAIServer, client: openai.AsyncOpenAI +): request_args = { "model": MODEL_NAME, "input": input_text, @@ -381,7 +445,7 @@ async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenA @pytest.mark.asyncio -async def test_invocations_conversation(server: RemoteOpenAIServer): +async def test_invocations_chat_request(server: RemoteOpenAIServer): messages = [ { "role": "user", diff --git a/tests/entrypoints/pooling/pooling/test_online.py b/tests/entrypoints/pooling/pooling/test_online.py index ce4014ba1..0ca841b4a 100644 --- a/tests/entrypoints/pooling/pooling/test_online.py +++ b/tests/entrypoints/pooling/pooling/test_online.py @@ -138,7 +138,7 @@ def test_completion_request_batched(server: RemoteOpenAIServer, model_name: str) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_conversation_pooling(server: RemoteOpenAIServer, model_name: str): +async def test_chat_request(server: RemoteOpenAIServer, model_name: str): messages = [ { "role": "user", @@ -154,6 +154,7 @@ async def test_conversation_pooling(server: RemoteOpenAIServer, model_name: str) }, ] + # test chat request basic usage chat_response = requests.post( server.url_for("pooling"), json={ @@ -193,6 +194,68 @@ async def test_conversation_pooling(server: RemoteOpenAIServer, model_name: str) completion_poolings.model_dump(exclude={"id", "created"}) ) + # test add_generation_prompt + response = requests.post( + server.url_for("pooling"), + json={"model": model_name, "messages": messages, "add_generation_prompt": True}, + ) + + response.raise_for_status() + output = PoolingResponse.model_validate(response.json()) + + assert output.object == "list" + assert len(output.data) == 1 + assert output.model == MODEL_NAME + assert output.usage.prompt_tokens == 33 + + # test continue_final_message + # The continue_final_message parameter doesn't seem to be working with this model. + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "messages": messages, + "continue_final_message": True, + }, + ) + + response.raise_for_status() + output = PoolingResponse.model_validate(response.json()) + + assert output.object == "list" + assert len(output.data) == 1 + assert output.model == MODEL_NAME + assert output.usage.prompt_tokens == 33 + + # test add_special_tokens + response = requests.post( + server.url_for("pooling"), + json={"model": model_name, "messages": messages, "add_special_tokens": True}, + ) + + response.raise_for_status() + output = PoolingResponse.model_validate(response.json()) + + assert output.object == "list" + assert len(output.data) == 1 + assert output.model == MODEL_NAME + assert output.usage.prompt_tokens == 34 + + # test continue_final_message with add_generation_prompt + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "messages": messages, + "continue_final_message": True, + "add_generation_prompt": True, + }, + ) + assert ( + "Cannot set both `continue_final_message` and `add_generation_prompt` to True." + in response.json()["error"]["message"] + ) + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @@ -430,7 +493,7 @@ async def test_params_not_supported( @pytest.mark.asyncio -async def test_invocations(server: RemoteOpenAIServer): +async def test_invocations_chat_request(server: RemoteOpenAIServer): request_args = { "model": MODEL_NAME, "input": input_text, @@ -462,7 +525,7 @@ async def test_invocations(server: RemoteOpenAIServer): @pytest.mark.asyncio -async def test_invocations_conversation(server: RemoteOpenAIServer): +async def test_invocations_conversation_chat_request(server: RemoteOpenAIServer): messages = [ { "role": "user", diff --git a/vllm/entrypoints/pooling/base/protocol.py b/vllm/entrypoints/pooling/base/protocol.py index 0a60be888..1a079306c 100644 --- a/vllm/entrypoints/pooling/base/protocol.py +++ b/vllm/entrypoints/pooling/base/protocol.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Annotated +from typing import Annotated, Any -from pydantic import Field +from pydantic import Field, model_validator +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel from vllm.utils import random_uuid @@ -44,3 +45,66 @@ class CompletionRequestMixin(OpenAIBaseModel): "the prompt." ), ) + + +class ChatRequestMixin(OpenAIBaseModel): + messages: list[ChatCompletionMessageParam] + + add_generation_prompt: bool = Field( + default=False, + description=( + "If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model." + ), + ) + + continue_final_message: bool = Field( + default=False, + description=( + "If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + 'This allows you to "prefill" part of the model\'s response for it. ' + "Cannot be used at the same time as `add_generation_prompt`." + ), + ) + + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)." + ), + ) + + chat_template: str | None = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one." + ), + ) + + chat_template_kwargs: dict[str, Any] | None = Field( + default=None, + description=( + "Additional keyword args to pass to the template renderer. " + "Will be accessible by the chat template." + ), + ) + + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get("add_generation_prompt"): + raise ValueError( + "Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True." + ) + return data diff --git a/vllm/entrypoints/pooling/classify/protocol.py b/vllm/entrypoints/pooling/classify/protocol.py index cd239cff7..d90665bf8 100644 --- a/vllm/entrypoints/pooling/classify/protocol.py +++ b/vllm/entrypoints/pooling/classify/protocol.py @@ -10,9 +10,9 @@ from pydantic import ( from vllm import PoolingParams from vllm.config.pooler import get_use_activation -from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.pooling.base.protocol import ( + ChatRequestMixin, CompletionRequestMixin, PoolingBasicRequestMixin, ) @@ -45,48 +45,8 @@ class ClassificationCompletionRequest(PoolingBasicRequestMixin, CompletionReques ) -class ClassificationChatRequest(PoolingBasicRequestMixin): - messages: list[ChatCompletionMessageParam] - +class ClassificationChatRequest(PoolingBasicRequestMixin, ChatRequestMixin): # --8<-- [start:chat-classification-extra-params] - add_generation_prompt: bool = Field( - default=False, - description=( - "If true, the generation prompt will be added to the chat template. " - "This is a parameter used by chat template in tokenizer config of the " - "model." - ), - ) - - add_special_tokens: bool = Field( - default=False, - description=( - "If true, special tokens (e.g. BOS) will be added to the prompt " - "on top of what is added by the chat template. " - "For most models, the chat template takes care of adding the " - "special tokens so this should be set to false (as is the " - "default)." - ), - ) - - chat_template: str | None = Field( - default=None, - description=( - "A Jinja template to use for this conversion. " - "As of transformers v4.44, default chat template is no longer " - "allowed, so you must provide a chat template if the tokenizer " - "does not define one." - ), - ) - - chat_template_kwargs: dict[str, Any] | None = Field( - default=None, - description=( - "Additional keyword args to pass to the template renderer. " - "Will be accessible by the chat template." - ), - ) - mm_processor_kwargs: dict[str, Any] | None = Field( default=None, description=("Additional kwargs to pass to the HF processor."), diff --git a/vllm/entrypoints/pooling/classify/serving.py b/vllm/entrypoints/pooling/classify/serving.py index 7da2210c6..2ff313930 100644 --- a/vllm/entrypoints/pooling/classify/serving.py +++ b/vllm/entrypoints/pooling/classify/serving.py @@ -86,8 +86,8 @@ class ClassificationMixin(OpenAIServing): ChatTemplateContentFormatOption, getattr(self, "chat_template_content_format", "auto"), ), - add_generation_prompt=False, - continue_final_message=False, + add_generation_prompt=chat_request.add_generation_prompt, + continue_final_message=chat_request.continue_final_message, add_special_tokens=chat_request.add_special_tokens, ) ctx.engine_prompts = engine_prompts diff --git a/vllm/entrypoints/pooling/embed/protocol.py b/vllm/entrypoints/pooling/embed/protocol.py index db3d74052..ece014f4a 100644 --- a/vllm/entrypoints/pooling/embed/protocol.py +++ b/vllm/entrypoints/pooling/embed/protocol.py @@ -5,13 +5,12 @@ from typing import Any, TypeAlias from pydantic import ( Field, - model_validator, ) from vllm import PoolingParams -from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.pooling.base.protocol import ( + ChatRequestMixin, CompletionRequestMixin, PoolingBasicRequestMixin, ) @@ -57,57 +56,11 @@ class EmbeddingCompletionRequest(PoolingBasicRequestMixin, CompletionRequestMixi ) -class EmbeddingChatRequest(PoolingBasicRequestMixin): - messages: list[ChatCompletionMessageParam] - +class EmbeddingChatRequest(PoolingBasicRequestMixin, ChatRequestMixin): encoding_format: EncodingFormat = "float" dimensions: int | None = None # --8<-- [start:chat-embedding-extra-params] - add_generation_prompt: bool = Field( - default=False, - description=( - "If true, the generation prompt will be added to the chat template. " - "This is a parameter used by chat template in tokenizer config of the " - "model." - ), - ) - continue_final_message: bool = Field( - default=False, - description=( - "If this is set, the chat will be formatted so that the final " - "message in the chat is open-ended, without any EOS tokens. The " - "model will continue this message rather than starting a new one. " - 'This allows you to "prefill" part of the model\'s response for it. ' - "Cannot be used at the same time as `add_generation_prompt`." - ), - ) - add_special_tokens: bool = Field( - default=False, - description=( - "If true, special tokens (e.g. BOS) will be added to the prompt " - "on top of what is added by the chat template. " - "For most models, the chat template takes care of adding the " - "special tokens so this should be set to false (as is the " - "default)." - ), - ) - chat_template: str | None = Field( - default=None, - description=( - "A Jinja template to use for this conversion. " - "As of transformers v4.44, default chat template is no longer " - "allowed, so you must provide a chat template if the tokenizer " - "does not define one." - ), - ) - chat_template_kwargs: dict[str, Any] | None = Field( - default=None, - description=( - "Additional keyword args to pass to the template renderer. " - "Will be accessible by the chat template." - ), - ) mm_processor_kwargs: dict[str, Any] | None = Field( default=None, description=("Additional kwargs to pass to the HF processor."), @@ -134,16 +87,6 @@ class EmbeddingChatRequest(PoolingBasicRequestMixin): ) # --8<-- [end:chat-embedding-extra-params] - @model_validator(mode="before") - @classmethod - def check_generation_prompt(cls, data): - if data.get("continue_final_message") and data.get("add_generation_prompt"): - raise ValueError( - "Cannot set both `continue_final_message` and " - "`add_generation_prompt` to True." - ) - return data - def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index 1a2bfd770..b53caf81b 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -144,10 +144,8 @@ class OpenAIServingPooling(OpenAIServing): request.messages, chat_template=request.chat_template or self.chat_template, chat_template_content_format=self.chat_template_content_format, - # In pooling requests, we are not generating tokens, - # so there is no need to append extra tokens to the input - add_generation_prompt=False, - continue_final_message=False, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, add_special_tokens=request.add_special_tokens, ) elif isinstance(request, PoolingCompletionRequest):