Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -30,22 +30,27 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
||||
BatchEncoding, BatchFeature)
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BatchEncoding,
|
||||
BatchFeature,
|
||||
)
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from tests.models.utils import (TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs)
|
||||
from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config.model import (ConvertOption, RunnerOption,
|
||||
_get_and_verify_dtype)
|
||||
from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.distributed import (
|
||||
cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
@@ -82,12 +87,13 @@ class ImageAssetPrompts(TypedDict):
|
||||
|
||||
|
||||
class ImageTestAssets(list[ImageAsset]):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__([
|
||||
ImageAsset("stop_sign"),
|
||||
ImageAsset("cherry_blossom"),
|
||||
])
|
||||
super().__init__(
|
||||
[
|
||||
ImageAsset("stop_sign"),
|
||||
ImageAsset("cherry_blossom"),
|
||||
]
|
||||
)
|
||||
|
||||
def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
|
||||
"""
|
||||
@@ -104,11 +110,12 @@ class VideoAssetPrompts(TypedDict):
|
||||
|
||||
|
||||
class VideoTestAssets(list[VideoAsset]):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__([
|
||||
VideoAsset("baby_reading"),
|
||||
])
|
||||
super().__init__(
|
||||
[
|
||||
VideoAsset("baby_reading"),
|
||||
]
|
||||
)
|
||||
|
||||
def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
|
||||
return [prompts["baby_reading"]]
|
||||
@@ -120,12 +127,13 @@ class AudioAssetPrompts(TypedDict):
|
||||
|
||||
|
||||
class AudioTestAssets(list[AudioAsset]):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__([
|
||||
AudioAsset("mary_had_lamb"),
|
||||
AudioAsset("winning_call"),
|
||||
])
|
||||
super().__init__(
|
||||
[
|
||||
AudioAsset("mary_had_lamb"),
|
||||
AudioAsset("winning_call"),
|
||||
]
|
||||
)
|
||||
|
||||
def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
|
||||
return [prompts["mary_had_lamb"], prompts["winning_call"]]
|
||||
@@ -220,6 +228,7 @@ def example_system_message() -> str:
|
||||
|
||||
class DecoderPromptType(Enum):
|
||||
"""For encoder/decoder models only."""
|
||||
|
||||
CUSTOM = 1
|
||||
NONE = 2
|
||||
EMPTY_STR = 3
|
||||
@@ -253,15 +262,13 @@ _R = TypeVar("_R")
|
||||
|
||||
|
||||
class HfRunner:
|
||||
|
||||
def get_default_device(self):
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return ("cpu"
|
||||
if current_platform.is_cpu() else current_platform.device_type)
|
||||
return "cpu" if current_platform.is_cpu() else current_platform.device_type
|
||||
|
||||
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
||||
if x is None or isinstance(x, (bool, )):
|
||||
if x is None or isinstance(x, (bool,)):
|
||||
return x
|
||||
|
||||
if device is None:
|
||||
@@ -289,8 +296,11 @@ class HfRunner:
|
||||
# Set this to avoid hanging issue
|
||||
default_torch_num_threads: Optional[int] = None,
|
||||
) -> None:
|
||||
init_ctx = (nullcontext() if default_torch_num_threads is None else
|
||||
set_default_torch_num_threads(default_torch_num_threads))
|
||||
init_ctx = (
|
||||
nullcontext()
|
||||
if default_torch_num_threads is None
|
||||
else set_default_torch_num_threads(default_torch_num_threads)
|
||||
)
|
||||
|
||||
with init_ctx:
|
||||
self._init(
|
||||
@@ -362,14 +372,15 @@ class HfRunner:
|
||||
)
|
||||
|
||||
# in case some unquantized custom models are not in same dtype
|
||||
if (getattr(model, "quantization_method", None) is None
|
||||
and any(p.dtype != self.dtype
|
||||
for p in model.parameters())):
|
||||
if getattr(model, "quantization_method", None) is None and any(
|
||||
p.dtype != self.dtype for p in model.parameters()
|
||||
):
|
||||
model = model.to(dtype=self.dtype)
|
||||
|
||||
if (getattr(model, "quantization_method", None) != "bitsandbytes"
|
||||
and len({p.device
|
||||
for p in model.parameters()}) < 2):
|
||||
if (
|
||||
getattr(model, "quantization_method", None) != "bitsandbytes"
|
||||
and len({p.device for p in model.parameters()}) < 2
|
||||
):
|
||||
model = model.to(device=self.device)
|
||||
|
||||
self.model = model
|
||||
@@ -384,6 +395,7 @@ class HfRunner:
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoProcessor # noqa: F401
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
@@ -471,10 +483,9 @@ class HfRunner:
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
all_inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
all_inputs = self.get_inputs(
|
||||
prompts, images=images, videos=videos, audios=audios
|
||||
)
|
||||
|
||||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||
for inputs in all_inputs:
|
||||
@@ -501,16 +512,17 @@ class HfRunner:
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[list[int], str]]:
|
||||
outputs = self.generate(prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios,
|
||||
**kwargs)
|
||||
outputs = self.generate(
|
||||
prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return [(output_ids[0], output_str[0])
|
||||
for output_ids, output_str in outputs]
|
||||
return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
|
||||
|
||||
def generate_beam_search(
|
||||
self,
|
||||
@@ -521,21 +533,22 @@ class HfRunner:
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
outputs = self.generate(prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
num_beams=beam_width,
|
||||
num_return_sequences=beam_width,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
outputs = self.generate(
|
||||
prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
num_beams=beam_width,
|
||||
num_return_sequences=beam_width,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios,
|
||||
)
|
||||
|
||||
for i in range(len(outputs)):
|
||||
output_ids, output_str = outputs[i]
|
||||
for j in range(len(output_ids)):
|
||||
output_ids[j] = [
|
||||
x for x in output_ids[j]
|
||||
if x != self.tokenizer.pad_token_id
|
||||
x for x in output_ids[j] if x != self.tokenizer.pad_token_id
|
||||
]
|
||||
outputs[i] = (output_ids, output_str)
|
||||
return outputs
|
||||
@@ -549,10 +562,9 @@ class HfRunner:
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[list[torch.Tensor]]:
|
||||
all_inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
all_inputs = self.get_inputs(
|
||||
prompts, images=images, videos=videos, audios=audios
|
||||
)
|
||||
|
||||
all_logprobs: list[list[torch.Tensor]] = []
|
||||
for inputs in all_inputs:
|
||||
@@ -565,8 +577,7 @@ class HfRunner:
|
||||
return_dict_in_generate=True,
|
||||
**kwargs,
|
||||
)
|
||||
seq_logprobs = self._hidden_states_to_seq_logprobs(
|
||||
output.hidden_states)
|
||||
seq_logprobs = self._hidden_states_to_seq_logprobs(output.hidden_states)
|
||||
all_logprobs.append(seq_logprobs)
|
||||
return all_logprobs
|
||||
|
||||
@@ -630,10 +641,9 @@ class HfRunner:
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[TokensTextLogprobs]:
|
||||
all_inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
all_inputs = self.get_inputs(
|
||||
prompts, images=images, videos=videos, audios=audios
|
||||
)
|
||||
|
||||
all_logprobs: list[list[dict[int, float]]] = []
|
||||
all_output_ids: list[list[int]] = []
|
||||
@@ -653,8 +663,7 @@ class HfRunner:
|
||||
(
|
||||
seq_logprobs_lst,
|
||||
output_len,
|
||||
) = self._hidden_states_to_logprobs(output.hidden_states,
|
||||
num_logprobs)
|
||||
) = self._hidden_states_to_logprobs(output.hidden_states, num_logprobs)
|
||||
|
||||
all_logprobs.append(seq_logprobs_lst)
|
||||
seq_ids = output.sequences[0]
|
||||
@@ -664,19 +673,16 @@ class HfRunner:
|
||||
all_output_strs.append(self.tokenizer.decode(output_ids))
|
||||
|
||||
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
|
||||
return [(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs]
|
||||
return [
|
||||
(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs
|
||||
]
|
||||
|
||||
def encode(self, prompts: list[str], *args,
|
||||
**kwargs) -> list[list[torch.Tensor]]:
|
||||
def encode(self, prompts: list[str], *args, **kwargs) -> list[list[torch.Tensor]]:
|
||||
return self.model.encode(prompts, *args, **kwargs)
|
||||
|
||||
def predict(self, prompts: list[list[str]], *args,
|
||||
**kwargs) -> torch.Tensor:
|
||||
return self.model.predict(prompts,
|
||||
*args,
|
||||
convert_to_tensor=True,
|
||||
**kwargs)
|
||||
def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
|
||||
return self.model.predict(prompts, *args, convert_to_tensor=True, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
@@ -727,8 +733,11 @@ class VllmRunner:
|
||||
default_torch_num_threads: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
init_ctx = (nullcontext() if default_torch_num_threads is None else
|
||||
set_default_torch_num_threads(default_torch_num_threads))
|
||||
init_ctx = (
|
||||
nullcontext()
|
||||
if default_torch_num_threads is None
|
||||
else set_default_torch_num_threads(default_torch_num_threads)
|
||||
)
|
||||
|
||||
if not kwargs.get("compilation_config", None):
|
||||
kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]}
|
||||
@@ -760,11 +769,12 @@ class VllmRunner:
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if any(x is not None and len(x) != len(prompts)
|
||||
for x in [images, videos, audios]):
|
||||
if any(
|
||||
x is not None and len(x) != len(prompts) for x in [images, videos, audios]
|
||||
):
|
||||
raise ValueError(
|
||||
"All non-None multimodal inputs must have the same length as "
|
||||
"prompts")
|
||||
"All non-None multimodal inputs must have the same length as prompts"
|
||||
)
|
||||
|
||||
inputs = list[dict[str, Any]]()
|
||||
for i, prompt in enumerate(prompts):
|
||||
@@ -800,14 +810,11 @@ class VllmRunner:
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||
|
||||
req_outputs = self.llm.generate(inputs,
|
||||
sampling_params=sampling_params,
|
||||
**kwargs)
|
||||
req_outputs = self.llm.generate(
|
||||
inputs, sampling_params=sampling_params, **kwargs
|
||||
)
|
||||
|
||||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||
for req_output in req_outputs:
|
||||
@@ -834,8 +841,9 @@ class VllmRunner:
|
||||
output_str = sample.text
|
||||
output_ids = list(sample.token_ids)
|
||||
output_logprobs = sample.logprobs
|
||||
outputs.append((output_ids, output_str, output_logprobs,
|
||||
req_output.prompt_logprobs))
|
||||
outputs.append(
|
||||
(output_ids, output_str, output_logprobs, req_output.prompt_logprobs)
|
||||
)
|
||||
return outputs
|
||||
|
||||
def generate_w_logprobs(
|
||||
@@ -846,23 +854,22 @@ class VllmRunner:
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[list[TokensTextLogprobs],
|
||||
list[TokensTextLogprobsPromptLogprobs]]:
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]:
|
||||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||
|
||||
req_outputs = self.llm.generate(inputs,
|
||||
sampling_params=sampling_params,
|
||||
**kwargs)
|
||||
req_outputs = self.llm.generate(
|
||||
inputs, sampling_params=sampling_params, **kwargs
|
||||
)
|
||||
|
||||
toks_str_logsprobs_prompt_logprobs = (
|
||||
self._final_steps_generate_w_logprobs(req_outputs))
|
||||
toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs(
|
||||
req_outputs
|
||||
)
|
||||
# Omit prompt logprobs if not required by sampling params
|
||||
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
||||
if sampling_params.prompt_logprobs is None else
|
||||
toks_str_logsprobs_prompt_logprobs)
|
||||
return (
|
||||
[x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
||||
if sampling_params.prompt_logprobs is None
|
||||
else toks_str_logsprobs_prompt_logprobs
|
||||
)
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
@@ -874,14 +881,15 @@ class VllmRunner:
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[list[int], str]]:
|
||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
outputs = self.generate(prompts,
|
||||
greedy_params,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios,
|
||||
**kwargs)
|
||||
return [(output_ids[0], output_str[0])
|
||||
for output_ids, output_str in outputs]
|
||||
outputs = self.generate(
|
||||
prompts,
|
||||
greedy_params,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios,
|
||||
**kwargs,
|
||||
)
|
||||
return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
|
||||
|
||||
def generate_greedy_logprobs(
|
||||
self,
|
||||
@@ -895,22 +903,24 @@ class VllmRunner:
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[list[TokensTextLogprobs],
|
||||
list[TokensTextLogprobsPromptLogprobs]]:
|
||||
) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]:
|
||||
greedy_logprobs_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs,
|
||||
stop_token_ids=stop_token_ids,
|
||||
stop=stop)
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
return self.generate_w_logprobs(prompts,
|
||||
greedy_logprobs_params,
|
||||
images=images,
|
||||
audios=audios,
|
||||
videos=videos,
|
||||
**kwargs)
|
||||
return self.generate_w_logprobs(
|
||||
prompts,
|
||||
greedy_logprobs_params,
|
||||
images=images,
|
||||
audios=audios,
|
||||
videos=videos,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
|
||||
"""
|
||||
@@ -919,10 +929,9 @@ class VllmRunner:
|
||||
:param prompts: list of prompts to score
|
||||
:return: perplexity score of each prompt
|
||||
"""
|
||||
outputs = self.generate_greedy_logprobs(prompts,
|
||||
max_tokens=1,
|
||||
num_logprobs=None,
|
||||
num_prompt_logprobs=0)
|
||||
outputs = self.generate_greedy_logprobs(
|
||||
prompts, max_tokens=1, num_logprobs=None, num_prompt_logprobs=0
|
||||
)
|
||||
|
||||
perplexities = []
|
||||
for output in outputs:
|
||||
@@ -951,15 +960,13 @@ class VllmRunner:
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
concurrency_limit: Optional[int] = None,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||
|
||||
outputs = self.llm.beam_search(inputs,
|
||||
BeamSearchParams(beam_width=beam_width,
|
||||
max_tokens=max_tokens),
|
||||
concurrency_limit=concurrency_limit)
|
||||
outputs = self.llm.beam_search(
|
||||
inputs,
|
||||
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens),
|
||||
concurrency_limit=concurrency_limit,
|
||||
)
|
||||
returned_outputs = []
|
||||
for output in outputs:
|
||||
token_ids = [x.tokens for x in output.sequences]
|
||||
@@ -971,17 +978,16 @@ class VllmRunner:
|
||||
req_outputs = self.llm.classify(prompts)
|
||||
return [req_output.outputs.probs for req_output in req_outputs]
|
||||
|
||||
def embed(self,
|
||||
prompts: list[str],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
*args,
|
||||
**kwargs) -> list[list[float]]:
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
def embed(
|
||||
self,
|
||||
prompts: list[str],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> list[list[float]]:
|
||||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||
|
||||
req_outputs = self.llm.embed(inputs, *args, **kwargs)
|
||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||
@@ -1026,6 +1032,7 @@ def vllm_runner():
|
||||
@pytest.fixture()
|
||||
def temporary_enable_log_propagate():
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("vllm")
|
||||
logger.propagate = True
|
||||
yield
|
||||
@@ -1045,6 +1052,7 @@ def num_gpus_available():
|
||||
in current process."""
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return current_platform.device_count()
|
||||
|
||||
|
||||
@@ -1058,12 +1066,11 @@ _dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
|
||||
def dummy_opt_path():
|
||||
json_path = os.path.join(_dummy_opt_path, "config.json")
|
||||
if not os.path.exists(_dummy_opt_path):
|
||||
snapshot_download(repo_id="facebook/opt-125m",
|
||||
local_dir=_dummy_opt_path,
|
||||
ignore_patterns=[
|
||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
||||
"*.msgpack"
|
||||
])
|
||||
snapshot_download(
|
||||
repo_id="facebook/opt-125m",
|
||||
local_dir=_dummy_opt_path,
|
||||
ignore_patterns=["*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack"],
|
||||
)
|
||||
assert os.path.exists(json_path)
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
@@ -1077,12 +1084,18 @@ def dummy_opt_path():
|
||||
def dummy_llava_path():
|
||||
json_path = os.path.join(_dummy_llava_path, "config.json")
|
||||
if not os.path.exists(_dummy_llava_path):
|
||||
snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
|
||||
local_dir=_dummy_llava_path,
|
||||
ignore_patterns=[
|
||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
||||
"*.msgpack", "*.safetensors"
|
||||
])
|
||||
snapshot_download(
|
||||
repo_id="llava-hf/llava-1.5-7b-hf",
|
||||
local_dir=_dummy_llava_path,
|
||||
ignore_patterns=[
|
||||
"*.bin",
|
||||
"*.bin.index.json",
|
||||
"*.pt",
|
||||
"*.h5",
|
||||
"*.msgpack",
|
||||
"*.safetensors",
|
||||
],
|
||||
)
|
||||
assert os.path.exists(json_path)
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
@@ -1096,12 +1109,18 @@ def dummy_llava_path():
|
||||
def dummy_gemma2_embedding_path():
|
||||
json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
|
||||
if not os.path.exists(_dummy_gemma2_embedding_path):
|
||||
snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
|
||||
local_dir=_dummy_gemma2_embedding_path,
|
||||
ignore_patterns=[
|
||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
||||
"*.msgpack", "*.safetensors"
|
||||
])
|
||||
snapshot_download(
|
||||
repo_id="BAAI/bge-multilingual-gemma2",
|
||||
local_dir=_dummy_gemma2_embedding_path,
|
||||
ignore_patterns=[
|
||||
"*.bin",
|
||||
"*.bin.index.json",
|
||||
"*.pt",
|
||||
"*.h5",
|
||||
"*.msgpack",
|
||||
"*.safetensors",
|
||||
],
|
||||
)
|
||||
assert os.path.exists(json_path)
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
@@ -1114,10 +1133,9 @@ def dummy_gemma2_embedding_path():
|
||||
# Add the flag `--optional` to allow run tests
|
||||
# that are marked with @pytest.mark.optional
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--optional",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="run optional test")
|
||||
parser.addoption(
|
||||
"--optional", action="store_true", default=False, help="run optional test"
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
@@ -1185,7 +1203,6 @@ def _find_free_port() -> int:
|
||||
|
||||
|
||||
class LocalAssetServer:
|
||||
|
||||
address: str
|
||||
port: int
|
||||
server: Optional[http.server.ThreadingHTTPServer]
|
||||
@@ -1200,9 +1217,9 @@ class LocalAssetServer:
|
||||
def __enter__(self):
|
||||
self.port = _find_free_port()
|
||||
self.server = http.server.ThreadingHTTPServer(
|
||||
(self.address, self.port), AssetHandler)
|
||||
self.thread = threading.Thread(target=self.server.serve_forever,
|
||||
daemon=True)
|
||||
(self.address, self.port), AssetHandler
|
||||
)
|
||||
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
||||
self.thread.start()
|
||||
return self
|
||||
|
||||
@@ -1236,7 +1253,7 @@ class LocalAssetServer:
|
||||
@pytest.fixture(scope="session")
|
||||
def local_asset_server() -> Generator[LocalAssetServer, None, None]:
|
||||
"""
|
||||
Starts a thread based HTTP server bound to 127.0.0.1 on a random free port.
|
||||
Starts a thread based HTTP server bound to 127.0.0.1 on a random free port.
|
||||
The server currently servers images at:
|
||||
http://127.0.0.1:<port>/<name>.<ext>
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user