[CI/Test] improve robustness of test (vllm_runner) (#5357)

[CI/Test] improve robustness of test by replacing del with context manager (vllm_runner) (#5357)
This commit is contained in:
youkaichao
2024-06-08 01:59:20 -07:00
committed by GitHub
parent 9fb900f90c
commit 8ea5e44a43
28 changed files with 431 additions and 470 deletions

View File

@@ -1,4 +1,3 @@
import gc
import json
import os
import subprocess
@@ -7,7 +6,6 @@ from unittest.mock import MagicMock, patch
import openai
import pytest
import ray
import torch
from vllm import SamplingParams
# yapf: disable
@@ -71,47 +69,43 @@ def test_can_deserialize_s3(vllm_runner):
model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
loaded_hf_model = vllm_runner(model_ref,
with vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=tensorized_path,
num_readers=1,
s3_endpoint="object.ord1.coreweave.com",
))
)) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params)
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) # noqa: E501
assert deserialized_outputs
assert deserialized_outputs
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
vllm_runner, tmp_path):
vllm_model = vllm_runner(model_ref)
model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key")
outputs = vllm_model.generate(prompts, sampling_params)
with vllm_runner(model_ref) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key")
outputs = vllm_model.generate(prompts, sampling_params)
config_for_serializing = TensorizerConfig(tensorizer_uri=model_path)
serialize_vllm_model(vllm_model.model.llm_engine,
config_for_serializing,
encryption_key_path=key_path)
del vllm_model
gc.collect()
torch.cuda.empty_cache()
config_for_serializing = TensorizerConfig(tensorizer_uri=model_path)
serialize_vllm_model(vllm_model.model.llm_engine,
config_for_serializing,
encryption_key_path=key_path)
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
encryption_keyfile=key_path)
loaded_vllm_model = vllm_runner(
with vllm_runner(
model_ref,
load_format="tensorizer",
model_loader_extra_config=config_for_deserializing)
model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501
assert outputs == deserialized_outputs
assert outputs == deserialized_outputs
def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
@@ -124,17 +118,17 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
serializer = TensorSerializer(stream)
serializer.write_module(hf_model.model)
loaded_hf_model = vllm_runner(model_ref,
with vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path,
num_readers=1,
))
)) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate_greedy(
prompts, max_tokens=max_tokens)
deserialized_outputs = loaded_hf_model.generate_greedy(
prompts, max_tokens=max_tokens)
assert outputs == deserialized_outputs
assert outputs == deserialized_outputs
def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
@@ -148,16 +142,13 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
test_prompts = create_test_prompts(lora_path)
# Serialize model before deserializing and binding LoRA adapters
vllm_model = vllm_runner(model_ref, )
model_path = tmp_path / (model_ref + ".tensors")
with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(vllm_model.model.llm_engine,
TensorizerConfig(tensorizer_uri=model_path))
serialize_vllm_model(vllm_model.model.llm_engine,
TensorizerConfig(tensorizer_uri=model_path))
del vllm_model
gc.collect()
torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner(
with vllm_runner(
model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
@@ -170,10 +161,10 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
max_cpu_loras=2,
max_num_seqs=50,
max_model_len=1000,
)
process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
) as loaded_vllm_model:
process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
assert loaded_vllm_model
assert loaded_vllm_model
def test_load_without_tensorizer_load_format(vllm_runner):
@@ -186,19 +177,15 @@ def test_load_without_tensorizer_load_format(vllm_runner):
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
## Serialize model
vllm_model = vllm_runner(model_ref, )
model_path = tmp_path / (model_ref + ".tensors")
with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(vllm_model.model.llm_engine,
TensorizerConfig(tensorizer_uri=model_path))
serialize_vllm_model(vllm_model.model.llm_engine,
TensorizerConfig(tensorizer_uri=model_path))
model_loader_extra_config = {
"tensorizer_uri": str(model_path),
}
del vllm_model
gc.collect()
torch.cuda.empty_cache()
model_loader_extra_config = {
"tensorizer_uri": str(model_path),
}
## Start OpenAI API server
openai_args = [
@@ -260,18 +247,15 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path))
vllm_model = vllm_runner(model_ref)
outputs = vllm_model.generate(prompts, sampling_params)
serialize_vllm_model(vllm_model.model.llm_engine, config)
with vllm_runner(model_ref) as vllm_model:
outputs = vllm_model.generate(prompts, sampling_params)
serialize_vllm_model(vllm_model.model.llm_engine, config)
assert is_vllm_tensorized(config)
del vllm_model
gc.collect()
torch.cuda.empty_cache()
assert is_vllm_tensorized(config)
loaded_vllm_model = vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=config)
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
with vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=config) as loaded_vllm_model:
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501
assert outputs == deserialized_outputs
assert outputs == deserialized_outputs