[Speculative decoding] Support target-model logprobs (#4378)
This commit is contained in:
@@ -1,9 +1,13 @@
|
||||
import asyncio
|
||||
import time
|
||||
from itertools import cycle
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
|
||||
nvmlInit)
|
||||
|
||||
from tests.conftest import cleanup
|
||||
from vllm import LLM
|
||||
@@ -13,7 +17,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import MultiModalData
|
||||
from vllm.sequence import Logprob, MultiModalData
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, random_uuid
|
||||
|
||||
@@ -153,12 +157,19 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
||||
test_name = request.node.name
|
||||
|
||||
def generator_inner():
|
||||
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
|
||||
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=list(range(torch.cuda.device_count())),
|
||||
threshold_bytes=2 * 2**30,
|
||||
timeout_s=60,
|
||||
)
|
||||
|
||||
use_async = False
|
||||
if "use_async" in kwargs:
|
||||
use_async = kwargs.pop("use_async")
|
||||
print(f'{use_async=}')
|
||||
|
||||
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
|
||||
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
|
||||
set_random_seed(seed)
|
||||
|
||||
@@ -188,6 +199,20 @@ def get_output_from_llm_generator(
|
||||
return tokens, token_ids
|
||||
|
||||
|
||||
def get_logprobs_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> List[List[Dict[int, Logprob]]]:
|
||||
"""Returns a dict of (token_id: Logprob) for each generated position, for
|
||||
each sequence in the batch.
|
||||
"""
|
||||
for llm in llm_generator():
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
logprobs = [output.outputs[0].logprobs[:] for output in outputs]
|
||||
del llm
|
||||
|
||||
return logprobs
|
||||
|
||||
|
||||
def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
@@ -243,3 +268,38 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
print(f'{i=} {baseline_token_ids=}')
|
||||
print(f'{i=} {spec_token_ids=}')
|
||||
assert baseline_token_ids == spec_token_ids
|
||||
|
||||
|
||||
def wait_for_gpu_memory_to_clear(devices: List[int],
|
||||
threshold_bytes: int,
|
||||
timeout_s: float = 120) -> None:
|
||||
# Use nvml instead of pytorch to reduce measurement error from torch cuda
|
||||
# context.
|
||||
nvmlInit()
|
||||
start_time = time.time()
|
||||
while True:
|
||||
output = {}
|
||||
output_raw = {}
|
||||
for device in devices:
|
||||
dev_handle = nvmlDeviceGetHandleByIndex(device)
|
||||
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
|
||||
gb_used = mem_info.used / 2**30
|
||||
output_raw[device] = gb_used
|
||||
output[device] = f'{gb_used:.02f}'
|
||||
|
||||
print('gpu memory used (GB): ', end='')
|
||||
for k, v in output.items():
|
||||
print(f'{k}={v}; ', end='')
|
||||
print('')
|
||||
|
||||
dur_s = time.time() - start_time
|
||||
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
|
||||
print(f'Done waiting for free GPU memory on devices {devices=} '
|
||||
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
|
||||
break
|
||||
|
||||
if dur_s >= timeout_s:
|
||||
raise ValueError(f'Memory of devices {devices=} not free after '
|
||||
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
Reference in New Issue
Block a user