[Frontend] [Core] perf: Automatically detect vLLM-tensorized model, update tensorizer to version 2.9.0 (#4208)
This commit is contained in:
@@ -10,12 +10,19 @@ import ray
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
EncryptionParams, TensorizerConfig, TensorSerializer,
|
||||
is_vllm_serialized_tensorizer, load_with_tensorizer, open_stream)
|
||||
# yapf: disable
|
||||
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
|
||||
TensorSerializer,
|
||||
is_vllm_tensorized,
|
||||
load_with_tensorizer,
|
||||
open_stream,
|
||||
serialize_vllm_model)
|
||||
|
||||
from ..utils import ServerRunner
|
||||
|
||||
# yapf conflicts with isort for this docstring
|
||||
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@@ -40,7 +47,7 @@ def is_curl_installed():
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def tensorizer_config():
|
||||
config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True)
|
||||
config = TensorizerConfig(tensorizer_uri="vllm")
|
||||
return config
|
||||
|
||||
|
||||
@@ -59,47 +66,6 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config):
|
||||
assert result == mock_agent_instance.deserialize.return_value
|
||||
|
||||
|
||||
def test_is_vllm_model_with_vllm_in_uri(tensorizer_config):
|
||||
tensorizer_config.vllm_tensorized = True
|
||||
|
||||
result = is_vllm_serialized_tensorizer(tensorizer_config)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_is_vllm_model_without_vllm_in_uri(tensorizer_config):
|
||||
tensorizer_config.vllm_tensorized = False
|
||||
|
||||
result = is_vllm_serialized_tensorizer(tensorizer_config)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path):
|
||||
vllm_model = vllm_runner(model_ref)
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
outputs = vllm_model.generate(prompts, sampling_params)
|
||||
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
with open_stream(model_path, "wb+") as stream:
|
||||
serializer = TensorSerializer(stream)
|
||||
serializer.write_module(model)
|
||||
del vllm_model, model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
loaded_vllm_model = vllm_runner(
|
||||
model_ref,
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=TensorizerConfig(tensorizer_uri=model_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=True),
|
||||
)
|
||||
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
|
||||
|
||||
# Assumes SamplingParams being seeded ensures the outputs are deterministic
|
||||
assert outputs == deserialized_outputs
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||
def test_can_deserialize_s3(vllm_runner):
|
||||
model_ref = "EleutherAI/pythia-1.4b"
|
||||
@@ -110,7 +76,6 @@ def test_can_deserialize_s3(vllm_runner):
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=tensorized_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=False,
|
||||
s3_endpoint="object.ord1.coreweave.com",
|
||||
))
|
||||
|
||||
@@ -126,29 +91,26 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
key_path = tmp_path / (model_ref + ".key")
|
||||
outputs = vllm_model.generate(prompts, sampling_params)
|
||||
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
|
||||
encryption_params = EncryptionParams.random()
|
||||
with open_stream(model_path, "wb+") as stream:
|
||||
serializer = TensorSerializer(stream, encryption=encryption_params)
|
||||
serializer.write_module(model)
|
||||
with open_stream(key_path, "wb+") as stream:
|
||||
stream.write(encryption_params.key)
|
||||
del vllm_model, model
|
||||
config_for_serializing = TensorizerConfig(tensorizer_uri=model_path)
|
||||
serialize_vllm_model(vllm_model.model.llm_engine,
|
||||
config_for_serializing,
|
||||
encryption_key_path=key_path)
|
||||
|
||||
del vllm_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
loaded_vllm_model = vllm_runner(model_ref,
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=model_path,
|
||||
encryption_keyfile=key_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=True))
|
||||
|
||||
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
|
||||
encryption_keyfile=key_path)
|
||||
|
||||
loaded_vllm_model = vllm_runner(
|
||||
model_ref,
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=config_for_deserializing)
|
||||
|
||||
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
|
||||
|
||||
# Assumes SamplingParams being seeded ensures the outputs are deterministic
|
||||
assert outputs == deserialized_outputs
|
||||
|
||||
|
||||
@@ -169,7 +131,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=model_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=False))
|
||||
))
|
||||
|
||||
deserialized_outputs = loaded_hf_model.generate_greedy(
|
||||
prompts, max_tokens=max_tokens)
|
||||
@@ -190,12 +152,11 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
||||
# Serialize model before deserializing and binding LoRA adapters
|
||||
vllm_model = vllm_runner(model_ref, )
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
with open_stream(model_path, "wb+") as stream:
|
||||
serializer = TensorSerializer(stream)
|
||||
serializer.write_module(model)
|
||||
del vllm_model, model
|
||||
|
||||
serialize_vllm_model(vllm_model.model.llm_engine,
|
||||
TensorizerConfig(tensorizer_uri=model_path))
|
||||
|
||||
del vllm_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
loaded_vllm_model = vllm_runner(
|
||||
@@ -204,7 +165,6 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=model_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=True,
|
||||
),
|
||||
enable_lora=True,
|
||||
max_loras=1,
|
||||
@@ -220,58 +180,28 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
||||
|
||||
def test_load_without_tensorizer_load_format(vllm_runner):
|
||||
with pytest.raises(ValueError):
|
||||
vllm_runner(model_ref,
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri="test", vllm_tensorized=False))
|
||||
vllm_runner(
|
||||
model_ref,
|
||||
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||
def test_tensorize_vllm_model(tmp_path):
|
||||
# Test serialize command
|
||||
serialize_args = [
|
||||
"python3", tensorize_model_for_testing_script, "--model", model_ref,
|
||||
"--dtype", "float16", "serialize", "--serialized-directory", tmp_path,
|
||||
"--suffix", "tests"
|
||||
]
|
||||
result = subprocess.run(serialize_args, capture_output=True, text=True)
|
||||
print(result.stdout) # Print the output of the serialize command
|
||||
|
||||
assert result.returncode == 0, (f"Serialize command failed with output:"
|
||||
f"\n{result.stdout}\n{result.stderr}")
|
||||
|
||||
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
|
||||
|
||||
# Test deserialize command
|
||||
deserialize_args = [
|
||||
"python3", tensorize_model_for_testing_script, "--model", model_ref,
|
||||
"--dtype", "float16", "deserialize", "--path-to-tensors",
|
||||
path_to_tensors
|
||||
]
|
||||
result = subprocess.run(deserialize_args, capture_output=True, text=True)
|
||||
assert result.returncode == 0, (f"Deserialize command failed with output:"
|
||||
f"\n{result.stdout}\n{result.stderr}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||
def test_openai_apiserver_with_tensorizer(tmp_path):
|
||||
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
|
||||
## Serialize model
|
||||
serialize_args = [
|
||||
"python3", tensorize_model_for_testing_script, "--model", model_ref,
|
||||
"--dtype", "float16", "serialize", "--serialized-directory", tmp_path,
|
||||
"--suffix", "tests"
|
||||
]
|
||||
result = subprocess.run(serialize_args, capture_output=True, text=True)
|
||||
print(result.stdout) # Print the output of the serialize command
|
||||
vllm_model = vllm_runner(model_ref, )
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
|
||||
assert result.returncode == 0, (f"Serialize command failed with output:"
|
||||
f"\n{result.stdout}\n{result.stderr}")
|
||||
serialize_vllm_model(vllm_model.model.llm_engine,
|
||||
TensorizerConfig(tensorizer_uri=model_path))
|
||||
|
||||
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
|
||||
model_loader_extra_config = {
|
||||
"tensorizer_uri": path_to_tensors,
|
||||
"vllm_tensorized": True
|
||||
"tensorizer_uri": str(model_path),
|
||||
}
|
||||
|
||||
del vllm_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
## Start OpenAI API server
|
||||
openai_args = [
|
||||
"--model", model_ref, "--dtype", "float16", "--load-format",
|
||||
@@ -304,10 +234,10 @@ def test_openai_apiserver_with_tensorizer(tmp_path):
|
||||
|
||||
def test_raise_value_error_on_invalid_load_format(vllm_runner):
|
||||
with pytest.raises(ValueError):
|
||||
vllm_runner(model_ref,
|
||||
load_format="safetensors",
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri="test", vllm_tensorized=False))
|
||||
vllm_runner(
|
||||
model_ref,
|
||||
load_format="safetensors",
|
||||
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
|
||||
|
||||
|
||||
def test_tensorizer_with_tp(vllm_runner):
|
||||
@@ -321,8 +251,29 @@ def test_tensorizer_with_tp(vllm_runner):
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=tensorized_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=False,
|
||||
s3_endpoint="object.ord1.coreweave.com",
|
||||
),
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
|
||||
|
||||
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
|
||||
model_ref = "facebook/opt-125m"
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
config = TensorizerConfig(tensorizer_uri=str(model_path))
|
||||
|
||||
vllm_model = vllm_runner(model_ref)
|
||||
outputs = vllm_model.generate(prompts, sampling_params)
|
||||
serialize_vllm_model(vllm_model.model.llm_engine, config)
|
||||
|
||||
assert is_vllm_tensorized(config)
|
||||
del vllm_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
loaded_vllm_model = vllm_runner(model_ref,
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=config)
|
||||
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
|
||||
|
||||
assert outputs == deserialized_outputs
|
||||
|
||||
Reference in New Issue
Block a user