Push logprob generation to LLMEngine (#3065)

Co-authored-by: Avnish Narayan <avnish@anyscale.com>
This commit is contained in:
Antoni Baum
2024-03-04 11:54:06 -08:00
committed by GitHub
parent 76e8a70476
commit 22de45235c
13 changed files with 551 additions and 331 deletions

View File

@@ -4,7 +4,7 @@ from typing import List, Optional, Dict
from vllm.worker.worker import Worker
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import SequenceGroupMetadata, SequenceData
from vllm.sequence import Logprob, SequenceGroupMetadata, SequenceData
from vllm.sampling_params import SamplingParams
from vllm.worker.cache_engine import CacheEngine
from vllm.model_executor.utils import set_random_seed
@@ -166,13 +166,15 @@ def create_seq_group_metadata_from_prompts(
def assert_logprobs_dict_allclose(
actual_logprobs: List[Dict[int, float]],
expected_logprobs: List[Dict[int, float]]) -> None:
actual_logprobs: List[Dict[int, Logprob]],
expected_logprobs: List[Dict[int, Logprob]]) -> None:
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
actual_logprobs, expected_logprobs):
assert set(single_step_actual_logprobs.keys()) == set(
single_step_expected_logprobs.keys())
for token_id in single_step_actual_logprobs:
actual = torch.tensor(single_step_actual_logprobs[token_id])
expected = torch.tensor(single_step_expected_logprobs[token_id])
actual = torch.tensor(
single_step_actual_logprobs[token_id].logprob)
expected = torch.tensor(
single_step_expected_logprobs[token_id].logprob)
assert torch.allclose(actual, expected)