(bugfix): Fixed encode in LLM entrypoint for IOProcessr plugin prompts (#34618)
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
This commit is contained in:
@@ -20,13 +20,15 @@ def main():
|
||||
torch.set_default_dtype(torch.float16)
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
|
||||
|
||||
img_prompt = dict(
|
||||
img_data = dict(
|
||||
data=image_url,
|
||||
data_format="url",
|
||||
image_format="tiff",
|
||||
out_data_format="b64_json",
|
||||
)
|
||||
|
||||
prompt = dict(data=img_data)
|
||||
|
||||
llm = LLM(
|
||||
model="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
|
||||
skip_tokenizer_init=True,
|
||||
@@ -41,7 +43,7 @@ def main():
|
||||
enable_mm_embeds=True,
|
||||
)
|
||||
|
||||
pooler_output = llm.encode(img_prompt, pooling_task="plugin")
|
||||
pooler_output = llm.encode(prompt, pooling_task="plugin")
|
||||
output = pooler_output[0].outputs
|
||||
|
||||
print(output)
|
||||
|
||||
@@ -120,13 +120,15 @@ async def test_prithvi_mae_plugin_online(
|
||||
def test_prithvi_mae_plugin_offline(
|
||||
vllm_runner, model_name: str, image_url: str | dict, plugin: str, expected_hash: str
|
||||
):
|
||||
img_prompt = dict(
|
||||
img_data = dict(
|
||||
data=image_url,
|
||||
data_format="url",
|
||||
image_format="tiff",
|
||||
out_data_format="b64_json",
|
||||
)
|
||||
|
||||
prompt = dict(data=img_data)
|
||||
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
@@ -139,7 +141,7 @@ def test_prithvi_mae_plugin_offline(
|
||||
io_processor_plugin=plugin,
|
||||
default_torch_num_threads=1,
|
||||
) as llm_runner:
|
||||
pooler_output = llm_runner.get_llm().encode(img_prompt, pooling_task="plugin")
|
||||
pooler_output = llm_runner.get_llm().encode(prompt, pooling_task="plugin")
|
||||
output = pooler_output[0].outputs
|
||||
|
||||
# verify the output is formatted as expected for this plugin
|
||||
|
||||
@@ -1135,7 +1135,15 @@ class LLM:
|
||||
)
|
||||
|
||||
# Validate the request data is valid for the loaded plugin
|
||||
validated_prompt = self.io_processor.parse_data(prompts)
|
||||
prompt_data = prompts.get("data")
|
||||
if prompt_data is None:
|
||||
raise ValueError(
|
||||
"The 'data' field of the prompt is expected to contain "
|
||||
"the prompt data and it cannot be None. "
|
||||
"Refer to the documentation of the IOProcessor "
|
||||
"in use for more details."
|
||||
)
|
||||
validated_prompt = self.io_processor.parse_data(prompt_data)
|
||||
|
||||
# obtain the actual model prompts from the pre-processor
|
||||
prompts = self.io_processor.pre_process(prompt=validated_prompt)
|
||||
|
||||
Reference in New Issue
Block a user