Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -8,8 +8,7 @@ import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
import tests.ci_envs as ci_envs
|
||||
from tests.models.utils import (GenerateModelInfo,
|
||||
TokensTextLogprobsPromptLogprobs)
|
||||
from tests.models.utils import GenerateModelInfo, TokensTextLogprobsPromptLogprobs
|
||||
from vllm.logprobs import Logprob
|
||||
|
||||
# See #24485
|
||||
@@ -18,13 +17,14 @@ MAX_LENGTH = 1024
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def wikitext_ppl_test(hf_runner,
|
||||
vllm_runner,
|
||||
model_info: GenerateModelInfo,
|
||||
max_length=MAX_LENGTH,
|
||||
vllm_extra_kwargs=None,
|
||||
atol=PPL_TOL):
|
||||
|
||||
def wikitext_ppl_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model_info: GenerateModelInfo,
|
||||
max_length=MAX_LENGTH,
|
||||
vllm_extra_kwargs=None,
|
||||
atol=PPL_TOL,
|
||||
):
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
|
||||
@@ -44,15 +44,16 @@ def wikitext_ppl_test(hf_runner,
|
||||
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
|
||||
if "hf_overrides" not in vllm_extra_kwargs:
|
||||
vllm_extra_kwargs["hf_overrides"] = {}
|
||||
vllm_extra_kwargs["hf_overrides"][
|
||||
"head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
|
||||
vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
|
||||
|
||||
with vllm_runner(model_info.name,
|
||||
gpu_memory_utilization=0.7,
|
||||
max_model_len=max_length,
|
||||
max_num_seqs=1,
|
||||
enforce_eager=True,
|
||||
**vllm_extra_kwargs) as vllm_model:
|
||||
with vllm_runner(
|
||||
model_info.name,
|
||||
gpu_memory_utilization=0.7,
|
||||
max_model_len=max_length,
|
||||
max_num_seqs=1,
|
||||
enforce_eager=True,
|
||||
**vllm_extra_kwargs,
|
||||
) as vllm_model:
|
||||
# Use max_num_seqs=1 to avoid OOM,
|
||||
# and avoid batch different requests together.
|
||||
|
||||
@@ -60,7 +61,7 @@ def wikitext_ppl_test(hf_runner,
|
||||
|
||||
# Confirm whether vllm is using the correct architecture
|
||||
if model_info.architecture:
|
||||
assert (model_info.architecture in model_config.architectures)
|
||||
assert model_info.architecture in model_config.architectures
|
||||
|
||||
max_length = min(model_config.max_model_len - 1, max_length)
|
||||
stride = max_length
|
||||
@@ -74,12 +75,14 @@ def wikitext_ppl_test(hf_runner,
|
||||
end_loc = min(begin_loc + max_length, n_tokens)
|
||||
chunks.append(tokens[begin_loc:end_loc])
|
||||
|
||||
outputs = vllm_model.generate_greedy_logprobs(prompts=chunks,
|
||||
max_tokens=1,
|
||||
num_logprobs=None,
|
||||
num_prompt_logprobs=0,
|
||||
use_tqdm=False)
|
||||
nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu")
|
||||
outputs = vllm_model.generate_greedy_logprobs(
|
||||
prompts=chunks,
|
||||
max_tokens=1,
|
||||
num_logprobs=None,
|
||||
num_prompt_logprobs=0,
|
||||
use_tqdm=False,
|
||||
)
|
||||
nll_sum = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
n_tokens = 0
|
||||
for output in outputs:
|
||||
output = cast(TokensTextLogprobsPromptLogprobs, output)
|
||||
@@ -94,7 +97,8 @@ def wikitext_ppl_test(hf_runner,
|
||||
token_log_probs.append(token_log_prob)
|
||||
|
||||
neg_log_likelihood = -torch.tensor(
|
||||
token_log_probs, dtype=torch.float32, device="cpu").sum()
|
||||
token_log_probs, dtype=torch.float32, device="cpu"
|
||||
).sum()
|
||||
nll_sum += neg_log_likelihood
|
||||
n_tokens += len(token_log_probs)
|
||||
vllm_ppl = float(torch.exp(nll_sum / n_tokens))
|
||||
@@ -104,14 +108,13 @@ def wikitext_ppl_test(hf_runner,
|
||||
# Accelerate ppl test by setting Transformers ppl score to a constant
|
||||
if model_info.hf_ppl is None:
|
||||
with hf_runner(
|
||||
model_info.name,
|
||||
dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype,
|
||||
model_info.name,
|
||||
dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype,
|
||||
) as hf_model:
|
||||
nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu")
|
||||
nll_sum = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
n_tokens = 0
|
||||
for chunk in chunks:
|
||||
inputs = hf_model.wrap_device(
|
||||
{"input_ids": torch.tensor([chunk])})
|
||||
inputs = hf_model.wrap_device({"input_ids": torch.tensor([chunk])})
|
||||
input_ids = inputs["input_ids"]
|
||||
outputs = hf_model.model(input_ids, labels=input_ids)
|
||||
neg_log_likelihood = outputs.loss
|
||||
|
||||
Reference in New Issue
Block a user