[Core] [Bugfix] Add Input Embeddings (#15428)

Signed-off-by: Andrew Sansom <andrew@protopia.ai>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: 临景 <linjing.yx@alibaba-inc.com>
Co-authored-by: Bryce1010 <bryceyx@gmail.com>
Co-authored-by: Nan2018 <nan@protopia.ai>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Andrew Sansom
2025-05-02 03:06:39 -05:00
committed by GitHub
parent 9e2de9b9e9
commit cc2a77d7f1
22 changed files with 691 additions and 113 deletions

View File

@@ -1,4 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Optional
import pytest
import torch
@@ -110,6 +113,18 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv(
"VLLM_USE_V1") == "0" else None
prompt_token_ids = []
for prompt in example_prompts:
token_ids = hf_model.tokenizer(prompt,
return_tensors="pt").input_ids.to(
hf_model.model.device)
prompt_token_ids.append(token_ids)
if prompt_embeds is not None:
prompt_embeds.append(hf_model.model.get_input_embeddings()(
token_ids).squeeze(0))
with vllm_runner(
model,
tokenizer_name=model_info.tokenizer or model,
@@ -119,6 +134,9 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
if prompt_embeds is not None:
vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs(
prompt_embeds, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
@@ -126,6 +144,14 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
name_0="hf",
name_1="vllm",
)
if prompt_embeds is not None:
check_logprobs_close(
outputs_0_lst=vllm_outputs,
outputs_1_lst=vllm_outputs_from_embeds,
name_0="vllm",
name_1="vllm_from_embeds",
)
if use_rocm_aiter:
# this is to ensure that vllm engine
# has deallocated the memory before running the next