Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,46 +0,0 @@
|
|||||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
|
||||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
|
||||||
# following differences:
|
|
||||||
# - ruff line length is overridden to 88
|
|
||||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
line-length = 88
|
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
|
||||||
"vllm/third_party/**" = ["ALL"]
|
|
||||||
"vllm/version.py" = ["F401"]
|
|
||||||
"vllm/_version.py" = ["ALL"]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = [
|
|
||||||
# pycodestyle
|
|
||||||
"E",
|
|
||||||
# Pyflakes
|
|
||||||
"F",
|
|
||||||
# pyupgrade
|
|
||||||
"UP",
|
|
||||||
# flake8-bugbear
|
|
||||||
"B",
|
|
||||||
# flake8-simplify
|
|
||||||
"SIM",
|
|
||||||
# isort
|
|
||||||
"I",
|
|
||||||
# flake8-logging-format
|
|
||||||
"G",
|
|
||||||
]
|
|
||||||
ignore = [
|
|
||||||
# star imports
|
|
||||||
"F405", "F403",
|
|
||||||
# lambda expression assignment
|
|
||||||
"E731",
|
|
||||||
# Loop control variable not used within loop body
|
|
||||||
"B007",
|
|
||||||
# f-string format
|
|
||||||
"UP032",
|
|
||||||
# Can remove once 3.10+ is the minimum Python version
|
|
||||||
"UP007",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
|
||||||
docstring-code-format = true
|
|
||||||
@@ -6,28 +6,16 @@ default_stages:
|
|||||||
- manual # Run in CI
|
- manual # Run in CI
|
||||||
exclude: 'vllm/third_party/.*'
|
exclude: 'vllm/third_party/.*'
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/google/yapf
|
|
||||||
rev: v0.43.0
|
|
||||||
hooks:
|
|
||||||
- id: yapf
|
|
||||||
args: [--in-place, --verbose]
|
|
||||||
# Keep the same list from yapfignore here to avoid yapf failing without any inputs
|
|
||||||
exclude: '(.buildkite|benchmarks|build|examples)/.*'
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.11.7
|
rev: v0.11.7
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--output-format, github, --fix]
|
args: [--output-format, github, --fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
files: ^(.buildkite|benchmarks|examples)/.*
|
|
||||||
- repo: https://github.com/crate-ci/typos
|
- repo: https://github.com/crate-ci/typos
|
||||||
rev: v1.35.5
|
rev: v1.35.5
|
||||||
hooks:
|
hooks:
|
||||||
- id: typos
|
- id: typos
|
||||||
- repo: https://github.com/PyCQA/isort
|
|
||||||
rev: 6.0.1
|
|
||||||
hooks:
|
|
||||||
- id: isort
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v20.1.3
|
rev: v20.1.3
|
||||||
hooks:
|
hooks:
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
from benchmark_utils import TimeCollector
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
||||||
from benchmark_utils import TimeCollector
|
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
|
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ import time
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from benchmark_utils import TimeCollector
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
||||||
from benchmark_utils import TimeCollector
|
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
DeviceConfig,
|
DeviceConfig,
|
||||||
|
|||||||
@@ -37,14 +37,13 @@ from typing import Optional
|
|||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tqdm.asyncio import tqdm
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
|
||||||
|
|
||||||
from backend_request_func import (
|
from backend_request_func import (
|
||||||
ASYNC_REQUEST_FUNCS,
|
ASYNC_REQUEST_FUNCS,
|
||||||
RequestFuncInput,
|
RequestFuncInput,
|
||||||
RequestFuncOutput,
|
RequestFuncOutput,
|
||||||
)
|
)
|
||||||
|
from tqdm.asyncio import tqdm
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|||||||
@@ -1,49 +0,0 @@
|
|||||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
|
||||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
|
||||||
# following differences:
|
|
||||||
# - ruff line length is overridden to 88
|
|
||||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
line-length = 88
|
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
|
||||||
"vllm/third_party/**" = ["ALL"]
|
|
||||||
"vllm/version.py" = ["F401"]
|
|
||||||
"vllm/_version.py" = ["ALL"]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = [
|
|
||||||
# pycodestyle
|
|
||||||
"E",
|
|
||||||
# Pyflakes
|
|
||||||
"F",
|
|
||||||
# pyupgrade
|
|
||||||
"UP",
|
|
||||||
# flake8-bugbear
|
|
||||||
"B",
|
|
||||||
# flake8-simplify
|
|
||||||
"SIM",
|
|
||||||
# isort
|
|
||||||
"I",
|
|
||||||
# flake8-logging-format
|
|
||||||
"G",
|
|
||||||
]
|
|
||||||
ignore = [
|
|
||||||
# star imports
|
|
||||||
"F405", "F403",
|
|
||||||
# lambda expression assignment
|
|
||||||
"E731",
|
|
||||||
# Loop control variable not used within loop body
|
|
||||||
"B007",
|
|
||||||
# f-string format
|
|
||||||
"UP032",
|
|
||||||
# Can remove once 3.10+ is the minimum Python version
|
|
||||||
"UP007",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint.isort]
|
|
||||||
known-first-party = ["vllm"]
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
|
||||||
docstring-code-format = true
|
|
||||||
@@ -16,7 +16,7 @@ import shutil
|
|||||||
|
|
||||||
from torch.utils.hipify.hipify_python import hipify
|
from torch.utils.hipify.hipify_python import hipify
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
# Project directory where all the source + include files live.
|
# Project directory where all the source + include files live.
|
||||||
@@ -34,15 +34,14 @@ if __name__ == '__main__':
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Source files to convert.
|
# Source files to convert.
|
||||||
parser.add_argument("sources",
|
parser.add_argument(
|
||||||
help="Source files to hipify.",
|
"sources", help="Source files to hipify.", nargs="*", default=[]
|
||||||
nargs="*",
|
)
|
||||||
default=[])
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Limit include scope to project_dir only
|
# Limit include scope to project_dir only
|
||||||
includes = [os.path.join(args.project_dir, '*')]
|
includes = [os.path.join(args.project_dir, "*")]
|
||||||
|
|
||||||
# Get absolute path for all source files.
|
# Get absolute path for all source files.
|
||||||
extra_files = [os.path.abspath(s) for s in args.sources]
|
extra_files = [os.path.abspath(s) for s in args.sources]
|
||||||
@@ -51,25 +50,31 @@ if __name__ == '__main__':
|
|||||||
# The directory might already exist to hold object files so we ignore that.
|
# The directory might already exist to hold object files so we ignore that.
|
||||||
shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
|
shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
|
||||||
|
|
||||||
hipify_result = hipify(project_directory=args.project_dir,
|
hipify_result = hipify(
|
||||||
|
project_directory=args.project_dir,
|
||||||
output_directory=args.output_dir,
|
output_directory=args.output_dir,
|
||||||
header_include_dirs=[],
|
header_include_dirs=[],
|
||||||
includes=includes,
|
includes=includes,
|
||||||
extra_files=extra_files,
|
extra_files=extra_files,
|
||||||
show_detailed=True,
|
show_detailed=True,
|
||||||
is_pytorch_extension=True,
|
is_pytorch_extension=True,
|
||||||
hipify_extra_files_only=True)
|
hipify_extra_files_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
hipified_sources = []
|
hipified_sources = []
|
||||||
for source in args.sources:
|
for source in args.sources:
|
||||||
s_abs = os.path.abspath(source)
|
s_abs = os.path.abspath(source)
|
||||||
hipified_s_abs = (hipify_result[s_abs].hipified_path if
|
hipified_s_abs = (
|
||||||
(s_abs in hipify_result
|
hipify_result[s_abs].hipified_path
|
||||||
and hipify_result[s_abs].hipified_path is not None)
|
if (
|
||||||
else s_abs)
|
s_abs in hipify_result
|
||||||
|
and hipify_result[s_abs].hipified_path is not None
|
||||||
|
)
|
||||||
|
else s_abs
|
||||||
|
)
|
||||||
hipified_sources.append(hipified_s_abs)
|
hipified_sources.append(hipified_s_abs)
|
||||||
|
|
||||||
assert (len(hipified_sources) == len(args.sources))
|
assert len(hipified_sources) == len(args.sources)
|
||||||
|
|
||||||
# Print hipified source files.
|
# Print hipified source files.
|
||||||
print("\n".join(hipified_sources))
|
print("\n".join(hipified_sources))
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = {
|
|||||||
**{
|
**{
|
||||||
VLLMDataType.u4b8: "u4b8",
|
VLLMDataType.u4b8: "u4b8",
|
||||||
VLLMDataType.u8b128: "u8b128",
|
VLLMDataType.u8b128: "u8b128",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||||
@@ -35,7 +35,7 @@ VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
|||||||
**{
|
**{
|
||||||
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
|
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
|
||||||
VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
|
VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
|
VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
|
||||||
@@ -43,7 +43,7 @@ VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
|
|||||||
**{
|
**{
|
||||||
VLLMDataType.u4b8: 4,
|
VLLMDataType.u4b8: 4,
|
||||||
VLLMDataType.u8b128: 8,
|
VLLMDataType.u8b128: 8,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||||
@@ -67,15 +67,13 @@ VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
|||||||
DataType.f32: "at::ScalarType::Float",
|
DataType.f32: "at::ScalarType::Float",
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMKernelScheduleTag: dict[Union[
|
VLLMKernelScheduleTag: dict[
|
||||||
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
Union[MixedInputKernelScheduleType, KernelScheduleType], str
|
||||||
|
] = {
|
||||||
**KernelScheduleTag, # type: ignore
|
**KernelScheduleTag, # type: ignore
|
||||||
**{
|
**{
|
||||||
MixedInputKernelScheduleType.TmaWarpSpecialized:
|
MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized",
|
||||||
"cutlass::gemm::KernelTmaWarpSpecialized",
|
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong",
|
||||||
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong:
|
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative",
|
||||||
"cutlass::gemm::KernelTmaWarpSpecializedPingpong",
|
},
|
||||||
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative:
|
}
|
||||||
"cutlass::gemm::KernelTmaWarpSpecializedCooperative",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ FILE_HEAD = """
|
|||||||
namespace MARLIN_NAMESPACE_NAME {
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
TEMPLATE = ("template __global__ void Marlin<"
|
TEMPLATE = (
|
||||||
|
"template __global__ void Marlin<"
|
||||||
"{{scalar_t}}, "
|
"{{scalar_t}}, "
|
||||||
"{{w_type_id}}, "
|
"{{w_type_id}}, "
|
||||||
"{{s_type_id}}, "
|
"{{s_type_id}}, "
|
||||||
@@ -29,13 +30,17 @@ TEMPLATE = ("template __global__ void Marlin<"
|
|||||||
"{{stages}}, "
|
"{{stages}}, "
|
||||||
"{{group_blocks}}, "
|
"{{group_blocks}}, "
|
||||||
"{{'true' if is_zp_float else 'false'}}>"
|
"{{'true' if is_zp_float else 'false'}}>"
|
||||||
"( MARLIN_KERNEL_PARAMS );")
|
"( MARLIN_KERNEL_PARAMS );"
|
||||||
|
)
|
||||||
|
|
||||||
# int8 with zero point case (vllm::kU8) is also supported,
|
# int8 with zero point case (vllm::kU8) is also supported,
|
||||||
# we don't add it to reduce wheel size.
|
# we don't add it to reduce wheel size.
|
||||||
SCALAR_TYPES = [
|
SCALAR_TYPES = [
|
||||||
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
|
"vllm::kU4",
|
||||||
"vllm::kFE2M1f"
|
"vllm::kU4B8",
|
||||||
|
"vllm::kU8B128",
|
||||||
|
"vllm::kFE4M3fn",
|
||||||
|
"vllm::kFE2M1f",
|
||||||
]
|
]
|
||||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||||
|
|
||||||
@@ -58,11 +63,12 @@ def generate_new_kernels():
|
|||||||
all_template_str_list = []
|
all_template_str_list = []
|
||||||
|
|
||||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
|
||||||
|
):
|
||||||
# act order case only support gptq-int4 and gptq-int8
|
# act order case only support gptq-int4 and gptq-int8
|
||||||
if group_blocks == 0 and scalar_type not in [
|
if group_blocks == 0 and scalar_type not in [
|
||||||
"vllm::kU4B8", "vllm::kU8B128"
|
"vllm::kU4B8",
|
||||||
|
"vllm::kU8B128",
|
||||||
]:
|
]:
|
||||||
continue
|
continue
|
||||||
if thread_configs[2] == 256:
|
if thread_configs[2] == 256:
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ FILE_HEAD = """
|
|||||||
namespace MARLIN_NAMESPACE_NAME {
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
TEMPLATE = ("template __global__ void Marlin<"
|
TEMPLATE = (
|
||||||
|
"template __global__ void Marlin<"
|
||||||
"{{scalar_t}}, "
|
"{{scalar_t}}, "
|
||||||
"{{w_type_id}}, "
|
"{{w_type_id}}, "
|
||||||
"{{s_type_id}}, "
|
"{{s_type_id}}, "
|
||||||
@@ -29,16 +30,19 @@ TEMPLATE = ("template __global__ void Marlin<"
|
|||||||
"{{stages}}, "
|
"{{stages}}, "
|
||||||
"{{group_blocks}}, "
|
"{{group_blocks}}, "
|
||||||
"{{'true' if is_zp_float else 'false'}}>"
|
"{{'true' if is_zp_float else 'false'}}>"
|
||||||
"( MARLIN_KERNEL_PARAMS );")
|
"( MARLIN_KERNEL_PARAMS );"
|
||||||
|
)
|
||||||
|
|
||||||
# int8 with zero point case (vllm::kU8) is also supported,
|
# int8 with zero point case (vllm::kU8) is also supported,
|
||||||
# we don't add it to reduce wheel size.
|
# we don't add it to reduce wheel size.
|
||||||
SCALAR_TYPES = [
|
SCALAR_TYPES = [
|
||||||
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
|
"vllm::kU4",
|
||||||
"vllm::kFE2M1f"
|
"vllm::kU4B8",
|
||||||
|
"vllm::kU8B128",
|
||||||
|
"vllm::kFE4M3fn",
|
||||||
|
"vllm::kFE2M1f",
|
||||||
]
|
]
|
||||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
|
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
|
||||||
(128, 64, 128)]
|
|
||||||
|
|
||||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||||
# group_blocks:
|
# group_blocks:
|
||||||
@@ -59,11 +63,12 @@ def generate_new_kernels():
|
|||||||
all_template_str_list = []
|
all_template_str_list = []
|
||||||
|
|
||||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
|
||||||
|
):
|
||||||
# act order case only support gptq-int4 and gptq-int8
|
# act order case only support gptq-int4 and gptq-int8
|
||||||
if group_blocks == 0 and scalar_type not in [
|
if group_blocks == 0 and scalar_type not in [
|
||||||
"vllm::kU4B8", "vllm::kU8B128"
|
"vllm::kU4B8",
|
||||||
|
"vllm::kU8B128",
|
||||||
]:
|
]:
|
||||||
continue
|
continue
|
||||||
if thread_configs[2] == 256:
|
if thread_configs[2] == 256:
|
||||||
@@ -93,8 +98,7 @@ def generate_new_kernels():
|
|||||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||||
|
|
||||||
is_zp_float_list = [False]
|
is_zp_float_list = [False]
|
||||||
if dtype == "fp16" and scalar_type == "vllm::kU4" and \
|
if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4:
|
||||||
group_blocks == 4:
|
|
||||||
# HQQ (is_zp_float = true) only supports
|
# HQQ (is_zp_float = true) only supports
|
||||||
# 4bit quantization and fp16
|
# 4bit quantization and fp16
|
||||||
is_zp_float_list.append(True)
|
is_zp_float_list.append(True)
|
||||||
|
|||||||
@@ -12,18 +12,24 @@ from functools import reduce
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
|
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag,
|
from vllm_cutlass_library_extension import (
|
||||||
|
DataType,
|
||||||
|
EpilogueScheduleTag,
|
||||||
EpilogueScheduleType,
|
EpilogueScheduleType,
|
||||||
MixedInputKernelScheduleType,
|
MixedInputKernelScheduleType,
|
||||||
TileSchedulerTag,
|
TileSchedulerTag,
|
||||||
TileSchedulerType, VLLMDataType,
|
TileSchedulerType,
|
||||||
|
VLLMDataType,
|
||||||
VLLMDataTypeNames,
|
VLLMDataTypeNames,
|
||||||
VLLMDataTypeSize, VLLMDataTypeTag,
|
VLLMDataTypeSize,
|
||||||
|
VLLMDataTypeTag,
|
||||||
VLLMDataTypeTorchDataTypeTag,
|
VLLMDataTypeTorchDataTypeTag,
|
||||||
VLLMDataTypeVLLMScalarTypeTag,
|
VLLMDataTypeVLLMScalarTypeTag,
|
||||||
VLLMKernelScheduleTag)
|
VLLMKernelScheduleTag,
|
||||||
|
)
|
||||||
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
|
|
||||||
@@ -286,18 +292,23 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
|||||||
tile_shape = (
|
tile_shape = (
|
||||||
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
||||||
)
|
)
|
||||||
cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" +
|
cluster_shape = (
|
||||||
f"x{schedule_config.cluster_shape_mnk[1]}" +
|
f"{schedule_config.cluster_shape_mnk[0]}"
|
||||||
f"x{schedule_config.cluster_shape_mnk[2]}")
|
+ f"x{schedule_config.cluster_shape_mnk[1]}"
|
||||||
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\
|
+ f"x{schedule_config.cluster_shape_mnk[2]}"
|
||||||
.split("::")[-1]
|
)
|
||||||
epilogue_schedule = EpilogueScheduleTag[
|
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule].split(
|
||||||
schedule_config.epilogue_schedule].split("::")[-1]
|
"::"
|
||||||
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\
|
)[-1]
|
||||||
.split("::")[-1]
|
epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split(
|
||||||
|
"::"
|
||||||
|
)[-1]
|
||||||
|
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1]
|
||||||
|
|
||||||
return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" +
|
return (
|
||||||
f"_{epilogue_schedule}_{tile_scheduler}")
|
f"{tile_shape}_{cluster_shape}_{kernel_schedule}"
|
||||||
|
+ f"_{epilogue_schedule}_{tile_scheduler}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# mostly unique shorter sch_sig
|
# mostly unique shorter sch_sig
|
||||||
@@ -316,18 +327,24 @@ def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
|||||||
|
|
||||||
# unique type_name
|
# unique type_name
|
||||||
def generate_type_signature(kernel_types: TypeConfig):
|
def generate_type_signature(kernel_types: TypeConfig):
|
||||||
return str("".join([
|
return str(
|
||||||
|
"".join(
|
||||||
|
[
|
||||||
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||||
for field in fields(TypeConfig)
|
for field in fields(TypeConfig)
|
||||||
]))
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_type_option_name(kernel_types: TypeConfig):
|
def generate_type_option_name(kernel_types: TypeConfig):
|
||||||
return ", ".join([
|
return ", ".join(
|
||||||
f"{field.name.replace('b_', 'with_')+'_type'}=" +
|
[
|
||||||
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
f"{field.name.replace('b_', 'with_') + '_type'}="
|
||||||
|
+ VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||||
for field in fields(TypeConfig)
|
for field in fields(TypeConfig)
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_power_of_two(n):
|
def is_power_of_two(n):
|
||||||
@@ -335,7 +352,6 @@ def is_power_of_two(n):
|
|||||||
|
|
||||||
|
|
||||||
def to_cute_constant(value: list[int]):
|
def to_cute_constant(value: list[int]):
|
||||||
|
|
||||||
def _to_cute_constant(value: int):
|
def _to_cute_constant(value: int):
|
||||||
if is_power_of_two(value):
|
if is_power_of_two(value):
|
||||||
return f"_{value}"
|
return f"_{value}"
|
||||||
@@ -350,11 +366,11 @@ def to_cute_constant(value: list[int]):
|
|||||||
|
|
||||||
def unique_schedules(impl_configs: list[ImplConfig]):
|
def unique_schedules(impl_configs: list[ImplConfig]):
|
||||||
# Use dict over set for deterministic ordering
|
# Use dict over set for deterministic ordering
|
||||||
return list({
|
return list(
|
||||||
sch: None
|
{
|
||||||
for impl_config in impl_configs
|
sch: None for impl_config in impl_configs for sch in impl_config.schedules
|
||||||
for sch in impl_config.schedules
|
}.keys()
|
||||||
}.keys())
|
)
|
||||||
|
|
||||||
|
|
||||||
def unsigned_type_with_bitwidth(num_bits):
|
def unsigned_type_with_bitwidth(num_bits):
|
||||||
@@ -380,7 +396,7 @@ template_globals = {
|
|||||||
"gen_type_sig": generate_type_signature,
|
"gen_type_sig": generate_type_signature,
|
||||||
"unique_schedules": unique_schedules,
|
"unique_schedules": unique_schedules,
|
||||||
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
|
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
|
||||||
"gen_type_option_name": generate_type_option_name
|
"gen_type_option_name": generate_type_option_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -398,23 +414,28 @@ prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
|||||||
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
||||||
sources = []
|
sources = []
|
||||||
|
|
||||||
sources.append((
|
sources.append(
|
||||||
|
(
|
||||||
"machete_mm_dispatch",
|
"machete_mm_dispatch",
|
||||||
mm_dispatch_template.render(impl_configs=impl_configs),
|
mm_dispatch_template.render(impl_configs=impl_configs),
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
prepack_types = []
|
prepack_types = []
|
||||||
for impl_config in impl_configs:
|
for impl_config in impl_configs:
|
||||||
convert_type = impl_config.types.a \
|
convert_type = (
|
||||||
if impl_config.types.b_group_scale == DataType.void \
|
impl_config.types.a
|
||||||
|
if impl_config.types.b_group_scale == DataType.void
|
||||||
else impl_config.types.b_group_scale
|
else impl_config.types.b_group_scale
|
||||||
|
)
|
||||||
prepack_types.append(
|
prepack_types.append(
|
||||||
PrepackTypeConfig(
|
PrepackTypeConfig(
|
||||||
a=impl_config.types.a,
|
a=impl_config.types.a,
|
||||||
b_num_bits=VLLMDataTypeSize[impl_config.types.b],
|
b_num_bits=VLLMDataTypeSize[impl_config.types.b],
|
||||||
convert=convert_type,
|
convert=convert_type,
|
||||||
accumulator=impl_config.types.accumulator,
|
accumulator=impl_config.types.accumulator,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def prepacked_type_key(prepack_type: PrepackTypeConfig):
|
def prepacked_type_key(prepack_type: PrepackTypeConfig):
|
||||||
# For now, we can just use the first accumulator type seen since
|
# For now, we can just use the first accumulator type seen since
|
||||||
@@ -430,10 +451,14 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
|||||||
unique_prepack_types.append(prepack_type)
|
unique_prepack_types.append(prepack_type)
|
||||||
prepack_types_seen.add(key)
|
prepack_types_seen.add(key)
|
||||||
|
|
||||||
sources.append((
|
sources.append(
|
||||||
|
(
|
||||||
"machete_prepack",
|
"machete_prepack",
|
||||||
prepack_dispatch_template.render(types=unique_prepack_types, ),
|
prepack_dispatch_template.render(
|
||||||
))
|
types=unique_prepack_types,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Split up impls across files
|
# Split up impls across files
|
||||||
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
||||||
@@ -466,10 +491,12 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
|||||||
curr_impl_in_file += len(files_impls[-1][-1].schedules)
|
curr_impl_in_file += len(files_impls[-1][-1].schedules)
|
||||||
|
|
||||||
for part, file_impls in enumerate(files_impls):
|
for part, file_impls in enumerate(files_impls):
|
||||||
sources.append((
|
sources.append(
|
||||||
f"machete_mm_impl_part{part+1}",
|
(
|
||||||
|
f"machete_mm_impl_part{part + 1}",
|
||||||
mm_impl_template.render(impl_configs=file_impls),
|
mm_impl_template.render(impl_configs=file_impls),
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
@@ -514,8 +541,7 @@ def generate():
|
|||||||
# For now we use the same heuristic for all types
|
# For now we use the same heuristic for all types
|
||||||
# Heuristic is currently tuned for H100s
|
# Heuristic is currently tuned for H100s
|
||||||
default_heuristic = [
|
default_heuristic = [
|
||||||
(cond, ScheduleConfig(*tile_config,
|
(cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore
|
||||||
**sch_common_params)) # type: ignore
|
|
||||||
for cond, tile_config in default_tile_heuristic_config.items()
|
for cond, tile_config in default_tile_heuristic_config.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -541,14 +567,18 @@ def generate():
|
|||||||
a_token_scale=DataType.void,
|
a_token_scale=DataType.void,
|
||||||
out=a,
|
out=a,
|
||||||
accumulator=DataType.f32,
|
accumulator=DataType.f32,
|
||||||
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
)
|
||||||
for a in (DataType.f16, DataType.bf16))
|
for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
||||||
|
for a in (DataType.f16, DataType.bf16)
|
||||||
|
)
|
||||||
|
|
||||||
impl_configs += [
|
impl_configs += [
|
||||||
ImplConfig(x[0], x[1], x[2])
|
ImplConfig(x[0], x[1], x[2])
|
||||||
for x in zip(GPTQ_kernel_type_configs,
|
for x in zip(
|
||||||
|
GPTQ_kernel_type_configs,
|
||||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||||
itertools.repeat(default_heuristic))
|
itertools.repeat(default_heuristic),
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
AWQ_kernel_type_configs = list(
|
AWQ_kernel_type_configs = list(
|
||||||
@@ -561,14 +591,18 @@ def generate():
|
|||||||
a_token_scale=DataType.void,
|
a_token_scale=DataType.void,
|
||||||
out=a,
|
out=a,
|
||||||
accumulator=DataType.f32,
|
accumulator=DataType.f32,
|
||||||
) for b in (DataType.u4, DataType.u8)
|
)
|
||||||
for a in (DataType.f16, DataType.bf16))
|
for b in (DataType.u4, DataType.u8)
|
||||||
|
for a in (DataType.f16, DataType.bf16)
|
||||||
|
)
|
||||||
|
|
||||||
impl_configs += [
|
impl_configs += [
|
||||||
ImplConfig(x[0], x[1], x[2])
|
ImplConfig(x[0], x[1], x[2])
|
||||||
for x in zip(AWQ_kernel_type_configs,
|
for x in zip(
|
||||||
|
AWQ_kernel_type_configs,
|
||||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||||
itertools.repeat(default_heuristic))
|
itertools.repeat(default_heuristic),
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO: Support W4A8 when ready
|
# TODO: Support W4A8 when ready
|
||||||
|
|||||||
@@ -33,8 +33,11 @@ def auto_mock(module, attr, max_mocks=50):
|
|||||||
try:
|
try:
|
||||||
# First treat attr as an attr, then as a submodule
|
# First treat attr as an attr, then as a submodule
|
||||||
with patch("importlib.metadata.version", return_value="0.0.0"):
|
with patch("importlib.metadata.version", return_value="0.0.0"):
|
||||||
return getattr(importlib.import_module(module), attr,
|
return getattr(
|
||||||
importlib.import_module(f"{module}.{attr}"))
|
importlib.import_module(module),
|
||||||
|
attr,
|
||||||
|
importlib.import_module(f"{module}.{attr}"),
|
||||||
|
)
|
||||||
except importlib.metadata.PackageNotFoundError as e:
|
except importlib.metadata.PackageNotFoundError as e:
|
||||||
raise e
|
raise e
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
@@ -42,7 +45,8 @@ def auto_mock(module, attr, max_mocks=50):
|
|||||||
sys.modules[e.name] = PydanticMagicMock()
|
sys.modules[e.name] = PydanticMagicMock()
|
||||||
|
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Failed to import {module}.{attr} after mocking {max_mocks} imports")
|
f"Failed to import {module}.{attr} after mocking {max_mocks} imports"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
latency = auto_mock("vllm.benchmarks", "latency")
|
latency = auto_mock("vllm.benchmarks", "latency")
|
||||||
@@ -61,9 +65,7 @@ class MarkdownFormatter(HelpFormatter):
|
|||||||
"""Custom formatter that generates markdown for argument groups."""
|
"""Custom formatter that generates markdown for argument groups."""
|
||||||
|
|
||||||
def __init__(self, prog, starting_heading_level=3):
|
def __init__(self, prog, starting_heading_level=3):
|
||||||
super().__init__(prog,
|
super().__init__(prog, max_help_position=float("inf"), width=float("inf"))
|
||||||
max_help_position=float('inf'),
|
|
||||||
width=float('inf'))
|
|
||||||
self._section_heading_prefix = "#" * starting_heading_level
|
self._section_heading_prefix = "#" * starting_heading_level
|
||||||
self._argument_heading_prefix = "#" * (starting_heading_level + 1)
|
self._argument_heading_prefix = "#" * (starting_heading_level + 1)
|
||||||
self._markdown_output = []
|
self._markdown_output = []
|
||||||
@@ -85,23 +87,19 @@ class MarkdownFormatter(HelpFormatter):
|
|||||||
|
|
||||||
def add_arguments(self, actions):
|
def add_arguments(self, actions):
|
||||||
for action in actions:
|
for action in actions:
|
||||||
if (len(action.option_strings) == 0
|
if len(action.option_strings) == 0 or "--help" in action.option_strings:
|
||||||
or "--help" in action.option_strings):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
option_strings = f'`{"`, `".join(action.option_strings)}`'
|
option_strings = f"`{'`, `'.join(action.option_strings)}`"
|
||||||
heading_md = f"{self._argument_heading_prefix} {option_strings}\n\n"
|
heading_md = f"{self._argument_heading_prefix} {option_strings}\n\n"
|
||||||
self._markdown_output.append(heading_md)
|
self._markdown_output.append(heading_md)
|
||||||
|
|
||||||
if choices := action.choices:
|
if choices := action.choices:
|
||||||
choices = f'`{"`, `".join(str(c) for c in choices)}`'
|
choices = f"`{'`, `'.join(str(c) for c in choices)}`"
|
||||||
self._markdown_output.append(
|
self._markdown_output.append(f"Possible choices: {choices}\n\n")
|
||||||
f"Possible choices: {choices}\n\n")
|
elif (metavar := action.metavar) and isinstance(metavar, (list, tuple)):
|
||||||
elif ((metavar := action.metavar)
|
metavar = f"`{'`, `'.join(str(m) for m in metavar)}`"
|
||||||
and isinstance(metavar, (list, tuple))):
|
self._markdown_output.append(f"Possible choices: {metavar}\n\n")
|
||||||
metavar = f'`{"`, `".join(str(m) for m in metavar)}`'
|
|
||||||
self._markdown_output.append(
|
|
||||||
f"Possible choices: {metavar}\n\n")
|
|
||||||
|
|
||||||
if action.help:
|
if action.help:
|
||||||
self._markdown_output.append(f"{action.help}\n\n")
|
self._markdown_output.append(f"{action.help}\n\n")
|
||||||
@@ -143,24 +141,17 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
|||||||
|
|
||||||
# Create parsers to document
|
# Create parsers to document
|
||||||
parsers = {
|
parsers = {
|
||||||
"engine_args":
|
"engine_args": create_parser(EngineArgs.add_cli_args),
|
||||||
create_parser(EngineArgs.add_cli_args),
|
"async_engine_args": create_parser(
|
||||||
"async_engine_args":
|
AsyncEngineArgs.add_cli_args, async_args_only=True
|
||||||
create_parser(AsyncEngineArgs.add_cli_args, async_args_only=True),
|
),
|
||||||
"serve":
|
"serve": create_parser(cli_args.make_arg_parser),
|
||||||
create_parser(cli_args.make_arg_parser),
|
"chat": create_parser(ChatCommand.add_cli_args),
|
||||||
"chat":
|
"complete": create_parser(CompleteCommand.add_cli_args),
|
||||||
create_parser(ChatCommand.add_cli_args),
|
"bench_latency": create_parser(latency.add_cli_args),
|
||||||
"complete":
|
"bench_throughput": create_parser(throughput.add_cli_args),
|
||||||
create_parser(CompleteCommand.add_cli_args),
|
"bench_serve": create_parser(serve.add_cli_args),
|
||||||
"bench_latency":
|
"run-batch": create_parser(run_batch.make_arg_parser),
|
||||||
create_parser(latency.add_cli_args),
|
|
||||||
"bench_throughput":
|
|
||||||
create_parser(throughput.add_cli_args),
|
|
||||||
"bench_serve":
|
|
||||||
create_parser(serve.add_cli_args),
|
|
||||||
"run-batch":
|
|
||||||
create_parser(run_batch.make_arg_parser),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generate documentation for each parser
|
# Generate documentation for each parser
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import regex as re
|
|||||||
logger = logging.getLogger("mkdocs")
|
logger = logging.getLogger("mkdocs")
|
||||||
|
|
||||||
ROOT_DIR = Path(__file__).parent.parent.parent.parent
|
ROOT_DIR = Path(__file__).parent.parent.parent.parent
|
||||||
ROOT_DIR_RELATIVE = '../../../../..'
|
ROOT_DIR_RELATIVE = "../../../../.."
|
||||||
EXAMPLE_DIR = ROOT_DIR / "examples"
|
EXAMPLE_DIR = ROOT_DIR / "examples"
|
||||||
EXAMPLE_DOC_DIR = ROOT_DIR / "docs/examples"
|
EXAMPLE_DOC_DIR = ROOT_DIR / "docs/examples"
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ def fix_case(text: str) -> str:
|
|||||||
r"int\d+": lambda x: x.group(0).upper(), # e.g. int8, int16
|
r"int\d+": lambda x: x.group(0).upper(), # e.g. int8, int16
|
||||||
}
|
}
|
||||||
for pattern, repl in subs.items():
|
for pattern, repl in subs.items():
|
||||||
text = re.sub(rf'\b{pattern}\b', repl, text, flags=re.IGNORECASE)
|
text = re.sub(rf"\b{pattern}\b", repl, text, flags=re.IGNORECASE)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
@@ -59,6 +59,7 @@ class Example:
|
|||||||
determine_title() -> str: Determines the title of the document.
|
determine_title() -> str: Determines the title of the document.
|
||||||
generate() -> str: Generates the documentation content.
|
generate() -> str: Generates the documentation content.
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
path: Path
|
path: Path
|
||||||
category: str = None
|
category: str = None
|
||||||
main_file: Path = field(init=False)
|
main_file: Path = field(init=False)
|
||||||
@@ -85,8 +86,7 @@ class Example:
|
|||||||
Raises:
|
Raises:
|
||||||
IndexError: If no Markdown files are found in the directory.
|
IndexError: If no Markdown files are found in the directory.
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
return self.path if self.path.is_file() else list(
|
return self.path if self.path.is_file() else list(self.path.glob("*.md")).pop()
|
||||||
self.path.glob("*.md")).pop()
|
|
||||||
|
|
||||||
def determine_other_files(self) -> list[Path]:
|
def determine_other_files(self) -> list[Path]:
|
||||||
"""
|
"""
|
||||||
@@ -109,9 +109,9 @@ class Example:
|
|||||||
# Specify encoding for building on Windows
|
# Specify encoding for building on Windows
|
||||||
with open(self.main_file, encoding="utf-8") as f:
|
with open(self.main_file, encoding="utf-8") as f:
|
||||||
first_line = f.readline().strip()
|
first_line = f.readline().strip()
|
||||||
match = re.match(r'^#\s+(?P<title>.+)$', first_line)
|
match = re.match(r"^#\s+(?P<title>.+)$", first_line)
|
||||||
if match:
|
if match:
|
||||||
return match.group('title')
|
return match.group("title")
|
||||||
return fix_case(self.path.stem.replace("_", " ").title())
|
return fix_case(self.path.stem.replace("_", " ").title())
|
||||||
|
|
||||||
def fix_relative_links(self, content: str) -> str:
|
def fix_relative_links(self, content: str) -> str:
|
||||||
@@ -127,7 +127,7 @@ class Example:
|
|||||||
"""
|
"""
|
||||||
# Regex to match markdown links [text](relative_path)
|
# Regex to match markdown links [text](relative_path)
|
||||||
# This matches links that don't start with http, https, ftp, or #
|
# This matches links that don't start with http, https, ftp, or #
|
||||||
link_pattern = r'\[([^\]]*)\]\((?!(?:https?|ftp)://|#)([^)]+)\)'
|
link_pattern = r"\[([^\]]*)\]\((?!(?:https?|ftp)://|#)([^)]+)\)"
|
||||||
|
|
||||||
def replace_link(match):
|
def replace_link(match):
|
||||||
link_text = match.group(1)
|
link_text = match.group(1)
|
||||||
@@ -137,7 +137,7 @@ class Example:
|
|||||||
gh_file = (self.main_file.parent / relative_path).resolve()
|
gh_file = (self.main_file.parent / relative_path).resolve()
|
||||||
gh_file = gh_file.relative_to(ROOT_DIR)
|
gh_file = gh_file.relative_to(ROOT_DIR)
|
||||||
|
|
||||||
return f'[{link_text}](gh-file:{gh_file})'
|
return f"[{link_text}](gh-file:{gh_file})"
|
||||||
|
|
||||||
return re.sub(link_pattern, replace_link, content)
|
return re.sub(link_pattern, replace_link, content)
|
||||||
|
|
||||||
@@ -150,9 +150,11 @@ class Example:
|
|||||||
code_fence = "``````"
|
code_fence = "``````"
|
||||||
|
|
||||||
if self.is_code:
|
if self.is_code:
|
||||||
content += (f"{code_fence}{self.main_file.suffix[1:]}\n"
|
content += (
|
||||||
|
f"{code_fence}{self.main_file.suffix[1:]}\n"
|
||||||
f'--8<-- "{self.main_file}"\n'
|
f'--8<-- "{self.main_file}"\n'
|
||||||
f"{code_fence}\n")
|
f"{code_fence}\n"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
with open(self.main_file) as f:
|
with open(self.main_file) as f:
|
||||||
# Skip the title from md snippets as it's been included above
|
# Skip the title from md snippets as it's been included above
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Literal
|
|||||||
|
|
||||||
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
||||||
# see https://docs.readthedocs.io/en/stable/reference/environment-variables.html # noqa
|
# see https://docs.readthedocs.io/en/stable/reference/environment-variables.html # noqa
|
||||||
if os.getenv('READTHEDOCS_VERSION_TYPE') == "tag":
|
if os.getenv("READTHEDOCS_VERSION_TYPE") == "tag":
|
||||||
# remove the warning banner if the version is a tagged release
|
# remove the warning banner if the version is a tagged release
|
||||||
mkdocs_dir = Path(__file__).parent.parent
|
mkdocs_dir = Path(__file__).parent.parent
|
||||||
announcement_path = mkdocs_dir / "overrides/main.html"
|
announcement_path = mkdocs_dir / "overrides/main.html"
|
||||||
|
|||||||
@@ -25,8 +25,9 @@ from mkdocs.structure.files import Files
|
|||||||
from mkdocs.structure.pages import Page
|
from mkdocs.structure.pages import Page
|
||||||
|
|
||||||
|
|
||||||
def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig,
|
def on_page_markdown(
|
||||||
files: Files) -> str:
|
markdown: str, *, page: Page, config: MkDocsConfig, files: Files
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Custom MkDocs plugin hook to rewrite special GitHub reference links
|
Custom MkDocs plugin hook to rewrite special GitHub reference links
|
||||||
in Markdown.
|
in Markdown.
|
||||||
@@ -92,11 +93,11 @@ def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig,
|
|||||||
Example:
|
Example:
|
||||||
[My issue](gh-issue:123) → [:octicons-mark-github-16: My issue](https://github.com/vllm-project/vllm/issues/123)
|
[My issue](gh-issue:123) → [:octicons-mark-github-16: My issue](https://github.com/vllm-project/vllm/issues/123)
|
||||||
"""
|
"""
|
||||||
url = f'{urls[match.group("type")]}/{match.group("path")}'
|
url = f"{urls[match.group('type')]}/{match.group('path')}"
|
||||||
if fragment := match.group("fragment"):
|
if fragment := match.group("fragment"):
|
||||||
url += f"#{fragment}"
|
url += f"#{fragment}"
|
||||||
|
|
||||||
return f'[{gh_icon} {match.group("title")}]({url})'
|
return f"[{gh_icon} {match.group('title')}]({url})"
|
||||||
|
|
||||||
def replace_auto_link(match: re.Match) -> str:
|
def replace_auto_link(match: re.Match) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,54 +0,0 @@
|
|||||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
|
||||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
|
||||||
# following differences:
|
|
||||||
# - ruff line length is overridden to 88
|
|
||||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
line-length = 88
|
|
||||||
exclude = [
|
|
||||||
# External file, leaving license intact
|
|
||||||
"examples/other/fp8/quantizer/quantize.py",
|
|
||||||
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
|
||||||
"vllm/third_party/**" = ["ALL"]
|
|
||||||
"vllm/version.py" = ["F401"]
|
|
||||||
"vllm/_version.py" = ["ALL"]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = [
|
|
||||||
# pycodestyle
|
|
||||||
"E",
|
|
||||||
# Pyflakes
|
|
||||||
"F",
|
|
||||||
# pyupgrade
|
|
||||||
"UP",
|
|
||||||
# flake8-bugbear
|
|
||||||
"B",
|
|
||||||
# flake8-simplify
|
|
||||||
"SIM",
|
|
||||||
# isort
|
|
||||||
"I",
|
|
||||||
# flake8-logging-format
|
|
||||||
"G",
|
|
||||||
]
|
|
||||||
ignore = [
|
|
||||||
# star imports
|
|
||||||
"F405", "F403",
|
|
||||||
# lambda expression assignment
|
|
||||||
"E731",
|
|
||||||
# Loop control variable not used within loop body
|
|
||||||
"B007",
|
|
||||||
# f-string format
|
|
||||||
"UP032",
|
|
||||||
# Can remove once 3.10+ is the minimum Python version
|
|
||||||
"UP007",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint.isort]
|
|
||||||
known-first-party = ["vllm"]
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
|
||||||
docstring-code-format = true
|
|
||||||
127
pyproject.toml
127
pyproject.toml
@@ -52,27 +52,106 @@ lora_filesystem_resolver = "vllm.plugins.lora_resolvers.filesystem_resolver:regi
|
|||||||
where = ["."]
|
where = ["."]
|
||||||
include = ["vllm*"]
|
include = ["vllm*"]
|
||||||
|
|
||||||
[tool.yapfignore]
|
|
||||||
ignore_patterns = [
|
|
||||||
".buildkite/**",
|
|
||||||
"benchmarks/**",
|
|
||||||
"build/**",
|
|
||||||
"examples/**",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
# Allow lines to be as long as 80.
|
|
||||||
line-length = 80
|
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"vllm/third_party/**" = ["ALL"]
|
"vllm/third_party/**" = ["ALL"]
|
||||||
"vllm/version.py" = ["F401"]
|
"vllm/version.py" = ["F401"]
|
||||||
"vllm/_version.py" = ["ALL"]
|
"vllm/_version.py" = ["ALL"]
|
||||||
# Python 3.8 typing - skip V0 code
|
# TEMPORARY! These ignores will be fixed forward
|
||||||
"vllm/attention/**/*.py" = ["UP006", "UP035"]
|
## Line length violations
|
||||||
"vllm/engine/**/*.py" = ["UP006", "UP035"]
|
"csrc/cutlass_extensions/vllm_cutlass_library_extension.py" = ["E501"]
|
||||||
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
"tests/compile/piecewise/test_simple.py" = ["E501"]
|
||||||
"vllm/worker/**/*.py" = ["UP006", "UP035"]
|
"tests/compile/piecewise/test_toy_llama.py" = ["E501", "B023"]
|
||||||
|
"tests/entrypoints/conftest.py" = ["E501"]
|
||||||
|
"tests/entrypoints/openai/test_audio.py" = ["E501"]
|
||||||
|
"tests/entrypoints/openai/test_chat.py" = ["E501"]
|
||||||
|
"tests/entrypoints/openai/test_chat_template.py" = ["E501"]
|
||||||
|
"tests/entrypoints/openai/test_chat_with_tool_reasoning.py" = ["E501"]
|
||||||
|
"tests/entrypoints/openai/test_completion_with_function_calling.py" = ["E501"]
|
||||||
|
"tests/entrypoints/openai/test_video.py" = ["E501"]
|
||||||
|
"tests/entrypoints/openai/test_vision.py" = ["E501"]
|
||||||
|
"tests/entrypoints/test_chat_utils.py" = ["E501"]
|
||||||
|
"tests/kernels/moe/modular_kernel_tools/common.py" = ["E501"]
|
||||||
|
"tests/models/language/generation/test_gemma.py" = ["E501"]
|
||||||
|
"tests/models/language/generation/test_mistral.py" = ["E501"]
|
||||||
|
"tests/models/multimodal/generation/test_ultravox.py" = ["E501"]
|
||||||
|
"tests/models/multimodal/generation/test_voxtral.py" = ["E501"]
|
||||||
|
"tests/models/multimodal/generation/vlm_utils/custom_inputs.py" = ["E501"]
|
||||||
|
"tests/tool_use/test_tool_choice_required.py" = ["E501"]
|
||||||
|
"tests/v1/attention/utils.py" = ["E501"]
|
||||||
|
"tests/v1/entrypoints/openai/responses/test_image.py" = ["E501"]
|
||||||
|
"tests/v1/kv_connector/nixl_integration/test_accuracy.py" = ["E501"]
|
||||||
|
"tests/v1/kv_connector/unit/test_offloading_connector.py" = ["E501"]
|
||||||
|
"tests/v1/logits_processors/test_custom_offline.py" = ["E501"]
|
||||||
|
"vllm/attention/ops/pallas_kv_cache_update.py" = ["E501"]
|
||||||
|
"vllm/compilation/collective_fusion.py" = ["E501"]
|
||||||
|
"vllm/compilation/wrapper.py" = ["E501"]
|
||||||
|
"vllm/config/vllm.py" = ["E501"]
|
||||||
|
"vllm/distributed/device_communicators/all2all.py" = ["E501"]
|
||||||
|
"vllm/entrypoints/openai/protocol.py" = ["E501"]
|
||||||
|
"vllm/lora/layers/vocal_parallel_embedding.py" = ["E501"]
|
||||||
|
"vllm/model_executor/model_loader/bitsandbytes_loader.py" = ["E501"]
|
||||||
|
"vllm/model_executor/models/bailing_moe.py" = ["E501"]
|
||||||
|
"vllm/model_executor/models/hyperclovax_vision.py" = ["E501"]
|
||||||
|
"vllm/model_executor/models/llama4_eagle.py" = ["E501"]
|
||||||
|
"vllm/model_executor/models/longcat_flash_mtp.py" = ["E501"]
|
||||||
|
"vllm/model_executor/models/phi4mm.py" = ["E501"]
|
||||||
|
"vllm/model_executor/models/qwen3_next.py" = ["E501"]
|
||||||
|
"vllm/model_executor/layers/quantization/ptpc_fp8.py" = ["E501"]
|
||||||
|
"vllm/v1/attention/backends/mla/common.py" = ["E501"]
|
||||||
|
"vllm/v1/engine/utils.py" = ["E501"]
|
||||||
|
"vllm/v1/utils.py" = ["E501"]
|
||||||
|
"vllm/v1/worker/gpu_model_runner.py" = ["E501"]
|
||||||
|
## Simplification rules
|
||||||
|
"tests/distributed/test_expert_placement.py" = ["SIM108"]
|
||||||
|
"tests/kernels/attention/test_cutlass_mla_decode.py" = ["SIM108"]
|
||||||
|
"tests/kernels/attention/test_flashmla.py" = ["SIM108"]
|
||||||
|
"tests/kernels/attention/test_lightning_attn.py" = ["SIM108"]
|
||||||
|
"tests/kernels/moe/test_pplx_moe.py" = ["SIM108"]
|
||||||
|
"tests/kernels/quantization/test_cutlass_scaled_mm.py" = ["SIM108"]
|
||||||
|
"tests/kernels/test_onednn.py" = ["SIM108"]
|
||||||
|
"tests/kernels/utils.py" = ["SIM108"]
|
||||||
|
"tests/multimodal/test_processing.py" = ["SIM108"]
|
||||||
|
"vllm/attention/ops/triton_reshape_and_cache_flash.py" = ["SIM108"]
|
||||||
|
"vllm/distributed/parallel_state.py" = ["SIM108"]
|
||||||
|
"vllm/entrypoints/chat_utils.py" = ["SIM108"]
|
||||||
|
"vllm/entrypoints/llm.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/batch_invariant.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/fused_moe/layer.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/fused_moe/modular_kernel.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/layernorm.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/lightning_attn.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py" = ["SIM103"]
|
||||||
|
"vllm/model_executor/layers/quantization/compressed_tensors/utils.py" = ["SIM110"]
|
||||||
|
"vllm/model_executor/layers/quantization/quark/utils.py" = ["SIM110"]
|
||||||
|
"vllm/utils/__init__.py" = ["SIM108"]
|
||||||
|
"vllm/v1/sample/ops/bad_words.py" = ["SIM108"]
|
||||||
|
"vllm/v1/sample/rejection_sampler.py" = ["SIM108"]
|
||||||
|
"vllm/v1/worker/tpu_model_runner.py" = ["SIM108"]
|
||||||
|
"vllm/_custom_ops.py" = ["SIM108"]
|
||||||
|
"tools/profiler/print_layerwise_table.py" = ["SIM118"]
|
||||||
|
## Loop variable binding issues
|
||||||
|
"tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"]
|
||||||
|
## Type annotation modernization and other rules
|
||||||
|
"vllm/attention/backends/abstract.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/attention/layer.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/attention/layers/chunked_local_attention.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/attention/ops/flashmla.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/attention/ops/paged_attn.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/engine/arg_utils.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/engine/metrics.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/engine/metrics_types.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/executor/executor_base.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/executor/msgspec_utils.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/executor/ray_distributed_executor.py" = ["UP035", "UP006", "SIM108", "SIM112"]
|
||||||
|
"vllm/executor/ray_utils.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/executor/uniproc_executor.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py" = ["UP035"]
|
||||||
|
## Type comparison issues
|
||||||
|
"vllm/multimodal/inputs.py" = ["E721"]
|
||||||
|
# End of temporary ignores
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
@@ -87,7 +166,7 @@ select = [
|
|||||||
# flake8-simplify
|
# flake8-simplify
|
||||||
"SIM",
|
"SIM",
|
||||||
# isort
|
# isort
|
||||||
# "I",
|
"I",
|
||||||
# flake8-logging-format
|
# flake8-logging-format
|
||||||
"G",
|
"G",
|
||||||
]
|
]
|
||||||
@@ -104,21 +183,15 @@ ignore = [
|
|||||||
"UP007",
|
"UP007",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
docstring-code-format = true
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
plugins = ['pydantic.mypy']
|
plugins = ['pydantic.mypy']
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
follow_imports = "silent"
|
follow_imports = "silent"
|
||||||
|
|
||||||
[tool.isort]
|
|
||||||
skip_glob = [
|
|
||||||
".buildkite/*",
|
|
||||||
"benchmarks/*",
|
|
||||||
"examples/*",
|
|
||||||
]
|
|
||||||
use_parentheses = true
|
|
||||||
skip_gitignore = true
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
markers = [
|
markers = [
|
||||||
"slow_test",
|
"slow_test",
|
||||||
|
|||||||
247
setup.py
247
setup.py
@@ -34,32 +34,36 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# cannot import envs directly because it depends on vllm,
|
# cannot import envs directly because it depends on vllm,
|
||||||
# which is not installed yet
|
# which is not installed yet
|
||||||
envs = load_module_from_path('envs', os.path.join(ROOT_DIR, 'vllm', 'envs.py'))
|
envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "vllm", "envs.py"))
|
||||||
|
|
||||||
VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE
|
VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE
|
||||||
|
|
||||||
if sys.platform.startswith("darwin") and VLLM_TARGET_DEVICE != "cpu":
|
if sys.platform.startswith("darwin") and VLLM_TARGET_DEVICE != "cpu":
|
||||||
logger.warning(
|
logger.warning("VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS")
|
||||||
"VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS")
|
|
||||||
VLLM_TARGET_DEVICE = "cpu"
|
VLLM_TARGET_DEVICE = "cpu"
|
||||||
elif not (sys.platform.startswith("linux")
|
elif not (sys.platform.startswith("linux") or sys.platform.startswith("darwin")):
|
||||||
or sys.platform.startswith("darwin")):
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"vLLM only supports Linux platform (including WSL) and MacOS."
|
"vLLM only supports Linux platform (including WSL) and MacOS."
|
||||||
"Building on %s, "
|
"Building on %s, "
|
||||||
"so vLLM may not be able to run correctly", sys.platform)
|
"so vLLM may not be able to run correctly",
|
||||||
|
sys.platform,
|
||||||
|
)
|
||||||
VLLM_TARGET_DEVICE = "empty"
|
VLLM_TARGET_DEVICE = "empty"
|
||||||
elif (sys.platform.startswith("linux") and torch.version.cuda is None
|
elif (
|
||||||
|
sys.platform.startswith("linux")
|
||||||
|
and torch.version.cuda is None
|
||||||
and os.getenv("VLLM_TARGET_DEVICE") is None
|
and os.getenv("VLLM_TARGET_DEVICE") is None
|
||||||
and torch.version.hip is None):
|
and torch.version.hip is None
|
||||||
|
):
|
||||||
# if cuda or hip is not available and VLLM_TARGET_DEVICE is not set,
|
# if cuda or hip is not available and VLLM_TARGET_DEVICE is not set,
|
||||||
# fallback to cpu
|
# fallback to cpu
|
||||||
VLLM_TARGET_DEVICE = "cpu"
|
VLLM_TARGET_DEVICE = "cpu"
|
||||||
|
|
||||||
|
|
||||||
def is_sccache_available() -> bool:
|
def is_sccache_available() -> bool:
|
||||||
return which("sccache") is not None and \
|
return which("sccache") is not None and not bool(
|
||||||
not bool(int(os.getenv("VLLM_DISABLE_SCCACHE", "0")))
|
int(os.getenv("VLLM_DISABLE_SCCACHE", "0"))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_ccache_available() -> bool:
|
def is_ccache_available() -> bool:
|
||||||
@@ -83,8 +87,7 @@ def is_url_available(url: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
class CMakeExtension(Extension):
|
class CMakeExtension(Extension):
|
||||||
|
def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None:
|
||||||
def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None:
|
|
||||||
super().__init__(name, sources=[], py_limited_api=True, **kwa)
|
super().__init__(name, sources=[], py_limited_api=True, **kwa)
|
||||||
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
|
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
|
||||||
|
|
||||||
@@ -121,8 +124,8 @@ class cmake_build_ext(build_ext):
|
|||||||
if nvcc_threads is not None:
|
if nvcc_threads is not None:
|
||||||
nvcc_threads = int(nvcc_threads)
|
nvcc_threads = int(nvcc_threads)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Using NVCC_THREADS=%d as the number of nvcc threads.",
|
"Using NVCC_THREADS=%d as the number of nvcc threads.", nvcc_threads
|
||||||
nvcc_threads)
|
)
|
||||||
else:
|
else:
|
||||||
nvcc_threads = 1
|
nvcc_threads = 1
|
||||||
num_jobs = max(1, num_jobs // nvcc_threads)
|
num_jobs = max(1, num_jobs // nvcc_threads)
|
||||||
@@ -146,36 +149,36 @@ class cmake_build_ext(build_ext):
|
|||||||
cfg = envs.CMAKE_BUILD_TYPE or default_cfg
|
cfg = envs.CMAKE_BUILD_TYPE or default_cfg
|
||||||
|
|
||||||
cmake_args = [
|
cmake_args = [
|
||||||
'-DCMAKE_BUILD_TYPE={}'.format(cfg),
|
"-DCMAKE_BUILD_TYPE={}".format(cfg),
|
||||||
'-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE),
|
"-DVLLM_TARGET_DEVICE={}".format(VLLM_TARGET_DEVICE),
|
||||||
]
|
]
|
||||||
|
|
||||||
verbose = envs.VERBOSE
|
verbose = envs.VERBOSE
|
||||||
if verbose:
|
if verbose:
|
||||||
cmake_args += ['-DCMAKE_VERBOSE_MAKEFILE=ON']
|
cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"]
|
||||||
|
|
||||||
if is_sccache_available():
|
if is_sccache_available():
|
||||||
cmake_args += [
|
cmake_args += [
|
||||||
'-DCMAKE_C_COMPILER_LAUNCHER=sccache',
|
"-DCMAKE_C_COMPILER_LAUNCHER=sccache",
|
||||||
'-DCMAKE_CXX_COMPILER_LAUNCHER=sccache',
|
"-DCMAKE_CXX_COMPILER_LAUNCHER=sccache",
|
||||||
'-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache',
|
"-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache",
|
||||||
'-DCMAKE_HIP_COMPILER_LAUNCHER=sccache',
|
"-DCMAKE_HIP_COMPILER_LAUNCHER=sccache",
|
||||||
]
|
]
|
||||||
elif is_ccache_available():
|
elif is_ccache_available():
|
||||||
cmake_args += [
|
cmake_args += [
|
||||||
'-DCMAKE_C_COMPILER_LAUNCHER=ccache',
|
"-DCMAKE_C_COMPILER_LAUNCHER=ccache",
|
||||||
'-DCMAKE_CXX_COMPILER_LAUNCHER=ccache',
|
"-DCMAKE_CXX_COMPILER_LAUNCHER=ccache",
|
||||||
'-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache',
|
"-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache",
|
||||||
'-DCMAKE_HIP_COMPILER_LAUNCHER=ccache',
|
"-DCMAKE_HIP_COMPILER_LAUNCHER=ccache",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Pass the python executable to cmake so it can find an exact
|
# Pass the python executable to cmake so it can find an exact
|
||||||
# match.
|
# match.
|
||||||
cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)]
|
cmake_args += ["-DVLLM_PYTHON_EXECUTABLE={}".format(sys.executable)]
|
||||||
|
|
||||||
# Pass the python path to cmake so it can reuse the build dependencies
|
# Pass the python path to cmake so it can reuse the build dependencies
|
||||||
# on subsequent calls to python.
|
# on subsequent calls to python.
|
||||||
cmake_args += ['-DVLLM_PYTHON_PATH={}'.format(":".join(sys.path))]
|
cmake_args += ["-DVLLM_PYTHON_PATH={}".format(":".join(sys.path))]
|
||||||
|
|
||||||
# Override the base directory for FetchContent downloads to $ROOT/.deps
|
# Override the base directory for FetchContent downloads to $ROOT/.deps
|
||||||
# This allows sharing dependencies between profiles,
|
# This allows sharing dependencies between profiles,
|
||||||
@@ -183,7 +186,7 @@ class cmake_build_ext(build_ext):
|
|||||||
# To override this, set the FETCHCONTENT_BASE_DIR environment variable.
|
# To override this, set the FETCHCONTENT_BASE_DIR environment variable.
|
||||||
fc_base_dir = os.path.join(ROOT_DIR, ".deps")
|
fc_base_dir = os.path.join(ROOT_DIR, ".deps")
|
||||||
fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir)
|
fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir)
|
||||||
cmake_args += ['-DFETCHCONTENT_BASE_DIR={}'.format(fc_base_dir)]
|
cmake_args += ["-DFETCHCONTENT_BASE_DIR={}".format(fc_base_dir)]
|
||||||
|
|
||||||
#
|
#
|
||||||
# Setup parallelism and build tool
|
# Setup parallelism and build tool
|
||||||
@@ -191,35 +194,36 @@ class cmake_build_ext(build_ext):
|
|||||||
num_jobs, nvcc_threads = self.compute_num_jobs()
|
num_jobs, nvcc_threads = self.compute_num_jobs()
|
||||||
|
|
||||||
if nvcc_threads:
|
if nvcc_threads:
|
||||||
cmake_args += ['-DNVCC_THREADS={}'.format(nvcc_threads)]
|
cmake_args += ["-DNVCC_THREADS={}".format(nvcc_threads)]
|
||||||
|
|
||||||
if is_ninja_available():
|
if is_ninja_available():
|
||||||
build_tool = ['-G', 'Ninja']
|
build_tool = ["-G", "Ninja"]
|
||||||
cmake_args += [
|
cmake_args += [
|
||||||
'-DCMAKE_JOB_POOL_COMPILE:STRING=compile',
|
"-DCMAKE_JOB_POOL_COMPILE:STRING=compile",
|
||||||
'-DCMAKE_JOB_POOLS:STRING=compile={}'.format(num_jobs),
|
"-DCMAKE_JOB_POOLS:STRING=compile={}".format(num_jobs),
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
# Default build tool to whatever cmake picks.
|
# Default build tool to whatever cmake picks.
|
||||||
build_tool = []
|
build_tool = []
|
||||||
# Make sure we use the nvcc from CUDA_HOME
|
# Make sure we use the nvcc from CUDA_HOME
|
||||||
if _is_cuda():
|
if _is_cuda():
|
||||||
cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc']
|
cmake_args += [f"-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc"]
|
||||||
|
|
||||||
other_cmake_args = os.environ.get("CMAKE_ARGS")
|
other_cmake_args = os.environ.get("CMAKE_ARGS")
|
||||||
if other_cmake_args:
|
if other_cmake_args:
|
||||||
cmake_args += other_cmake_args.split()
|
cmake_args += other_cmake_args.split()
|
||||||
|
|
||||||
subprocess.check_call(
|
subprocess.check_call(
|
||||||
['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args],
|
["cmake", ext.cmake_lists_dir, *build_tool, *cmake_args],
|
||||||
cwd=self.build_temp)
|
cwd=self.build_temp,
|
||||||
|
)
|
||||||
|
|
||||||
def build_extensions(self) -> None:
|
def build_extensions(self) -> None:
|
||||||
# Ensure that CMake is present and working
|
# Ensure that CMake is present and working
|
||||||
try:
|
try:
|
||||||
subprocess.check_output(['cmake', '--version'])
|
subprocess.check_output(["cmake", "--version"])
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise RuntimeError('Cannot find CMake executable') from e
|
raise RuntimeError("Cannot find CMake executable") from e
|
||||||
|
|
||||||
# Create build directory if it does not exist.
|
# Create build directory if it does not exist.
|
||||||
if not os.path.exists(self.build_temp):
|
if not os.path.exists(self.build_temp):
|
||||||
@@ -258,13 +262,18 @@ class cmake_build_ext(build_ext):
|
|||||||
# CMake appends the extension prefix to the install path,
|
# CMake appends the extension prefix to the install path,
|
||||||
# and outdir already contains that prefix, so we need to remove it.
|
# and outdir already contains that prefix, so we need to remove it.
|
||||||
prefix = outdir
|
prefix = outdir
|
||||||
for _ in range(ext.name.count('.')):
|
for _ in range(ext.name.count(".")):
|
||||||
prefix = prefix.parent
|
prefix = prefix.parent
|
||||||
|
|
||||||
# prefix here should actually be the same for all components
|
# prefix here should actually be the same for all components
|
||||||
install_args = [
|
install_args = [
|
||||||
"cmake", "--install", ".", "--prefix", prefix, "--component",
|
"cmake",
|
||||||
target_name(ext.name)
|
"--install",
|
||||||
|
".",
|
||||||
|
"--prefix",
|
||||||
|
prefix,
|
||||||
|
"--component",
|
||||||
|
target_name(ext.name),
|
||||||
]
|
]
|
||||||
subprocess.check_call(install_args, cwd=self.build_temp)
|
subprocess.check_call(install_args, cwd=self.build_temp)
|
||||||
|
|
||||||
@@ -275,12 +284,15 @@ class cmake_build_ext(build_ext):
|
|||||||
# copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current
|
# copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current
|
||||||
# directory so that they can be included in the editable build
|
# directory so that they can be included in the editable build
|
||||||
import glob
|
import glob
|
||||||
files = glob.glob(os.path.join(self.build_lib, "vllm",
|
|
||||||
"vllm_flash_attn", "**", "*.py"),
|
files = glob.glob(
|
||||||
recursive=True)
|
os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "**", "*.py"),
|
||||||
|
recursive=True,
|
||||||
|
)
|
||||||
for file in files:
|
for file in files:
|
||||||
dst_file = os.path.join("vllm/vllm_flash_attn",
|
dst_file = os.path.join(
|
||||||
file.split("vllm/vllm_flash_attn/")[-1])
|
"vllm/vllm_flash_attn", file.split("vllm/vllm_flash_attn/")[-1]
|
||||||
|
)
|
||||||
print(f"Copying {file} to {dst_file}")
|
print(f"Copying {file} to {dst_file}")
|
||||||
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
|
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
|
||||||
self.copy_file(file, dst_file)
|
self.copy_file(file, dst_file)
|
||||||
@@ -290,8 +302,7 @@ class precompiled_build_ext(build_ext):
|
|||||||
"""Disables extension building when using precompiled binaries."""
|
"""Disables extension building when using precompiled binaries."""
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
assert _is_cuda(
|
assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
|
||||||
), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
|
|
||||||
|
|
||||||
def build_extensions(self) -> None:
|
def build_extensions(self) -> None:
|
||||||
print("Skipping build_ext: using precompiled extensions.")
|
print("Skipping build_ext: using precompiled extensions.")
|
||||||
@@ -312,9 +323,9 @@ class precompiled_wheel_utils:
|
|||||||
wheel_filename = wheel_url_or_path.split("/")[-1]
|
wheel_filename = wheel_url_or_path.split("/")[-1]
|
||||||
temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
|
temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
|
||||||
wheel_path = os.path.join(temp_dir, wheel_filename)
|
wheel_path = os.path.join(temp_dir, wheel_filename)
|
||||||
print(f"Downloading wheel from {wheel_url_or_path} "
|
print(f"Downloading wheel from {wheel_url_or_path} to {wheel_path}")
|
||||||
f"to {wheel_path}")
|
|
||||||
from urllib.request import urlretrieve
|
from urllib.request import urlretrieve
|
||||||
|
|
||||||
urlretrieve(wheel_url_or_path, filename=wheel_path)
|
urlretrieve(wheel_url_or_path, filename=wheel_path)
|
||||||
else:
|
else:
|
||||||
wheel_path = wheel_url_or_path
|
wheel_path = wheel_url_or_path
|
||||||
@@ -335,25 +346,29 @@ class precompiled_wheel_utils:
|
|||||||
]
|
]
|
||||||
|
|
||||||
compiled_regex = re.compile(
|
compiled_regex = re.compile(
|
||||||
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
|
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
|
||||||
|
)
|
||||||
file_members = list(
|
file_members = list(
|
||||||
filter(lambda x: x.filename in files_to_copy,
|
filter(lambda x: x.filename in files_to_copy, wheel.filelist)
|
||||||
wheel.filelist))
|
)
|
||||||
file_members += list(
|
file_members += list(
|
||||||
filter(lambda x: compiled_regex.match(x.filename),
|
filter(lambda x: compiled_regex.match(x.filename), wheel.filelist)
|
||||||
wheel.filelist))
|
)
|
||||||
|
|
||||||
for file in file_members:
|
for file in file_members:
|
||||||
print(f"[extract] {file.filename}")
|
print(f"[extract] {file.filename}")
|
||||||
target_path = os.path.join(".", file.filename)
|
target_path = os.path.join(".", file.filename)
|
||||||
os.makedirs(os.path.dirname(target_path), exist_ok=True)
|
os.makedirs(os.path.dirname(target_path), exist_ok=True)
|
||||||
with wheel.open(file.filename) as src, open(
|
with (
|
||||||
target_path, "wb") as dst:
|
wheel.open(file.filename) as src,
|
||||||
|
open(target_path, "wb") as dst,
|
||||||
|
):
|
||||||
shutil.copyfileobj(src, dst)
|
shutil.copyfileobj(src, dst)
|
||||||
|
|
||||||
pkg = os.path.dirname(file.filename).replace("/", ".")
|
pkg = os.path.dirname(file.filename).replace("/", ".")
|
||||||
package_data_patch.setdefault(pkg, []).append(
|
package_data_patch.setdefault(pkg, []).append(
|
||||||
os.path.basename(file.filename))
|
os.path.basename(file.filename)
|
||||||
|
)
|
||||||
|
|
||||||
return package_data_patch
|
return package_data_patch
|
||||||
finally:
|
finally:
|
||||||
@@ -369,10 +384,13 @@ class precompiled_wheel_utils:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Get the latest commit hash of the upstream main branch.
|
# Get the latest commit hash of the upstream main branch.
|
||||||
resp_json = subprocess.check_output([
|
resp_json = subprocess.check_output(
|
||||||
"curl", "-s",
|
[
|
||||||
"https://api.github.com/repos/vllm-project/vllm/commits/main"
|
"curl",
|
||||||
]).decode("utf-8")
|
"-s",
|
||||||
|
"https://api.github.com/repos/vllm-project/vllm/commits/main",
|
||||||
|
]
|
||||||
|
).decode("utf-8")
|
||||||
upstream_main_commit = json.loads(resp_json)["sha"]
|
upstream_main_commit = json.loads(resp_json)["sha"]
|
||||||
|
|
||||||
# In Docker build context, .git may be immutable or missing.
|
# In Docker build context, .git may be immutable or missing.
|
||||||
@@ -382,25 +400,32 @@ class precompiled_wheel_utils:
|
|||||||
# Check if the upstream_main_commit exists in the local repo
|
# Check if the upstream_main_commit exists in the local repo
|
||||||
try:
|
try:
|
||||||
subprocess.check_output(
|
subprocess.check_output(
|
||||||
["git", "cat-file", "-e", f"{upstream_main_commit}"])
|
["git", "cat-file", "-e", f"{upstream_main_commit}"]
|
||||||
|
)
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
# If not present, fetch it from the remote repository.
|
# If not present, fetch it from the remote repository.
|
||||||
# Note that this does not update any local branches,
|
# Note that this does not update any local branches,
|
||||||
# but ensures that this commit ref and its history are
|
# but ensures that this commit ref and its history are
|
||||||
# available in our local repo.
|
# available in our local repo.
|
||||||
subprocess.check_call([
|
subprocess.check_call(
|
||||||
"git", "fetch", "https://github.com/vllm-project/vllm",
|
["git", "fetch", "https://github.com/vllm-project/vllm", "main"]
|
||||||
"main"
|
)
|
||||||
])
|
|
||||||
|
|
||||||
# Then get the commit hash of the current branch that is the same as
|
# Then get the commit hash of the current branch that is the same as
|
||||||
# the upstream main commit.
|
# the upstream main commit.
|
||||||
current_branch = subprocess.check_output(
|
current_branch = (
|
||||||
["git", "branch", "--show-current"]).decode("utf-8").strip()
|
subprocess.check_output(["git", "branch", "--show-current"])
|
||||||
|
.decode("utf-8")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
|
||||||
base_commit = subprocess.check_output([
|
base_commit = (
|
||||||
"git", "merge-base", f"{upstream_main_commit}", current_branch
|
subprocess.check_output(
|
||||||
]).decode("utf-8").strip()
|
["git", "merge-base", f"{upstream_main_commit}", current_branch]
|
||||||
|
)
|
||||||
|
.decode("utf-8")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
return base_commit
|
return base_commit
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
raise ValueError(err) from None
|
raise ValueError(err) from None
|
||||||
@@ -408,7 +433,9 @@ class precompiled_wheel_utils:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to get the base commit in the main branch. "
|
"Failed to get the base commit in the main branch. "
|
||||||
"Using the nightly wheel. The libraries in this "
|
"Using the nightly wheel. The libraries in this "
|
||||||
"wheel may not be compatible with your dev branch: %s", err)
|
"wheel may not be compatible with your dev branch: %s",
|
||||||
|
err,
|
||||||
|
)
|
||||||
return "nightly"
|
return "nightly"
|
||||||
|
|
||||||
|
|
||||||
@@ -418,12 +445,13 @@ def _no_device() -> bool:
|
|||||||
|
|
||||||
def _is_cuda() -> bool:
|
def _is_cuda() -> bool:
|
||||||
has_cuda = torch.version.cuda is not None
|
has_cuda = torch.version.cuda is not None
|
||||||
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu())
|
return VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu()
|
||||||
|
|
||||||
|
|
||||||
def _is_hip() -> bool:
|
def _is_hip() -> bool:
|
||||||
return (VLLM_TARGET_DEVICE == "cuda"
|
return (
|
||||||
or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None
|
VLLM_TARGET_DEVICE == "cuda" or VLLM_TARGET_DEVICE == "rocm"
|
||||||
|
) and torch.version.hip is not None
|
||||||
|
|
||||||
|
|
||||||
def _is_tpu() -> bool:
|
def _is_tpu() -> bool:
|
||||||
@@ -462,8 +490,12 @@ def get_rocm_version():
|
|||||||
minor = ctypes.c_uint32()
|
minor = ctypes.c_uint32()
|
||||||
patch = ctypes.c_uint32()
|
patch = ctypes.c_uint32()
|
||||||
|
|
||||||
if (get_rocm_core_version(ctypes.byref(major), ctypes.byref(minor),
|
if (
|
||||||
ctypes.byref(patch)) == 0):
|
get_rocm_core_version(
|
||||||
|
ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch)
|
||||||
|
)
|
||||||
|
== 0
|
||||||
|
):
|
||||||
return f"{major.value}.{minor.value}.{patch.value}"
|
return f"{major.value}.{minor.value}.{patch.value}"
|
||||||
return None
|
return None
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -476,8 +508,9 @@ def get_nvcc_cuda_version() -> Version:
|
|||||||
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
|
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
|
||||||
"""
|
"""
|
||||||
assert CUDA_HOME is not None, "CUDA_HOME is not set"
|
assert CUDA_HOME is not None, "CUDA_HOME is not set"
|
||||||
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
|
nvcc_output = subprocess.check_output(
|
||||||
universal_newlines=True)
|
[CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True
|
||||||
|
)
|
||||||
output = nvcc_output.split()
|
output = nvcc_output.split()
|
||||||
release_idx = output.index("release") + 1
|
release_idx = output.index("release") + 1
|
||||||
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
|
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
|
||||||
@@ -489,14 +522,20 @@ def get_gaudi_sw_version():
|
|||||||
Returns the driver version.
|
Returns the driver version.
|
||||||
"""
|
"""
|
||||||
# Enable console printing for `hl-smi` check
|
# Enable console printing for `hl-smi` check
|
||||||
output = subprocess.run("hl-smi",
|
output = subprocess.run(
|
||||||
|
"hl-smi",
|
||||||
shell=True,
|
shell=True,
|
||||||
text=True,
|
text=True,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
env={"ENABLE_CONSOLE": "true"})
|
env={"ENABLE_CONSOLE": "true"},
|
||||||
|
)
|
||||||
if output.returncode == 0 and output.stdout:
|
if output.returncode == 0 and output.stdout:
|
||||||
return output.stdout.split("\n")[2].replace(
|
return (
|
||||||
" ", "").split(":")[1][:-1].split("-")[0]
|
output.stdout.split("\n")[2]
|
||||||
|
.replace(" ", "")
|
||||||
|
.split(":")[1][:-1]
|
||||||
|
.split("-")[0]
|
||||||
|
)
|
||||||
return "0.0.0" # when hl-smi is not available
|
return "0.0.0" # when hl-smi is not available
|
||||||
|
|
||||||
|
|
||||||
@@ -546,8 +585,11 @@ def get_requirements() -> list[str]:
|
|||||||
for line in requirements:
|
for line in requirements:
|
||||||
if line.startswith("-r "):
|
if line.startswith("-r "):
|
||||||
resolved_requirements += _read_requirements(line.split()[1])
|
resolved_requirements += _read_requirements(line.split()[1])
|
||||||
elif not line.startswith("--") and not line.startswith(
|
elif (
|
||||||
"#") and line.strip() != "":
|
not line.startswith("--")
|
||||||
|
and not line.startswith("#")
|
||||||
|
and line.strip() != ""
|
||||||
|
):
|
||||||
resolved_requirements.append(line)
|
resolved_requirements.append(line)
|
||||||
return resolved_requirements
|
return resolved_requirements
|
||||||
|
|
||||||
@@ -558,7 +600,7 @@ def get_requirements() -> list[str]:
|
|||||||
cuda_major, cuda_minor = torch.version.cuda.split(".")
|
cuda_major, cuda_minor = torch.version.cuda.split(".")
|
||||||
modified_requirements = []
|
modified_requirements = []
|
||||||
for req in requirements:
|
for req in requirements:
|
||||||
if ("vllm-flash-attn" in req and cuda_major != "12"):
|
if "vllm-flash-attn" in req and cuda_major != "12":
|
||||||
# vllm-flash-attn is built only for CUDA 12.x.
|
# vllm-flash-attn is built only for CUDA 12.x.
|
||||||
# Skip for other versions.
|
# Skip for other versions.
|
||||||
continue
|
continue
|
||||||
@@ -573,8 +615,7 @@ def get_requirements() -> list[str]:
|
|||||||
elif _is_xpu():
|
elif _is_xpu():
|
||||||
requirements = _read_requirements("xpu.txt")
|
requirements = _read_requirements("xpu.txt")
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("Unsupported platform, please use CUDA, ROCm, or CPU.")
|
||||||
"Unsupported platform, please use CUDA, ROCm, or CPU.")
|
|
||||||
return requirements
|
return requirements
|
||||||
|
|
||||||
|
|
||||||
@@ -590,14 +631,13 @@ if _is_cuda():
|
|||||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
|
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
|
||||||
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
|
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
|
||||||
# FA3 requires CUDA 12.3 or later
|
# FA3 requires CUDA 12.3 or later
|
||||||
ext_modules.append(
|
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||||
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
|
||||||
# Optional since this doesn't get built (produce an .so file) when
|
# Optional since this doesn't get built (produce an .so file) when
|
||||||
# not targeting a hopper system
|
# not targeting a hopper system
|
||||||
|
ext_modules.append(CMakeExtension(name="vllm._flashmla_C", optional=True))
|
||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
CMakeExtension(name="vllm._flashmla_C", optional=True))
|
CMakeExtension(name="vllm._flashmla_extension_C", optional=True)
|
||||||
ext_modules.append(
|
)
|
||||||
CMakeExtension(name="vllm._flashmla_extension_C", optional=True))
|
|
||||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||||
|
|
||||||
if _build_custom_ops():
|
if _build_custom_ops():
|
||||||
@@ -619,6 +659,7 @@ if envs.VLLM_USE_PRECOMPILED:
|
|||||||
wheel_url = wheel_location
|
wheel_url = wheel_location
|
||||||
else:
|
else:
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
arch = platform.machine()
|
arch = platform.machine()
|
||||||
if arch == "x86_64":
|
if arch == "x86_64":
|
||||||
wheel_tag = "manylinux1_x86_64"
|
wheel_tag = "manylinux1_x86_64"
|
||||||
@@ -628,8 +669,11 @@ if envs.VLLM_USE_PRECOMPILED:
|
|||||||
raise ValueError(f"Unsupported architecture: {arch}")
|
raise ValueError(f"Unsupported architecture: {arch}")
|
||||||
base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
|
base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
|
||||||
wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
|
wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
|
||||||
nightly_wheel_url = f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
|
nightly_wheel_url = (
|
||||||
|
f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
|
||||||
|
)
|
||||||
from urllib.request import urlopen
|
from urllib.request import urlopen
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with urlopen(wheel_url) as resp:
|
with urlopen(wheel_url) as resp:
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
@@ -638,8 +682,7 @@ if envs.VLLM_USE_PRECOMPILED:
|
|||||||
print(f"[warn] Falling back to nightly wheel: {e}")
|
print(f"[warn] Falling back to nightly wheel: {e}")
|
||||||
wheel_url = nightly_wheel_url
|
wheel_url = nightly_wheel_url
|
||||||
|
|
||||||
patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(
|
patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(wheel_url)
|
||||||
wheel_url)
|
|
||||||
for pkg, files in patch.items():
|
for pkg, files in patch.items():
|
||||||
package_data.setdefault(pkg, []).extend(files)
|
package_data.setdefault(pkg, []).extend(files)
|
||||||
|
|
||||||
@@ -650,8 +693,9 @@ if not ext_modules:
|
|||||||
cmdclass = {}
|
cmdclass = {}
|
||||||
else:
|
else:
|
||||||
cmdclass = {
|
cmdclass = {
|
||||||
"build_ext":
|
"build_ext": precompiled_build_ext
|
||||||
precompiled_build_ext if envs.VLLM_USE_PRECOMPILED else cmake_build_ext
|
if envs.VLLM_USE_PRECOMPILED
|
||||||
|
else cmake_build_ext
|
||||||
}
|
}
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
@@ -664,8 +708,11 @@ setup(
|
|||||||
"tensorizer": ["tensorizer==2.10.1"],
|
"tensorizer": ["tensorizer==2.10.1"],
|
||||||
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||||
"runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"],
|
"runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"],
|
||||||
"audio": ["librosa", "soundfile",
|
"audio": [
|
||||||
"mistral_common[audio]"], # Required for audio processing
|
"librosa",
|
||||||
|
"soundfile",
|
||||||
|
"mistral_common[audio]",
|
||||||
|
], # Required for audio processing
|
||||||
"video": [], # Kept for backwards compatibility
|
"video": [], # Kept for backwards compatibility
|
||||||
# FlashInfer should be updated together with the Dockerfile
|
# FlashInfer should be updated together with the Dockerfile
|
||||||
"flashinfer": ["flashinfer-python==0.3.1"],
|
"flashinfer": ["flashinfer-python==0.3.1"],
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
|
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import weakref
|
import weakref
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
@@ -37,16 +38,21 @@ def test_vllm_gc_ed():
|
|||||||
|
|
||||||
|
|
||||||
def _fix_prompt_embed_outputs(
|
def _fix_prompt_embed_outputs(
|
||||||
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner,
|
vllm_outputs: list[tuple[list[int], str]],
|
||||||
example_prompts: list[str]) -> list[tuple[list[int], str]]:
|
hf_model: HfRunner,
|
||||||
|
example_prompts: list[str],
|
||||||
|
) -> list[tuple[list[int], str]]:
|
||||||
fixed_vllm_outputs = []
|
fixed_vllm_outputs = []
|
||||||
for vllm_output, hf_input, prompt in zip(
|
for vllm_output, hf_input, prompt in zip(
|
||||||
vllm_outputs, hf_model.get_inputs(example_prompts),
|
vllm_outputs, hf_model.get_inputs(example_prompts), example_prompts
|
||||||
example_prompts):
|
):
|
||||||
hf_input_ids = hf_input["input_ids"].tolist()[0]
|
hf_input_ids = hf_input["input_ids"].tolist()[0]
|
||||||
fixed_vllm_outputs.append(
|
fixed_vllm_outputs.append(
|
||||||
(hf_input_ids + vllm_output[0][len(hf_input_ids):],
|
(
|
||||||
prompt + vllm_output[1]))
|
hf_input_ids + vllm_output[0][len(hf_input_ids) :],
|
||||||
|
prompt + vllm_output[1],
|
||||||
|
)
|
||||||
|
)
|
||||||
return fixed_vllm_outputs
|
return fixed_vllm_outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -69,8 +75,7 @@ def test_models(
|
|||||||
enable_prompt_embeds: bool,
|
enable_prompt_embeds: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
|
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
|
||||||
pytest.skip(
|
pytest.skip(f"{backend} does not support gemma2 with full context length.")
|
||||||
f"{backend} does not support gemma2 with full context length.")
|
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
@@ -78,16 +83,18 @@ def test_models(
|
|||||||
# 5042 tokens for gemma2
|
# 5042 tokens for gemma2
|
||||||
# gemma2 has alternating sliding window size of 4096
|
# gemma2 has alternating sliding window size of 4096
|
||||||
# we need a prompt with more than 4096 tokens to test the sliding window
|
# we need a prompt with more than 4096 tokens to test the sliding window
|
||||||
prompt = "The following numbers of the sequence " + ", ".join(
|
prompt = (
|
||||||
str(i) for i in range(1024)) + " are:"
|
"The following numbers of the sequence "
|
||||||
|
+ ", ".join(str(i) for i in range(1024))
|
||||||
|
+ " are:"
|
||||||
|
)
|
||||||
example_prompts = [prompt]
|
example_prompts = [prompt]
|
||||||
|
|
||||||
with hf_runner(model) as hf_model:
|
with hf_runner(model) as hf_model:
|
||||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
if enable_prompt_embeds:
|
if enable_prompt_embeds:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
prompt_embeds = hf_model.get_prompt_embeddings(
|
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||||
example_prompts)
|
|
||||||
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
model,
|
model,
|
||||||
@@ -99,13 +106,12 @@ def test_models(
|
|||||||
distributed_executor_backend=model_executor,
|
distributed_executor_backend=model_executor,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
if enable_prompt_embeds:
|
if enable_prompt_embeds:
|
||||||
vllm_outputs = vllm_model.generate_greedy(
|
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
||||||
prompt_embeds, max_tokens)
|
|
||||||
vllm_outputs = _fix_prompt_embed_outputs(
|
vllm_outputs = _fix_prompt_embed_outputs(
|
||||||
vllm_outputs, hf_model, example_prompts)
|
vllm_outputs, hf_model, example_prompts
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
vllm_outputs = vllm_model.generate_greedy(
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
example_prompts, max_tokens)
|
|
||||||
|
|
||||||
check_outputs_equal(
|
check_outputs_equal(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
@@ -117,21 +123,18 @@ def test_models(
|
|||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, distributed_executor_backend, attention_backend, "
|
"model, distributed_executor_backend, attention_backend, test_suite, extra_env",
|
||||||
"test_suite, extra_env", [
|
[
|
||||||
("distilbert/distilgpt2", "ray", "", "L4", {}),
|
("distilbert/distilgpt2", "ray", "", "L4", {}),
|
||||||
("distilbert/distilgpt2", "mp", "", "L4", {}),
|
("distilbert/distilgpt2", "mp", "", "L4", {}),
|
||||||
("distilbert/distilgpt2", "ray", "", "L4", {
|
("distilbert/distilgpt2", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
|
||||||
"VLLM_SLEEP_WHEN_IDLE": "1"
|
("distilbert/distilgpt2", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
|
||||||
}),
|
|
||||||
("distilbert/distilgpt2", "mp", "", "L4", {
|
|
||||||
"VLLM_SLEEP_WHEN_IDLE": "1"
|
|
||||||
}),
|
|
||||||
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
|
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
|
||||||
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
|
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
|
||||||
("distilbert/distilgpt2", "ray", "", "A100", {}),
|
("distilbert/distilgpt2", "ray", "", "A100", {}),
|
||||||
("distilbert/distilgpt2", "mp", "", "A100", {}),
|
("distilbert/distilgpt2", "mp", "", "A100", {}),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||||
def test_models_distributed(
|
def test_models_distributed(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
@@ -149,11 +152,14 @@ def test_models_distributed(
|
|||||||
pytest.skip(f"Skip test for {test_suite}")
|
pytest.skip(f"Skip test for {test_suite}")
|
||||||
|
|
||||||
with monkeypatch.context() as monkeypatch_context:
|
with monkeypatch.context() as monkeypatch_context:
|
||||||
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
|
if (
|
||||||
|
model == "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
and distributed_executor_backend == "ray"
|
||||||
|
and attention_backend == ""
|
||||||
|
and test_suite == "L4"
|
||||||
|
): # noqa
|
||||||
if enable_prompt_embeds:
|
if enable_prompt_embeds:
|
||||||
pytest.skip(
|
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
|
||||||
"enable_prompt_embeds does not work with ray compiled dag."
|
|
||||||
)
|
|
||||||
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
|
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
|
||||||
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
|
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
|
||||||
|
|
||||||
@@ -185,20 +191,16 @@ def test_models_distributed(
|
|||||||
if enable_prompt_embeds:
|
if enable_prompt_embeds:
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
prompt_embeds = hf_model.get_prompt_embeddings(
|
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||||
example_prompts)
|
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
||||||
vllm_outputs = vllm_model.generate_greedy(
|
|
||||||
prompt_embeds, max_tokens)
|
|
||||||
vllm_outputs = _fix_prompt_embed_outputs(
|
vllm_outputs = _fix_prompt_embed_outputs(
|
||||||
vllm_outputs, hf_model, example_prompts)
|
vllm_outputs, hf_model, example_prompts
|
||||||
hf_outputs = hf_model.generate_greedy(
|
)
|
||||||
example_prompts, max_tokens)
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
else:
|
else:
|
||||||
vllm_outputs = vllm_model.generate_greedy(
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
example_prompts, max_tokens)
|
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
hf_outputs = hf_model.generate_greedy(
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
example_prompts, max_tokens)
|
|
||||||
|
|
||||||
check_outputs_equal(
|
check_outputs_equal(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
@@ -209,27 +211,23 @@ def test_models_distributed(
|
|||||||
|
|
||||||
|
|
||||||
def test_failed_model_execution(vllm_runner, monkeypatch) -> None:
|
def test_failed_model_execution(vllm_runner, monkeypatch) -> None:
|
||||||
|
|
||||||
from vllm.envs import VLLM_USE_V1
|
from vllm.envs import VLLM_USE_V1
|
||||||
|
|
||||||
if not VLLM_USE_V1:
|
if not VLLM_USE_V1:
|
||||||
pytest.skip("Skipping V0 test, dump input not supported")
|
pytest.skip("Skipping V0 test, dump input not supported")
|
||||||
|
|
||||||
# Needed to mock an error in the same process
|
# Needed to mock an error in the same process
|
||||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
|
|
||||||
with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model:
|
with vllm_runner("facebook/opt-125m", enforce_eager=True) as vllm_model:
|
||||||
if isinstance(vllm_model.llm.llm_engine, LLMEngineV1):
|
if isinstance(vllm_model.llm.llm_engine, LLMEngineV1):
|
||||||
v1_test_failed_model_execution(vllm_model)
|
v1_test_failed_model_execution(vllm_model)
|
||||||
|
|
||||||
|
|
||||||
def v1_test_failed_model_execution(vllm_model):
|
def v1_test_failed_model_execution(vllm_model):
|
||||||
|
|
||||||
engine = vllm_model.llm.llm_engine
|
engine = vllm_model.llm.llm_engine
|
||||||
mocked_execute_model = Mock(
|
mocked_execute_model = Mock(side_effect=RuntimeError("Mocked Critical Error"))
|
||||||
side_effect=RuntimeError("Mocked Critical Error"))
|
engine.engine_core.engine_core.model_executor.execute_model = mocked_execute_model
|
||||||
engine.engine_core.engine_core.model_executor.execute_model =\
|
|
||||||
mocked_execute_model
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as exc_info:
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
prompts = [
|
prompts = [
|
||||||
|
|||||||
@@ -5,5 +5,6 @@ from ..utils import compare_two_settings
|
|||||||
|
|
||||||
|
|
||||||
def test_cpu_offload():
|
def test_cpu_offload():
|
||||||
compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", [],
|
compare_two_settings(
|
||||||
["--cpu-offload-gb", "1"])
|
"meta-llama/Llama-3.2-1B-Instruct", [], ["--cpu-offload-gb", "1"]
|
||||||
|
)
|
||||||
|
|||||||
@@ -23,13 +23,13 @@ def test_python_error():
|
|||||||
tensors = []
|
tensors = []
|
||||||
with allocator.use_memory_pool():
|
with allocator.use_memory_pool():
|
||||||
# allocate 70% of the total memory
|
# allocate 70% of the total memory
|
||||||
x = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda')
|
x = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda")
|
||||||
tensors.append(x)
|
tensors.append(x)
|
||||||
# release the memory
|
# release the memory
|
||||||
allocator.sleep()
|
allocator.sleep()
|
||||||
|
|
||||||
# allocate more memory than the total memory
|
# allocate more memory than the total memory
|
||||||
y = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda')
|
y = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda")
|
||||||
tensors.append(y)
|
tensors.append(y)
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
# when the allocator is woken up, it should raise an error
|
# when the allocator is woken up, it should raise an error
|
||||||
@@ -41,17 +41,17 @@ def test_python_error():
|
|||||||
def test_basic_cumem():
|
def test_basic_cumem():
|
||||||
# some tensors from default memory pool
|
# some tensors from default memory pool
|
||||||
shape = (1024, 1024)
|
shape = (1024, 1024)
|
||||||
x = torch.empty(shape, device='cuda')
|
x = torch.empty(shape, device="cuda")
|
||||||
x.zero_()
|
x.zero_()
|
||||||
|
|
||||||
# some tensors from custom memory pool
|
# some tensors from custom memory pool
|
||||||
allocator = CuMemAllocator.get_instance()
|
allocator = CuMemAllocator.get_instance()
|
||||||
with allocator.use_memory_pool():
|
with allocator.use_memory_pool():
|
||||||
# custom memory pool
|
# custom memory pool
|
||||||
y = torch.empty(shape, device='cuda')
|
y = torch.empty(shape, device="cuda")
|
||||||
y.zero_()
|
y.zero_()
|
||||||
y += 1
|
y += 1
|
||||||
z = torch.empty(shape, device='cuda')
|
z = torch.empty(shape, device="cuda")
|
||||||
z.zero_()
|
z.zero_()
|
||||||
z += 2
|
z += 2
|
||||||
|
|
||||||
@@ -74,16 +74,16 @@ def test_basic_cumem():
|
|||||||
def test_cumem_with_cudagraph():
|
def test_cumem_with_cudagraph():
|
||||||
allocator = CuMemAllocator.get_instance()
|
allocator = CuMemAllocator.get_instance()
|
||||||
with allocator.use_memory_pool():
|
with allocator.use_memory_pool():
|
||||||
weight = torch.eye(1024, device='cuda')
|
weight = torch.eye(1024, device="cuda")
|
||||||
with allocator.use_memory_pool(tag="discard"):
|
with allocator.use_memory_pool(tag="discard"):
|
||||||
cache = torch.empty(1024, 1024, device='cuda')
|
cache = torch.empty(1024, 1024, device="cuda")
|
||||||
|
|
||||||
def model(x):
|
def model(x):
|
||||||
out = x @ weight
|
out = x @ weight
|
||||||
cache[:out.size(0)].copy_(out)
|
cache[: out.size(0)].copy_(out)
|
||||||
return out + 1
|
return out + 1
|
||||||
|
|
||||||
x = torch.empty(128, 1024, device='cuda')
|
x = torch.empty(128, 1024, device="cuda")
|
||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
model(x)
|
model(x)
|
||||||
@@ -109,7 +109,7 @@ def test_cumem_with_cudagraph():
|
|||||||
model_graph.replay()
|
model_graph.replay()
|
||||||
|
|
||||||
# cache content is as expected
|
# cache content is as expected
|
||||||
assert torch.allclose(x, cache[:x.size(0)])
|
assert torch.allclose(x, cache[: x.size(0)])
|
||||||
|
|
||||||
# output content is as expected
|
# output content is as expected
|
||||||
assert torch.allclose(y, x + 1)
|
assert torch.allclose(y, x + 1)
|
||||||
@@ -123,7 +123,8 @@ def test_cumem_with_cudagraph():
|
|||||||
("meta-llama/Llama-3.2-1B", True),
|
("meta-llama/Llama-3.2-1B", True),
|
||||||
# sleep mode with pytorch checkpoint
|
# sleep mode with pytorch checkpoint
|
||||||
("facebook/opt-125m", True),
|
("facebook/opt-125m", True),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool):
|
def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool):
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
assert use_v1
|
assert use_v1
|
||||||
|
|||||||
@@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
|||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_bench_latency():
|
def test_bench_latency():
|
||||||
command = [
|
command = [
|
||||||
"vllm", "bench", "latency", "--model", MODEL_NAME, "--input-len", "32",
|
"vllm",
|
||||||
"--output-len", "1", "--enforce-eager", "--load-format", "dummy"
|
"bench",
|
||||||
|
"latency",
|
||||||
|
"--model",
|
||||||
|
MODEL_NAME,
|
||||||
|
"--input-len",
|
||||||
|
"32",
|
||||||
|
"--output-len",
|
||||||
|
"1",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--load-format",
|
||||||
|
"dummy",
|
||||||
]
|
]
|
||||||
result = subprocess.run(command, capture_output=True, text=True)
|
result = subprocess.run(command, capture_output=True, text=True)
|
||||||
print(result.stdout)
|
print(result.stdout)
|
||||||
|
|||||||
@@ -7,8 +7,11 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset,
|
from vllm.benchmarks.datasets import (
|
||||||
SampleRequest)
|
RandomDataset,
|
||||||
|
RandomMultiModalDataset,
|
||||||
|
SampleRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@@ -27,11 +30,9 @@ class Params(NamedTuple):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def random_dataset_params() -> Params:
|
def random_dataset_params() -> Params:
|
||||||
return Params(num_requests=16,
|
return Params(
|
||||||
prefix_len=7,
|
num_requests=16, prefix_len=7, range_ratio=0.3, input_len=50, output_len=20
|
||||||
range_ratio=0.3,
|
)
|
||||||
input_len=50,
|
|
||||||
output_len=20)
|
|
||||||
|
|
||||||
|
|
||||||
def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
|
def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
|
||||||
@@ -39,13 +40,15 @@ def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
|
|||||||
return (req.prompt, req.prompt_len, req.expected_output_len)
|
return (req.prompt, req.prompt_len, req.expected_output_len)
|
||||||
|
|
||||||
|
|
||||||
def _collect_samples(dataset: RandomDataset,
|
def _collect_samples(
|
||||||
|
dataset: RandomDataset,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
num_requests: int = 16,
|
num_requests: int = 16,
|
||||||
prefix_len: int = 7,
|
prefix_len: int = 7,
|
||||||
range_ratio: float = 0.3,
|
range_ratio: float = 0.3,
|
||||||
input_len: int = 50,
|
input_len: int = 50,
|
||||||
output_len: int = 20) -> list[tuple[str, int, int]]:
|
output_len: int = 20,
|
||||||
|
) -> list[tuple[str, int, int]]:
|
||||||
samples = dataset.sample(
|
samples = dataset.sample(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_requests=num_requests,
|
num_requests=num_requests,
|
||||||
@@ -59,8 +62,8 @@ def _collect_samples(dataset: RandomDataset,
|
|||||||
|
|
||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_random_dataset_same_seed(
|
def test_random_dataset_same_seed(
|
||||||
hf_tokenizer: PreTrainedTokenizerBase,
|
hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params
|
||||||
random_dataset_params: Params) -> None:
|
) -> None:
|
||||||
"""Same seed should yield identical outputs, even if global RNGs change.
|
"""Same seed should yield identical outputs, even if global RNGs change.
|
||||||
|
|
||||||
This guards against accidental reliance on Python's random or np.random
|
This guards against accidental reliance on Python's random or np.random
|
||||||
@@ -70,13 +73,15 @@ def test_random_dataset_same_seed(
|
|||||||
common_seed = 123
|
common_seed = 123
|
||||||
dataset_a = RandomDataset(random_seed=common_seed)
|
dataset_a = RandomDataset(random_seed=common_seed)
|
||||||
dataset_b = RandomDataset(random_seed=common_seed)
|
dataset_b = RandomDataset(random_seed=common_seed)
|
||||||
a = _collect_samples(dataset_a,
|
a = _collect_samples(
|
||||||
|
dataset_a,
|
||||||
hf_tokenizer,
|
hf_tokenizer,
|
||||||
num_requests=p.num_requests,
|
num_requests=p.num_requests,
|
||||||
prefix_len=p.prefix_len,
|
prefix_len=p.prefix_len,
|
||||||
range_ratio=p.range_ratio,
|
range_ratio=p.range_ratio,
|
||||||
input_len=p.input_len,
|
input_len=p.input_len,
|
||||||
output_len=p.output_len)
|
output_len=p.output_len,
|
||||||
|
)
|
||||||
|
|
||||||
# Perturb global RNG state to ensure isolation
|
# Perturb global RNG state to ensure isolation
|
||||||
random.seed(999)
|
random.seed(999)
|
||||||
@@ -84,43 +89,50 @@ def test_random_dataset_same_seed(
|
|||||||
np.random.seed(888)
|
np.random.seed(888)
|
||||||
_ = [np.random.random() for _ in range(100)]
|
_ = [np.random.random() for _ in range(100)]
|
||||||
|
|
||||||
b = _collect_samples(dataset_b,
|
b = _collect_samples(
|
||||||
|
dataset_b,
|
||||||
hf_tokenizer,
|
hf_tokenizer,
|
||||||
num_requests=p.num_requests,
|
num_requests=p.num_requests,
|
||||||
prefix_len=p.prefix_len,
|
prefix_len=p.prefix_len,
|
||||||
range_ratio=p.range_ratio,
|
range_ratio=p.range_ratio,
|
||||||
input_len=p.input_len,
|
input_len=p.input_len,
|
||||||
output_len=p.output_len)
|
output_len=p.output_len,
|
||||||
|
)
|
||||||
assert a == b
|
assert a == b
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_random_dataset_different_seeds(
|
def test_random_dataset_different_seeds(
|
||||||
hf_tokenizer: PreTrainedTokenizerBase,
|
hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params
|
||||||
random_dataset_params: Params) -> None:
|
) -> None:
|
||||||
"""Different seeds should change outputs with overwhelming likelihood."""
|
"""Different seeds should change outputs with overwhelming likelihood."""
|
||||||
p = random_dataset_params
|
p = random_dataset_params
|
||||||
seed_a = 0
|
seed_a = 0
|
||||||
dataset_a = RandomDataset(random_seed=seed_a)
|
dataset_a = RandomDataset(random_seed=seed_a)
|
||||||
a = _collect_samples(dataset_a,
|
a = _collect_samples(
|
||||||
|
dataset_a,
|
||||||
hf_tokenizer,
|
hf_tokenizer,
|
||||||
num_requests=p.num_requests,
|
num_requests=p.num_requests,
|
||||||
prefix_len=p.prefix_len,
|
prefix_len=p.prefix_len,
|
||||||
range_ratio=p.range_ratio,
|
range_ratio=p.range_ratio,
|
||||||
input_len=p.input_len,
|
input_len=p.input_len,
|
||||||
output_len=p.output_len)
|
output_len=p.output_len,
|
||||||
|
)
|
||||||
|
|
||||||
seed_b = 999
|
seed_b = 999
|
||||||
dataset_b = RandomDataset(random_seed=seed_b)
|
dataset_b = RandomDataset(random_seed=seed_b)
|
||||||
# Perturb global RNG with same seed as dataset_a to ensure isolation
|
# Perturb global RNG with same seed as dataset_a to ensure isolation
|
||||||
random.seed(seed_a)
|
random.seed(seed_a)
|
||||||
np.random.seed(seed_a)
|
np.random.seed(seed_a)
|
||||||
b = _collect_samples(dataset_b,
|
b = _collect_samples(
|
||||||
|
dataset_b,
|
||||||
hf_tokenizer,
|
hf_tokenizer,
|
||||||
num_requests=p.num_requests,
|
num_requests=p.num_requests,
|
||||||
prefix_len=p.prefix_len,
|
prefix_len=p.prefix_len,
|
||||||
range_ratio=p.range_ratio,
|
range_ratio=p.range_ratio,
|
||||||
input_len=p.input_len,
|
input_len=p.input_len,
|
||||||
output_len=p.output_len)
|
output_len=p.output_len,
|
||||||
|
)
|
||||||
assert a != b
|
assert a != b
|
||||||
|
|
||||||
|
|
||||||
@@ -128,6 +140,7 @@ def test_random_dataset_different_seeds(
|
|||||||
# RandomMultiModalDataset tests
|
# RandomMultiModalDataset tests
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
def _mm_fingerprint_sample(
|
def _mm_fingerprint_sample(
|
||||||
req: SampleRequest,
|
req: SampleRequest,
|
||||||
) -> tuple[str, int, int, int, list[str]]:
|
) -> tuple[str, int, int, int, list[str]]:
|
||||||
@@ -152,8 +165,13 @@ def _mm_fingerprint_sample(
|
|||||||
item_prefixes.append(f"video:{url[:22]}")
|
item_prefixes.append(f"video:{url[:22]}")
|
||||||
else:
|
else:
|
||||||
item_prefixes.append("unknown:")
|
item_prefixes.append("unknown:")
|
||||||
return (req.prompt, req.prompt_len, req.expected_output_len, len(items),
|
return (
|
||||||
item_prefixes)
|
req.prompt,
|
||||||
|
req.prompt_len,
|
||||||
|
req.expected_output_len,
|
||||||
|
len(items),
|
||||||
|
item_prefixes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _collect_mm_samples(
|
def _collect_mm_samples(
|
||||||
@@ -214,6 +232,7 @@ def test_random_mm_different_seeds(
|
|||||||
fb = [_mm_fingerprint_sample(s) for s in b]
|
fb = [_mm_fingerprint_sample(s) for s in b]
|
||||||
assert fa != fb
|
assert fa != fb
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_random_mm_respects_limits(
|
def test_random_mm_respects_limits(
|
||||||
hf_tokenizer: PreTrainedTokenizerBase,
|
hf_tokenizer: PreTrainedTokenizerBase,
|
||||||
@@ -271,9 +290,9 @@ def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
|||||||
for s in samples:
|
for s in samples:
|
||||||
assert s.multi_modal_data == []
|
assert s.multi_modal_data == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_random_mm_num_items_per_prompt(
|
def test_random_mm_num_items_per_prompt(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||||
hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
|
||||||
ds = RandomMultiModalDataset(random_seed=0)
|
ds = RandomMultiModalDataset(random_seed=0)
|
||||||
# Fixed number of images per prompt
|
# Fixed number of images per prompt
|
||||||
# set num_mm_items_range_ratio to 0.0
|
# set num_mm_items_range_ratio to 0.0
|
||||||
@@ -300,7 +319,6 @@ def test_random_mm_num_items_per_prompt(
|
|||||||
def test_random_mm_bucket_config_not_mutated(
|
def test_random_mm_bucket_config_not_mutated(
|
||||||
hf_tokenizer: PreTrainedTokenizerBase,
|
hf_tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
ds = RandomMultiModalDataset(random_seed=0)
|
ds = RandomMultiModalDataset(random_seed=0)
|
||||||
# This bucket config is not normalized to sum to 1
|
# This bucket config is not normalized to sum to 1
|
||||||
# and has more buckets than requested images
|
# and has more buckets than requested images
|
||||||
@@ -321,7 +339,6 @@ def test_random_mm_bucket_config_not_mutated(
|
|||||||
# Ensure the original dict content is unchanged
|
# Ensure the original dict content is unchanged
|
||||||
assert original == snapshot
|
assert original == snapshot
|
||||||
|
|
||||||
|
|
||||||
# Vary number of mm items per prompt
|
# Vary number of mm items per prompt
|
||||||
# set num_mm_items_range_ratio to 0.5
|
# set num_mm_items_range_ratio to 0.5
|
||||||
samples_varying_items = _collect_mm_samples(
|
samples_varying_items = _collect_mm_samples(
|
||||||
|
|||||||
@@ -11,9 +11,7 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server():
|
def server():
|
||||||
args = [
|
args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"]
|
||||||
"--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"
|
|
||||||
]
|
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
yield remote_server
|
yield remote_server
|
||||||
@@ -46,6 +44,7 @@ def test_bench_serve(server):
|
|||||||
|
|
||||||
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_bench_serve_chat(server):
|
def test_bench_serve_chat(server):
|
||||||
command = [
|
command = [
|
||||||
|
|||||||
@@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
|||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_bench_throughput():
|
def test_bench_throughput():
|
||||||
command = [
|
command = [
|
||||||
"vllm", "bench", "throughput", "--model", MODEL_NAME, "--input-len",
|
"vllm",
|
||||||
"32", "--output-len", "1", "--enforce-eager", "--load-format", "dummy"
|
"bench",
|
||||||
|
"throughput",
|
||||||
|
"--model",
|
||||||
|
MODEL_NAME,
|
||||||
|
"--input-len",
|
||||||
|
"32",
|
||||||
|
"--output-len",
|
||||||
|
"1",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--load-format",
|
||||||
|
"dummy",
|
||||||
]
|
]
|
||||||
result = subprocess.run(command, capture_output=True, text=True)
|
result = subprocess.run(command, capture_output=True, text=True)
|
||||||
print(result.stdout)
|
print(result.stdout)
|
||||||
|
|||||||
@@ -23,8 +23,7 @@ class LazyInitPass(InductorPass):
|
|||||||
and then immediately invoke it.
|
and then immediately invoke it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pass_cls: type[VllmInductorPass],
|
def __init__(self, pass_cls: type[VllmInductorPass], vllm_config: VllmConfig):
|
||||||
vllm_config: VllmConfig):
|
|
||||||
self.pass_cls = pass_cls
|
self.pass_cls = pass_cls
|
||||||
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
|
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
|
||||||
|
|
||||||
@@ -45,20 +44,18 @@ class TestBackend:
|
|||||||
Inductor config is default-initialized from VllmConfig.CompilationConfig.
|
Inductor config is default-initialized from VllmConfig.CompilationConfig.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
|
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]):
|
||||||
None]]):
|
|
||||||
self.custom_passes = list(passes)
|
self.custom_passes = list(passes)
|
||||||
compile_config = get_current_vllm_config().compilation_config
|
compile_config = get_current_vllm_config().compilation_config
|
||||||
self.inductor_config = compile_config.inductor_compile_config
|
self.inductor_config = compile_config.inductor_compile_config
|
||||||
self.inductor_config['force_disable_caches'] = True
|
self.inductor_config["force_disable_caches"] = True
|
||||||
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
|
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
|
||||||
|
|
||||||
def __call__(self, graph: fx.GraphModule, example_inputs):
|
def __call__(self, graph: fx.GraphModule, example_inputs):
|
||||||
self.graph_pre_compile = deepcopy(graph)
|
self.graph_pre_compile = deepcopy(graph)
|
||||||
from torch._inductor.compile_fx import compile_fx
|
from torch._inductor.compile_fx import compile_fx
|
||||||
return compile_fx(graph,
|
|
||||||
example_inputs,
|
return compile_fx(graph, example_inputs, config_patches=self.inductor_config)
|
||||||
config_patches=self.inductor_config)
|
|
||||||
|
|
||||||
@with_pattern_match_debug
|
@with_pattern_match_debug
|
||||||
def post_pass(self, graph: fx.Graph):
|
def post_pass(self, graph: fx.Graph):
|
||||||
@@ -82,8 +79,7 @@ class TestBackend:
|
|||||||
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
|
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
|
||||||
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
|
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
|
||||||
if fully_replaced:
|
if fully_replaced:
|
||||||
assert num_post == 0, \
|
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
|
||||||
f"Unexpected op {op.name()} in post-pass graph"
|
|
||||||
|
|
||||||
def check_after_ops(self, ops: Sequence[OpOverload]):
|
def check_after_ops(self, ops: Sequence[OpOverload]):
|
||||||
for op in ops:
|
for op in ops:
|
||||||
|
|||||||
@@ -38,8 +38,8 @@ test_params_full_cudagraph = []
|
|||||||
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
|
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
|
||||||
for mla_backend in MLA_backends:
|
for mla_backend in MLA_backends:
|
||||||
test_params_full_cudagraph.append(
|
test_params_full_cudagraph.append(
|
||||||
pytest.param(
|
pytest.param(("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]))
|
||||||
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))
|
)
|
||||||
|
|
||||||
# Qwen/Qwen2-1.5B-Instruct with other backends
|
# Qwen/Qwen2-1.5B-Instruct with other backends
|
||||||
other_backend_configs = [
|
other_backend_configs = [
|
||||||
@@ -47,7 +47,8 @@ other_backend_configs = [
|
|||||||
]
|
]
|
||||||
for backend_config in other_backend_configs:
|
for backend_config in other_backend_configs:
|
||||||
test_params_full_cudagraph.append(
|
test_params_full_cudagraph.append(
|
||||||
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))
|
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="class")
|
@pytest.fixture(scope="class")
|
||||||
@@ -55,8 +56,10 @@ def llm_pair(request):
|
|||||||
model, backend_config = request.param
|
model, backend_config = request.param
|
||||||
|
|
||||||
# Dynamically skip test if GPU capability is not met
|
# Dynamically skip test if GPU capability is not met
|
||||||
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
|
if (
|
||||||
!= current_platform.get_device_capability():
|
backend_config.specific_gpu_arch
|
||||||
|
and backend_config.specific_gpu_arch != current_platform.get_device_capability()
|
||||||
|
):
|
||||||
if backend_config.specific_gpu_arch == (9, 0):
|
if backend_config.specific_gpu_arch == (9, 0):
|
||||||
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
||||||
elif backend_config.specific_gpu_arch == (10, 0):
|
elif backend_config.specific_gpu_arch == (10, 0):
|
||||||
@@ -76,8 +79,7 @@ def llm_pair(request):
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
max_num_seqs=128,
|
max_num_seqs=128,
|
||||||
compilation_config=\
|
compilation_config=CompilationConfig(**backend_config.comp_config),
|
||||||
CompilationConfig(**backend_config.comp_config),
|
|
||||||
generation_config="vllm",
|
generation_config="vllm",
|
||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
@@ -113,7 +115,9 @@ class TestFullCUDAGraph:
|
|||||||
meaning there would be multiple LLM instances hogging memory simultaneously.
|
meaning there would be multiple LLM instances hogging memory simultaneously.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
|
@pytest.mark.parametrize(
|
||||||
|
("batch_size", "max_tokens"),
|
||||||
|
[
|
||||||
(1, 10),
|
(1, 10),
|
||||||
(7, 10),
|
(7, 10),
|
||||||
(16, 10),
|
(16, 10),
|
||||||
@@ -124,9 +128,9 @@ class TestFullCUDAGraph:
|
|||||||
(123, 10),
|
(123, 10),
|
||||||
(8, 5),
|
(8, 5),
|
||||||
(8, 30),
|
(8, 30),
|
||||||
])
|
],
|
||||||
def test_full_cudagraph(self, batch_size, max_tokens,
|
)
|
||||||
llm_pair: tuple[LLM, LLM]):
|
def test_full_cudagraph(self, batch_size, max_tokens, llm_pair: tuple[LLM, LLM]):
|
||||||
"""
|
"""
|
||||||
Test various batch sizes and max_tokens to ensure that the
|
Test various batch sizes and max_tokens to ensure that the
|
||||||
full cudagraph compilation works for padded cases too.
|
full cudagraph compilation works for padded cases too.
|
||||||
@@ -137,26 +141,34 @@ class TestFullCUDAGraph:
|
|||||||
prompts = ["the quick brown fox"] * batch_size
|
prompts = ["the quick brown fox"] * batch_size
|
||||||
# Use purely greedy decoding to avoid top-p truncation sensitivity
|
# Use purely greedy decoding to avoid top-p truncation sensitivity
|
||||||
# that can amplify tiny numeric differences across runtimes.
|
# that can amplify tiny numeric differences across runtimes.
|
||||||
sampling_params = SamplingParams(temperature=0.0,
|
sampling_params = SamplingParams(
|
||||||
max_tokens=max_tokens,
|
temperature=0.0, max_tokens=max_tokens, top_p=1.0
|
||||||
top_p=1.0)
|
)
|
||||||
|
|
||||||
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
||||||
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
# Check that all responses are the same
|
# Check that all responses are the same
|
||||||
for piecewise_res, full_res in zip(piecewise_responses,
|
for piecewise_res, full_res in zip(piecewise_responses, full_responses):
|
||||||
full_responses):
|
assert (
|
||||||
assert piecewise_res.outputs[0].text.lower() == \
|
piecewise_res.outputs[0].text.lower()
|
||||||
full_res.outputs[0].text.lower()
|
== full_res.outputs[0].text.lower()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||||
def test_full_cudagraph_with_invalid_backend():
|
def test_full_cudagraph_with_invalid_backend():
|
||||||
with temporary_environ({
|
with (
|
||||||
|
temporary_environ(
|
||||||
|
{
|
||||||
"VLLM_USE_V1": "1",
|
"VLLM_USE_V1": "1",
|
||||||
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
|
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION",
|
||||||
# Flex_Attention is not supported with full cuda graph
|
# Flex_Attention is not supported with full cuda graph
|
||||||
}), pytest.raises(RuntimeError):
|
}
|
||||||
LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
),
|
||||||
compilation_config=CompilationConfig(cudagraph_mode="FULL"))
|
pytest.raises(RuntimeError),
|
||||||
|
):
|
||||||
|
LLM(
|
||||||
|
model="Qwen/Qwen2-1.5B-Instruct",
|
||||||
|
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
|
||||||
|
)
|
||||||
|
|||||||
@@ -10,10 +10,14 @@ from torch import nn
|
|||||||
|
|
||||||
from vllm.compilation.backends import set_model_tag
|
from vllm.compilation.backends import set_model_tag
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
|
||||||
support_torch_compile)
|
from vllm.config import (
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
CompilationConfig,
|
||||||
VllmConfig, set_current_vllm_config)
|
CompilationLevel,
|
||||||
|
CUDAGraphMode,
|
||||||
|
VllmConfig,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
|
|
||||||
# This import automatically registers `torch.ops.silly.attention`
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
@@ -27,12 +31,7 @@ RANDOM_SEED = 0
|
|||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class ParentModel(nn.Module):
|
class ParentModel(nn.Module):
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||||
def __init__(self,
|
|
||||||
*,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
prefix: str = '',
|
|
||||||
**kwargs) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -40,7 +39,6 @@ class ParentModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
|
|
||||||
def __init__(self, mlp_size: int, hidden_size: int) -> None:
|
def __init__(self, mlp_size: int, hidden_size: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False)
|
self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False)
|
||||||
@@ -51,17 +49,21 @@ class Attention(nn.Module):
|
|||||||
nn.init.xavier_normal_(
|
nn.init.xavier_normal_(
|
||||||
self.pre_attn.weight.data,
|
self.pre_attn.weight.data,
|
||||||
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
||||||
gain=0.001)
|
gain=0.001,
|
||||||
|
)
|
||||||
nn.init.xavier_normal_(
|
nn.init.xavier_normal_(
|
||||||
self.post_attn.weight.data,
|
self.post_attn.weight.data,
|
||||||
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
||||||
gain=0.001)
|
gain=0.001,
|
||||||
|
)
|
||||||
|
|
||||||
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
|
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x_f32 = x.float()
|
x_f32 = x.float()
|
||||||
return (x_f32 * torch.rsqrt(
|
return (
|
||||||
torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) *
|
x_f32
|
||||||
self.rms_norm_weight).to(x.dtype)
|
* torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6)
|
||||||
|
* self.rms_norm_weight
|
||||||
|
).to(x.dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = self.pre_attn(x)
|
x = self.pre_attn(x)
|
||||||
@@ -76,14 +78,15 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class CompiledAttention(nn.Module):
|
class CompiledAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self,
|
||||||
*,
|
*,
|
||||||
mlp_size: int,
|
mlp_size: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
prefix: str = '',
|
prefix: str = "",
|
||||||
**kwargs) -> None:
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attn = Attention(mlp_size, hidden_size)
|
self.attn = Attention(mlp_size, hidden_size)
|
||||||
|
|
||||||
@@ -93,21 +96,21 @@ class CompiledAttention(nn.Module):
|
|||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class CompiledAttentionTwo(CompiledAttention):
|
class CompiledAttentionTwo(CompiledAttention):
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self.attn(x) + x
|
return self.attn(x) + x
|
||||||
|
|
||||||
|
|
||||||
@ignore_torch_compile
|
@ignore_torch_compile
|
||||||
class SimpleModelWithTwoGraphs(ParentModel):
|
class SimpleModelWithTwoGraphs(ParentModel):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self,
|
||||||
*,
|
*,
|
||||||
mlp_size: int,
|
mlp_size: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
prefix: str = '',
|
prefix: str = "",
|
||||||
**kwargs) -> None:
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
# Test will fail without set_model_tag here with error:
|
# Test will fail without set_model_tag here with error:
|
||||||
# "ValueError: too many values to unpack (expected 3)"
|
# "ValueError: too many values to unpack (expected 3)"
|
||||||
@@ -142,32 +145,45 @@ class SimpleModelWithTwoGraphs(ParentModel):
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor,
|
def run_model(
|
||||||
cudagraph_runtime_mode: CUDAGraphMode):
|
vllm_config: VllmConfig,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
cudagraph_runtime_mode: CUDAGraphMode,
|
||||||
|
):
|
||||||
with set_forward_context({}, vllm_config=vllm_config):
|
with set_forward_context({}, vllm_config=vllm_config):
|
||||||
# warmup for the model with cudagraph_mode NONE
|
# warmup for the model with cudagraph_mode NONE
|
||||||
model(inputs)
|
model(inputs)
|
||||||
|
|
||||||
# simulate cudagraphs capturing
|
# simulate cudagraphs capturing
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
|
{},
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=BatchDescriptor(
|
batch_descriptor=BatchDescriptor(
|
||||||
num_tokens=2, )):
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(inputs[:2])
|
model(inputs[:2])
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
|
{},
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=BatchDescriptor(
|
batch_descriptor=BatchDescriptor(
|
||||||
num_tokens=1, )):
|
num_tokens=1,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(inputs[:1])
|
model(inputs[:1])
|
||||||
|
|
||||||
# simulate cudagraphs replay
|
# simulate cudagraphs replay
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
|
{},
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=BatchDescriptor(
|
batch_descriptor=BatchDescriptor(
|
||||||
num_tokens=2, )):
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
output = model(inputs[:2])
|
output = model(inputs[:2])
|
||||||
|
|
||||||
output = output.cpu()
|
output = output.cpu()
|
||||||
@@ -178,19 +194,27 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
# piecewise compile
|
# piecewise compile
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
splitting_ops=["silly.attention"],
|
splitting_ops=["silly.attention"],
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
model = (
|
||||||
|
SimpleModelWithTwoGraphs(
|
||||||
|
mlp_size=MLP_SIZE,
|
||||||
hidden_size=HIDDEN_SIZE,
|
hidden_size=HIDDEN_SIZE,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
prefix='').eval().cuda()
|
prefix="",
|
||||||
|
)
|
||||||
|
.eval()
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
|
|
||||||
# Pre-allocate memory for CUDAGraph which expects
|
# Pre-allocate memory for CUDAGraph which expects
|
||||||
# static tensor addresses
|
# static tensor addresses
|
||||||
@@ -207,19 +231,27 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
num_cudagraph_captured=8,
|
num_cudagraph_captured=8,
|
||||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
|
||||||
|
|
||||||
# no compile or cudagraph
|
# no compile or cudagraph
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
level=CompilationLevel.NO_COMPILATION, ))
|
compilation_config=CompilationConfig(
|
||||||
|
level=CompilationLevel.NO_COMPILATION,
|
||||||
|
)
|
||||||
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
model = (
|
||||||
|
SimpleModelWithTwoGraphs(
|
||||||
|
mlp_size=MLP_SIZE,
|
||||||
hidden_size=HIDDEN_SIZE,
|
hidden_size=HIDDEN_SIZE,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
prefix='').eval().cuda()
|
prefix="",
|
||||||
|
)
|
||||||
|
.eval()
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=0,
|
num_graphs_seen=0,
|
||||||
@@ -228,22 +260,29 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
num_backend_compilations=0,
|
num_backend_compilations=0,
|
||||||
num_cudagraph_captured=0,
|
num_cudagraph_captured=0,
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
|
||||||
|
|
||||||
# piecewise compile without CUDA graph
|
# piecewise compile without CUDA graph
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=False,
|
use_cudagraph=False,
|
||||||
splitting_ops=["silly.attention"],
|
splitting_ops=["silly.attention"],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
model = (
|
||||||
|
SimpleModelWithTwoGraphs(
|
||||||
|
mlp_size=MLP_SIZE,
|
||||||
hidden_size=HIDDEN_SIZE,
|
hidden_size=HIDDEN_SIZE,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
prefix='').eval().cuda()
|
prefix="",
|
||||||
|
)
|
||||||
|
.eval()
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=2,
|
num_graphs_seen=2,
|
||||||
@@ -252,8 +291,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
num_backend_compilations=4,
|
num_backend_compilations=4,
|
||||||
num_cudagraph_captured=0, # no cudagraph captured
|
num_cudagraph_captured=0, # no cudagraph captured
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
|
||||||
|
|
||||||
# Generally don't expect outputs with and without inductor
|
# Generally don't expect outputs with and without inductor
|
||||||
# to be bitwise equivalent
|
# to be bitwise equivalent
|
||||||
|
|||||||
@@ -11,8 +11,13 @@ from torch import nn
|
|||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
from vllm.config import (
|
||||||
VllmConfig, set_current_vllm_config)
|
CompilationConfig,
|
||||||
|
CompilationLevel,
|
||||||
|
CUDAGraphMode,
|
||||||
|
VllmConfig,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
from vllm.envs import VLLM_USE_V1
|
from vllm.envs import VLLM_USE_V1
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.utils import is_torch_equal_or_newer
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
@@ -23,12 +28,7 @@ from ..silly_attention import get_global_counter, reset_global_counter
|
|||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class SillyModel(nn.Module):
|
class SillyModel(nn.Module):
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||||
def __init__(self,
|
|
||||||
*,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
prefix: str = '',
|
|
||||||
**kwargs) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -60,7 +60,8 @@ def _run_simple_model(
|
|||||||
expected_num_backend_compilations,
|
expected_num_backend_compilations,
|
||||||
expected_num_cudagraph_captured,
|
expected_num_cudagraph_captured,
|
||||||
):
|
):
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
use_inductor=use_inductor,
|
use_inductor=use_inductor,
|
||||||
@@ -68,21 +69,23 @@ def _run_simple_model(
|
|||||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||||
cudagraph_copy_inputs=True,
|
cudagraph_copy_inputs=True,
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = SillyModel(vllm_config=vllm_config, prefix='')
|
model = SillyModel(vllm_config=vllm_config, prefix="")
|
||||||
|
|
||||||
inputs = torch.randn(100).cuda()
|
inputs = torch.randn(100).cuda()
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with (
|
||||||
|
compilation_counter.expect(
|
||||||
num_graphs_seen=1, # one graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
||||||
num_piecewise_capturable_graphs_seen=
|
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
|
||||||
expected_num_piecewise_capturable_graphs_seen,
|
|
||||||
num_backend_compilations=expected_num_backend_compilations,
|
num_backend_compilations=expected_num_backend_compilations,
|
||||||
num_cudagraph_captured=expected_num_cudagraph_captured,
|
num_cudagraph_captured=expected_num_cudagraph_captured,
|
||||||
), set_forward_context(None,
|
),
|
||||||
vllm_config=vllm_config): # background context
|
set_forward_context(None, vllm_config=vllm_config),
|
||||||
|
): # background context
|
||||||
# warm up with background context
|
# warm up with background context
|
||||||
model(inputs)
|
model(inputs)
|
||||||
|
|
||||||
@@ -91,13 +94,19 @@ def _run_simple_model(
|
|||||||
None,
|
None,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(torch.randn(2).cuda())
|
model(torch.randn(2).cuda())
|
||||||
with set_forward_context(
|
with set_forward_context(
|
||||||
None,
|
None,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||||
batch_descriptor=BatchDescriptor(num_tokens=1, )):
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=1,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(torch.randn(1).cuda())
|
model(torch.randn(1).cuda())
|
||||||
|
|
||||||
input = torch.zeros(2).cuda()
|
input = torch.zeros(2).cuda()
|
||||||
@@ -106,7 +115,10 @@ def _run_simple_model(
|
|||||||
None,
|
None,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
output = model(input)
|
output = model(input)
|
||||||
assert get_global_counter() == 2
|
assert get_global_counter() == 2
|
||||||
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
|
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
|
||||||
@@ -122,10 +134,8 @@ def test_simple_piecewise_compile(use_inductor):
|
|||||||
use_inductor=use_inductor,
|
use_inductor=use_inductor,
|
||||||
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
||||||
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
||||||
expected_num_backend_compilations=
|
expected_num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||||
3, # num_piecewise_capturable_graphs_seen
|
expected_num_cudagraph_captured=6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
expected_num_cudagraph_captured=
|
|
||||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -134,8 +144,7 @@ def test_simple_piecewise_compile(use_inductor):
|
|||||||
def test_simple_inductor_graph_partition(splitting_ops):
|
def test_simple_inductor_graph_partition(splitting_ops):
|
||||||
assert VLLM_USE_V1
|
assert VLLM_USE_V1
|
||||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
pytest.skip("inductor graph partition is only available "
|
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||||
"in PyTorch 2.9+")
|
|
||||||
|
|
||||||
_run_simple_model(
|
_run_simple_model(
|
||||||
# inductor graph partition automatically resets splitting_ops
|
# inductor graph partition automatically resets splitting_ops
|
||||||
@@ -143,13 +152,9 @@ def test_simple_inductor_graph_partition(splitting_ops):
|
|||||||
splitting_ops=splitting_ops,
|
splitting_ops=splitting_ops,
|
||||||
use_inductor_graph_partition=True,
|
use_inductor_graph_partition=True,
|
||||||
use_inductor=True,
|
use_inductor=True,
|
||||||
expected_num_piecewise_graphs_seen=
|
expected_num_piecewise_graphs_seen=1, # since not splitting at fx graph level
|
||||||
1, # since not splitting at fx graph level
|
expected_num_piecewise_capturable_graphs_seen=1, # since not splitting at fx graph level
|
||||||
expected_num_piecewise_capturable_graphs_seen=
|
expected_num_backend_compilations=1, # since not splitting at fx graph level
|
||||||
1, # since not splitting at fx graph level
|
expected_num_cudagraph_captured=6, # inductor graph partition still captures 6
|
||||||
expected_num_backend_compilations=
|
|
||||||
1, # since not splitting at fx graph level
|
|
||||||
expected_num_cudagraph_captured=
|
|
||||||
6, # inductor graph partition still captures 6
|
|
||||||
# graph, same as fx graph partition.
|
# graph, same as fx graph partition.
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ This is a tractable model, the weights and computation are specially designed
|
|||||||
if the config `tractable_init` is set to True. Otherwise, the weights are
|
if the config `tractable_init` is set to True. Otherwise, the weights are
|
||||||
initialized randomly with a fixed seed.
|
initialized randomly with a fixed seed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@@ -17,8 +18,13 @@ from torch import nn
|
|||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
from vllm.config import (
|
||||||
VllmConfig, set_current_vllm_config)
|
CompilationConfig,
|
||||||
|
CompilationLevel,
|
||||||
|
CUDAGraphMode,
|
||||||
|
VllmConfig,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
|
|
||||||
# This import automatically registers `torch.ops.silly.attention`
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
@@ -43,15 +49,14 @@ class LlamaConfig:
|
|||||||
factors.append((k, v))
|
factors.append((k, v))
|
||||||
factors.sort()
|
factors.sort()
|
||||||
import hashlib
|
import hashlib
|
||||||
return hashlib.md5(str(factors).encode(),
|
|
||||||
usedforsecurity=False).hexdigest()
|
return hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert self.mlp_size >= self.hidden_size
|
assert self.mlp_size >= self.hidden_size
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig) -> None:
|
def __init__(self, config: LlamaConfig) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_projection = nn.Linear(
|
self.gate_up_projection = nn.Linear(
|
||||||
@@ -66,31 +71,31 @@ class LlamaMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if config.tractable_init:
|
if config.tractable_init:
|
||||||
nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size])
|
nn.init.eye_(self.gate_up_projection.weight.data[: config.mlp_size])
|
||||||
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:])
|
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size :])
|
||||||
nn.init.eye_(self.down_projection.weight.data)
|
nn.init.eye_(self.down_projection.weight.data)
|
||||||
else:
|
else:
|
||||||
nn.init.xavier_normal_(self.gate_up_projection.weight.data,
|
nn.init.xavier_normal_(
|
||||||
generator=torch.Generator().manual_seed(
|
self.gate_up_projection.weight.data,
|
||||||
config.random_seed),
|
generator=torch.Generator().manual_seed(config.random_seed),
|
||||||
gain=0.001)
|
gain=0.001,
|
||||||
nn.init.xavier_normal_(self.down_projection.weight.data,
|
)
|
||||||
generator=torch.Generator().manual_seed(
|
nn.init.xavier_normal_(
|
||||||
config.random_seed),
|
self.down_projection.weight.data,
|
||||||
gain=0.001)
|
generator=torch.Generator().manual_seed(config.random_seed),
|
||||||
|
gain=0.001,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# for tractable_init and positive input, this is
|
# for tractable_init and positive input, this is
|
||||||
# essentially an elementwise-square
|
# essentially an elementwise-square
|
||||||
x = self.gate_up_projection(x)
|
x = self.gate_up_projection(x)
|
||||||
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
|
x = x[:, : x.size(1) // 2] * torch.nn.functional.relu(x[:, x.size(1) // 2 :])
|
||||||
x[:, x.size(1) // 2:])
|
|
||||||
x = self.down_projection(x)
|
x = self.down_projection(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LlamaAttention(nn.Module):
|
class LlamaAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig) -> None:
|
def __init__(self, config: LlamaConfig) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.qkv_projection = nn.Linear(
|
self.qkv_projection = nn.Linear(
|
||||||
@@ -106,21 +111,25 @@ class LlamaAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if config.tractable_init:
|
if config.tractable_init:
|
||||||
nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size])
|
nn.init.eye_(self.qkv_projection.weight.data[: config.hidden_size])
|
||||||
nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 *
|
nn.init.eye_(
|
||||||
config.hidden_size])
|
self.qkv_projection.weight.data[
|
||||||
nn.init.eye_(self.qkv_projection.weight.data[2 *
|
config.hidden_size : 2 * config.hidden_size
|
||||||
config.hidden_size:])
|
]
|
||||||
|
)
|
||||||
|
nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size :])
|
||||||
nn.init.eye_(self.output_projection.weight.data)
|
nn.init.eye_(self.output_projection.weight.data)
|
||||||
else:
|
else:
|
||||||
nn.init.xavier_normal_(self.qkv_projection.weight.data,
|
nn.init.xavier_normal_(
|
||||||
generator=torch.Generator().manual_seed(
|
self.qkv_projection.weight.data,
|
||||||
config.random_seed),
|
generator=torch.Generator().manual_seed(config.random_seed),
|
||||||
gain=0.001)
|
gain=0.001,
|
||||||
nn.init.xavier_normal_(self.output_projection.weight.data,
|
)
|
||||||
generator=torch.Generator().manual_seed(
|
nn.init.xavier_normal_(
|
||||||
config.random_seed),
|
self.output_projection.weight.data,
|
||||||
gain=0.001)
|
generator=torch.Generator().manual_seed(config.random_seed),
|
||||||
|
gain=0.001,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -144,7 +153,6 @@ class LlamaAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LlamaDecoderLayer(nn.Module):
|
class LlamaDecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig) -> None:
|
def __init__(self, config: LlamaConfig) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attention = LlamaAttention(config)
|
self.self_attention = LlamaAttention(config)
|
||||||
@@ -173,8 +181,9 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = hidden_states + 1
|
hidden_states = hidden_states + 1
|
||||||
|
|
||||||
hidden_states = self.self_attention(positions=positions,
|
hidden_states = self.self_attention(
|
||||||
hidden_states=hidden_states)
|
positions=positions, hidden_states=hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -186,20 +195,22 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class LlamaModel(nn.Module):
|
class LlamaModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self,
|
||||||
*,
|
*,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
config: LlamaConfig,
|
config: LlamaConfig,
|
||||||
prefix: str = '',
|
prefix: str = "",
|
||||||
**kwargs) -> None:
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embedding_tokens = nn.Embedding(
|
self.embedding_tokens = nn.Embedding(
|
||||||
num_embeddings=config.vocab_size,
|
num_embeddings=config.vocab_size,
|
||||||
embedding_dim=config.hidden_size,
|
embedding_dim=config.hidden_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
|
[LlamaDecoderLayer(config) for _ in range(config.num_layers)]
|
||||||
|
)
|
||||||
|
|
||||||
# this is the initial value of the hidden states
|
# this is the initial value of the hidden states
|
||||||
self.embedding_tokens.weight.data.fill_(config.init_value)
|
self.embedding_tokens.weight.data.fill_(config.init_value)
|
||||||
@@ -216,34 +227,39 @@ class LlamaModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def tractable_computation(input_ids: torch.Tensor,
|
def tractable_computation(
|
||||||
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
config: LlamaConfig,
|
config: LlamaConfig,
|
||||||
init_value: float = 1.0) -> torch.Tensor:
|
init_value: float = 1.0,
|
||||||
hidden_states = torch.ones(input_ids.size(0),
|
) -> torch.Tensor:
|
||||||
|
hidden_states = (
|
||||||
|
torch.ones(
|
||||||
|
input_ids.size(0),
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
dtype=input_ids.dtype) * init_value
|
dtype=input_ids.dtype,
|
||||||
|
)
|
||||||
|
* init_value
|
||||||
|
)
|
||||||
|
|
||||||
# first layer
|
# first layer
|
||||||
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
||||||
hidden_states = (residual + 1)**2
|
hidden_states = (residual + 1) ** 2
|
||||||
|
|
||||||
# following layers
|
# following layers
|
||||||
for _ in range(config.num_layers - 1):
|
for _ in range(config.num_layers - 1):
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
||||||
hidden_states = (residual + 1)**2
|
hidden_states = (residual + 1) ** 2
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def run_model(llama_config,
|
def run_model(
|
||||||
use_compile: bool,
|
llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False
|
||||||
use_inductor: bool,
|
) -> torch.Tensor:
|
||||||
split_attn: bool = False) -> torch.Tensor:
|
|
||||||
|
|
||||||
if use_compile:
|
if use_compile:
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
@@ -256,54 +272,66 @@ def run_model(llama_config,
|
|||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
else:
|
else:
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
level=CompilationLevel.NO_COMPILATION, )
|
level=CompilationLevel.NO_COMPILATION,
|
||||||
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=compilation_config,
|
vllm_config = VllmConfig(
|
||||||
additional_config=llama_config)
|
compilation_config=compilation_config, additional_config=llama_config
|
||||||
|
)
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = LlamaModel(config=llama_config,
|
model = (
|
||||||
vllm_config=vllm_config,
|
LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
|
||||||
prefix="").eval().cuda()
|
.eval()
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
|
|
||||||
with set_forward_context({},
|
with set_forward_context({}, vllm_config=vllm_config): # background context
|
||||||
vllm_config=vllm_config): # background context
|
|
||||||
B = 16 # max batch size
|
B = 16 # max batch size
|
||||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
|
||||||
positions = torch.arange(B).cuda()
|
positions = torch.arange(B).cuda()
|
||||||
|
|
||||||
# warmup for the model with cudagraph_mode NONE
|
# warmup for the model with cudagraph_mode NONE
|
||||||
model(input_ids, positions)
|
model(input_ids, positions)
|
||||||
|
|
||||||
# simulate cudagraphs capturing
|
# simulate cudagraphs capturing
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
|
{},
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=BatchDescriptor(
|
batch_descriptor=BatchDescriptor(
|
||||||
num_tokens=2, )):
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(input_ids[:2], positions[:2])
|
model(input_ids[:2], positions[:2])
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
|
{},
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=BatchDescriptor(
|
batch_descriptor=BatchDescriptor(
|
||||||
num_tokens=1, )):
|
num_tokens=1,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(input_ids[:1], positions[:1])
|
model(input_ids[:1], positions[:1])
|
||||||
|
|
||||||
input_ids[:2].zero_()
|
input_ids[:2].zero_()
|
||||||
# simulate cudagraphs replay
|
# simulate cudagraphs replay
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
|
{},
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=BatchDescriptor(
|
batch_descriptor=BatchDescriptor(
|
||||||
num_tokens=2, )):
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
output = model(input_ids[:2], positions[:2])
|
output = model(input_ids[:2], positions[:2])
|
||||||
|
|
||||||
output = output.cpu()
|
output = output.cpu()
|
||||||
|
|
||||||
if llama_config.tractable_init:
|
if llama_config.tractable_init:
|
||||||
expected_output = tractable_computation(input_ids[:2],
|
expected_output = tractable_computation(
|
||||||
positions[:2],
|
input_ids[:2], positions[:2], llama_config
|
||||||
llama_config).cpu()
|
).cpu()
|
||||||
|
|
||||||
assert torch.allclose(output, expected_output)
|
assert torch.allclose(output, expected_output)
|
||||||
else:
|
else:
|
||||||
@@ -314,16 +342,13 @@ def run_model(llama_config,
|
|||||||
def test_toy_llama(use_inductor: bool):
|
def test_toy_llama(use_inductor: bool):
|
||||||
# compare output with and without piecewise compilation
|
# compare output with and without piecewise compilation
|
||||||
|
|
||||||
llama_config = LlamaConfig(hidden_size=128,
|
llama_config = LlamaConfig(
|
||||||
mlp_size=256,
|
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=12
|
||||||
vocab_size=128,
|
)
|
||||||
num_layers=12)
|
|
||||||
|
|
||||||
tractable_config = LlamaConfig(hidden_size=128,
|
tractable_config = LlamaConfig(
|
||||||
mlp_size=256,
|
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
|
||||||
vocab_size=128,
|
)
|
||||||
num_layers=2,
|
|
||||||
tractable_init=True)
|
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
@@ -333,8 +358,7 @@ def test_toy_llama(use_inductor: bool):
|
|||||||
num_backend_compilations=0,
|
num_backend_compilations=0,
|
||||||
num_cudagraph_captured=0,
|
num_cudagraph_captured=0,
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(run_model(llama_config, use_inductor=False, use_compile=False))
|
||||||
run_model(llama_config, use_inductor=False, use_compile=False))
|
|
||||||
run_model(tractable_config, use_inductor=False, use_compile=False)
|
run_model(tractable_config, use_inductor=False, use_compile=False)
|
||||||
|
|
||||||
if use_inductor:
|
if use_inductor:
|
||||||
@@ -347,37 +371,37 @@ def test_toy_llama(use_inductor: bool):
|
|||||||
num_piecewise_graphs_seen=1,
|
num_piecewise_graphs_seen=1,
|
||||||
num_piecewise_capturable_graphs_seen=1,
|
num_piecewise_capturable_graphs_seen=1,
|
||||||
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||||
num_cudagraph_captured=
|
num_cudagraph_captured=2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(
|
||||||
run_model(llama_config,
|
run_model(llama_config, use_inductor=use_inductor, use_compile=True)
|
||||||
use_inductor=use_inductor,
|
)
|
||||||
use_compile=True))
|
|
||||||
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
|
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=1, # one graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
num_piecewise_graphs_seen=2 * llama_config.num_layers +
|
num_piecewise_graphs_seen=2 * llama_config.num_layers + 1, # 2 * num_layers + 1
|
||||||
1, # 2 * num_layers + 1
|
num_piecewise_capturable_graphs_seen=1
|
||||||
num_piecewise_capturable_graphs_seen=1 +
|
+ llama_config.num_layers, # 1 + num_layers
|
||||||
llama_config.num_layers, # 1 + num_layers
|
num_backend_compilations=1
|
||||||
num_backend_compilations=1 +
|
+ llama_config.num_layers, # num_piecewise_capturable_graphs_seen
|
||||||
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
|
num_cudagraph_captured=2
|
||||||
num_cudagraph_captured=2 *
|
* (
|
||||||
(1 + llama_config.num_layers
|
1 + llama_config.num_layers
|
||||||
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(
|
||||||
run_model(llama_config,
|
run_model(
|
||||||
|
llama_config,
|
||||||
use_inductor=use_inductor,
|
use_inductor=use_inductor,
|
||||||
use_compile=True,
|
use_compile=True,
|
||||||
split_attn=True))
|
split_attn=True,
|
||||||
run_model(tractable_config,
|
)
|
||||||
use_inductor=use_inductor,
|
)
|
||||||
use_compile=True,
|
run_model(
|
||||||
split_attn=True)
|
tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(1, len(outputs)):
|
for i in range(1, len(outputs)):
|
||||||
assert torch.allclose(outputs[0], outputs[i])
|
assert torch.allclose(outputs[0], outputs[i])
|
||||||
@@ -388,17 +412,15 @@ def benchmark():
|
|||||||
from triton.testing import do_bench
|
from triton.testing import do_bench
|
||||||
|
|
||||||
# similar to llama 3.1-8B
|
# similar to llama 3.1-8B
|
||||||
llama_config = LlamaConfig(hidden_size=4096,
|
llama_config = LlamaConfig(
|
||||||
mlp_size=14336,
|
hidden_size=4096, mlp_size=14336, vocab_size=128 * 1024, num_layers=32
|
||||||
vocab_size=128 * 1024,
|
)
|
||||||
num_layers=32)
|
|
||||||
|
|
||||||
# a tiny model to measure the overhead
|
# a tiny model to measure the overhead
|
||||||
# of piecewise cudagraph
|
# of piecewise cudagraph
|
||||||
llama_config = LlamaConfig(hidden_size=40,
|
llama_config = LlamaConfig(
|
||||||
mlp_size=80,
|
hidden_size=40, mlp_size=80, vocab_size=128, num_layers=2
|
||||||
vocab_size=128,
|
)
|
||||||
num_layers=2)
|
|
||||||
|
|
||||||
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
|
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
|
||||||
|
|
||||||
@@ -424,12 +446,15 @@ def benchmark():
|
|||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=compilation_config)
|
vllm_config = VllmConfig(compilation_config=compilation_config)
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = LlamaModel(config=llama_config,
|
model = (
|
||||||
vllm_config=vllm_config,
|
LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
|
||||||
prefix="").eval().cuda().to(torch.bfloat16)
|
.eval()
|
||||||
|
.cuda()
|
||||||
|
.to(torch.bfloat16)
|
||||||
|
)
|
||||||
|
|
||||||
B = 256 # max batch size
|
B = 256 # max batch size
|
||||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
|
||||||
positions = torch.arange(B).cuda().to(torch.bfloat16)
|
positions = torch.arange(B).cuda().to(torch.bfloat16)
|
||||||
|
|
||||||
graphs = {}
|
graphs = {}
|
||||||
@@ -451,21 +476,25 @@ def benchmark():
|
|||||||
# and use it later, because it will look up the name `b` in the
|
# and use it later, because it will look up the name `b` in the
|
||||||
# enclosing scope, and the value of `b` will always be 256.
|
# enclosing scope, and the value of `b` will always be 256.
|
||||||
# it is fine here, because we only use the lambda function once.
|
# it is fine here, because we only use the lambda function once.
|
||||||
runtime = do_bench(lambda: graphs[b][0] # noqa
|
runtime = do_bench(
|
||||||
(input_ids[:b], positions[:b])) # noqa
|
lambda: graphs[b][0]( # noqa
|
||||||
|
input_ids[:b], positions[:b]
|
||||||
|
)
|
||||||
|
) # noqa
|
||||||
piecewise_cudagraph_time[b] = runtime
|
piecewise_cudagraph_time[b] = runtime
|
||||||
else:
|
else:
|
||||||
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
|
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
|
||||||
eager_runtime = do_bench(
|
eager_runtime = do_bench(lambda: model(input_ids[:b], positions[:b])) # noqa
|
||||||
lambda: model(input_ids[:b], positions[:b])) # noqa
|
|
||||||
full_cudagraph_time[b] = runtime
|
full_cudagraph_time[b] = runtime
|
||||||
eager_time[b] = eager_runtime
|
eager_time[b] = eager_runtime
|
||||||
|
|
||||||
# print in tabular format
|
# print in tabular format
|
||||||
print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
|
print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
|
||||||
for b in cudagraph_sizes:
|
for b in cudagraph_sizes:
|
||||||
print(f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
|
print(
|
||||||
f"\t{piecewise_cudagraph_time[b]:.3f}")
|
f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
|
||||||
|
f"\t{piecewise_cudagraph_time[b]:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -31,8 +31,9 @@ def reset_global_counter():
|
|||||||
_global_counter = 0
|
_global_counter = 0
|
||||||
|
|
||||||
|
|
||||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
def silly_attention(
|
||||||
out: torch.Tensor) -> None:
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Unified attention implementation that depends on
|
Unified attention implementation that depends on
|
||||||
all inputs and affects the output.
|
all inputs and affects the output.
|
||||||
@@ -47,8 +48,9 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|||||||
out.copy_(q + k + v)
|
out.copy_(q + k + v)
|
||||||
|
|
||||||
|
|
||||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
def silly_attention_fake(
|
||||||
out: torch.Tensor) -> None:
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
|
||||||
|
) -> None:
|
||||||
"""Fake implementation for testing"""
|
"""Fake implementation for testing"""
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -60,5 +62,5 @@ direct_register_custom_op(
|
|||||||
mutates_args=["out"],
|
mutates_args=["out"],
|
||||||
fake_impl=silly_attention_fake,
|
fake_impl=silly_attention_fake,
|
||||||
target_lib=silly_lib,
|
target_lib=silly_lib,
|
||||||
tags=(torch._C.Tag.cudagraph_unsafe, ),
|
tags=(torch._C.Tag.cudagraph_unsafe,),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,18 +8,30 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.collective_fusion import AsyncTPPass
|
from vllm.compilation.collective_fusion import AsyncTPPass
|
||||||
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (
|
||||||
PassConfig, VllmConfig)
|
CompilationConfig,
|
||||||
from vllm.distributed import (tensor_model_parallel_all_gather,
|
DeviceConfig,
|
||||||
tensor_model_parallel_reduce_scatter)
|
ModelConfig,
|
||||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
PassConfig,
|
||||||
initialize_model_parallel)
|
VllmConfig,
|
||||||
|
)
|
||||||
|
from vllm.distributed import (
|
||||||
|
tensor_model_parallel_all_gather,
|
||||||
|
tensor_model_parallel_reduce_scatter,
|
||||||
|
)
|
||||||
|
from vllm.distributed.parallel_state import (
|
||||||
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
from ..models.registry import HF_EXAMPLE_MODELS
|
from ..models.registry import HF_EXAMPLE_MODELS
|
||||||
from ..utils import (compare_two_settings, create_new_process_for_each_test,
|
from ..utils import (
|
||||||
multi_gpu_test)
|
compare_two_settings,
|
||||||
|
create_new_process_for_each_test,
|
||||||
|
multi_gpu_test,
|
||||||
|
)
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
@@ -33,14 +45,13 @@ prompts = [
|
|||||||
|
|
||||||
|
|
||||||
class TestMMRSModel(torch.nn.Module):
|
class TestMMRSModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.gate_proj = torch.nn.Parameter(torch.empty(
|
self.gate_proj = torch.nn.Parameter(
|
||||||
(self.hidden_size * 2, hidden_size)),
|
torch.empty((self.hidden_size * 2, hidden_size)), requires_grad=False
|
||||||
requires_grad=False)
|
)
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||||
|
|
||||||
@@ -66,14 +77,13 @@ class TestMMRSModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestAGMMModel(torch.nn.Module):
|
class TestAGMMModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.weight = torch.nn.Parameter(torch.empty(
|
self.weight = torch.nn.Parameter(
|
||||||
(hidden_size, hidden_size)),
|
torch.empty((hidden_size, hidden_size)), requires_grad=False
|
||||||
requires_grad=False)
|
)
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
torch.nn.init.normal_(self.weight, std=0.02)
|
torch.nn.init.normal_(self.weight, std=0.02)
|
||||||
|
|
||||||
@@ -96,20 +106,21 @@ class TestAGMMModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class _BaseScaledMMModel(torch.nn.Module):
|
class _BaseScaledMMModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\
|
self.weight = (
|
||||||
.contiguous().transpose(0, 1)
|
torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)
|
||||||
|
.contiguous()
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize scale_b for _scaled_mm.
|
# Initialize scale_b for _scaled_mm.
|
||||||
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
|
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
class TestScaledMMRSModel(_BaseScaledMMModel):
|
class TestScaledMMRSModel(_BaseScaledMMModel):
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor):
|
def forward(self, input: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
|
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
|
||||||
@@ -117,11 +128,13 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
|
|||||||
"""
|
"""
|
||||||
fp8_input = input.to(FP8_DTYPE)
|
fp8_input = input.to(FP8_DTYPE)
|
||||||
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
||||||
scaled_mm = torch._scaled_mm(fp8_input,
|
scaled_mm = torch._scaled_mm(
|
||||||
|
fp8_input,
|
||||||
self.weight,
|
self.weight,
|
||||||
scale_a=scale_a,
|
scale_a=scale_a,
|
||||||
scale_b=self.scale_b,
|
scale_b=self.scale_b,
|
||||||
out_dtype=self.dtype)
|
out_dtype=self.dtype,
|
||||||
|
)
|
||||||
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
|
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
|
||||||
return reduce_scatter
|
return reduce_scatter
|
||||||
|
|
||||||
@@ -133,7 +146,6 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
|
|||||||
|
|
||||||
|
|
||||||
class TestAGScaledMMModel(_BaseScaledMMModel):
|
class TestAGScaledMMModel(_BaseScaledMMModel):
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor):
|
def forward(self, input: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Forward pass implementing the all gather + scaled_mm in the FX graph
|
Forward pass implementing the all gather + scaled_mm in the FX graph
|
||||||
@@ -143,11 +155,13 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
|
|||||||
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
|
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
|
||||||
|
|
||||||
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
||||||
scaled_mm = torch._scaled_mm(all_gather,
|
scaled_mm = torch._scaled_mm(
|
||||||
|
all_gather,
|
||||||
self.weight,
|
self.weight,
|
||||||
scale_a=scale_a,
|
scale_a=scale_a,
|
||||||
scale_b=self.scale_b,
|
scale_b=self.scale_b,
|
||||||
out_dtype=self.dtype)
|
out_dtype=self.dtype,
|
||||||
|
)
|
||||||
return scaled_mm
|
return scaled_mm
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
@@ -158,7 +172,6 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
|
|||||||
|
|
||||||
|
|
||||||
class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor):
|
def forward(self, input: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Forward pass implementing the cutlass_scaled_mm + reduce scatter
|
Forward pass implementing the cutlass_scaled_mm + reduce scatter
|
||||||
@@ -167,11 +180,14 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
|||||||
"""
|
"""
|
||||||
fp8_input = input.to(FP8_DTYPE)
|
fp8_input = input.to(FP8_DTYPE)
|
||||||
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
||||||
mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]),
|
mm_out = torch.empty(
|
||||||
|
(fp8_input.shape[0], self.weight.shape[1]),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=input.device)
|
device=input.device,
|
||||||
torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a,
|
)
|
||||||
self.scale_b, None)
|
torch.ops._C.cutlass_scaled_mm(
|
||||||
|
mm_out, fp8_input, self.weight, scale_a, self.scale_b, None
|
||||||
|
)
|
||||||
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
|
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
|
||||||
return reduce_scatter
|
return reduce_scatter
|
||||||
|
|
||||||
@@ -183,7 +199,6 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
|||||||
|
|
||||||
|
|
||||||
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor):
|
def forward(self, input: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Forward pass implementing the all gather + cutlass_scaled_mm
|
Forward pass implementing the all gather + cutlass_scaled_mm
|
||||||
@@ -195,11 +210,14 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
|||||||
|
|
||||||
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
||||||
|
|
||||||
mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]),
|
mm_out = torch.empty(
|
||||||
|
(all_gather.shape[0], self.weight.shape[1]),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=all_gather.device)
|
device=all_gather.device,
|
||||||
torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight,
|
)
|
||||||
scale_a, self.scale_b, None)
|
torch.ops._C.cutlass_scaled_mm(
|
||||||
|
mm_out, all_gather, self.weight, scale_a, self.scale_b, None
|
||||||
|
)
|
||||||
return mm_out
|
return mm_out
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
@@ -210,23 +228,37 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
|||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@pytest.mark.parametrize("test_model", [
|
@pytest.mark.parametrize(
|
||||||
TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel,
|
"test_model",
|
||||||
TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel
|
[
|
||||||
])
|
TestMMRSModel,
|
||||||
|
TestAGMMModel,
|
||||||
|
TestScaledMMRSModel,
|
||||||
|
TestAGScaledMMModel,
|
||||||
|
TestCutlassScaledMMRSModel,
|
||||||
|
TestAGCutlassScaledMMModel,
|
||||||
|
],
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize("seq_len", [16])
|
@pytest.mark.parametrize("seq_len", [16])
|
||||||
@pytest.mark.parametrize("hidden_size", [16])
|
@pytest.mark.parametrize("hidden_size", [16])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||||
reason="Only test on CUDA")
|
def test_async_tp_pass_replace(
|
||||||
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
|
test_model: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype
|
||||||
hidden_size: int, dtype: torch.dtype):
|
):
|
||||||
if test_model in (TestScaledMMRSModel, TestAGScaledMMModel,
|
if (
|
||||||
|
test_model
|
||||||
|
in (
|
||||||
|
TestScaledMMRSModel,
|
||||||
|
TestAGScaledMMModel,
|
||||||
TestCutlassScaledMMRSModel,
|
TestCutlassScaledMMRSModel,
|
||||||
TestAGCutlassScaledMMModel) and dtype == torch.float16:
|
TestAGCutlassScaledMMModel,
|
||||||
|
)
|
||||||
|
and dtype == torch.float16
|
||||||
|
):
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"Only bf16 high precision output types are supported for " \
|
"Only bf16 high precision output types are supported for "
|
||||||
"per-token (row-wise) scaling"
|
"per-token (row-wise) scaling"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -235,19 +267,24 @@ def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
|
|||||||
def run_torch_spawn(fn, nprocs):
|
def run_torch_spawn(fn, nprocs):
|
||||||
# need to use torch.mp.spawn otherwise will have problems with
|
# need to use torch.mp.spawn otherwise will have problems with
|
||||||
# torch.distributed and cuda
|
# torch.distributed and cuda
|
||||||
torch.multiprocessing.spawn(fn,
|
torch.multiprocessing.spawn(
|
||||||
args=(num_processes, test_model,
|
fn,
|
||||||
batch_size, seq_len, hidden_size,
|
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
|
||||||
dtype),
|
nprocs=nprocs,
|
||||||
nprocs=nprocs)
|
)
|
||||||
|
|
||||||
run_torch_spawn(async_tp_pass_on_test_model, num_processes)
|
run_torch_spawn(async_tp_pass_on_test_model, num_processes)
|
||||||
|
|
||||||
|
|
||||||
def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
def async_tp_pass_on_test_model(
|
||||||
|
local_rank: int,
|
||||||
|
world_size: int,
|
||||||
test_model_cls: torch.nn.Module,
|
test_model_cls: torch.nn.Module,
|
||||||
batch_size: int, seq_len: int,
|
batch_size: int,
|
||||||
hidden_size: int, dtype: torch.dtype):
|
seq_len: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
@@ -255,13 +292,15 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
|
|
||||||
update_environment_variables({
|
update_environment_variables(
|
||||||
'RANK': str(local_rank),
|
{
|
||||||
'LOCAL_RANK': str(local_rank),
|
"RANK": str(local_rank),
|
||||||
'WORLD_SIZE': str(world_size),
|
"LOCAL_RANK": str(local_rank),
|
||||||
'MASTER_ADDR': 'localhost',
|
"WORLD_SIZE": str(world_size),
|
||||||
'MASTER_PORT': '12345',
|
"MASTER_ADDR": "localhost",
|
||||||
})
|
"MASTER_PORT": "12345",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# initialize distributed
|
# initialize distributed
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
@@ -269,27 +308,28 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
|
|
||||||
# configure vllm config for SequenceParallelismPass
|
# configure vllm config for SequenceParallelismPass
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
|
vllm_config.compilation_config = CompilationConfig(
|
||||||
enable_async_tp=True, ), )
|
pass_config=PassConfig(
|
||||||
|
enable_async_tp=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||||
|
|
||||||
# this is a fake model name to construct the model config
|
# this is a fake model name to construct the model config
|
||||||
# in the vllm_config, it's not really used.
|
# in the vllm_config, it's not really used.
|
||||||
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||||
vllm_config.model_config = ModelConfig(model=model_name,
|
vllm_config.model_config = ModelConfig(
|
||||||
trust_remote_code=True,
|
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||||
dtype=dtype,
|
)
|
||||||
seed=42)
|
|
||||||
|
|
||||||
async_tp_pass = AsyncTPPass(vllm_config)
|
async_tp_pass = AsyncTPPass(vllm_config)
|
||||||
backend = TestBackend(async_tp_pass)
|
backend = TestBackend(async_tp_pass)
|
||||||
|
|
||||||
model = test_model_cls(hidden_size,
|
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
|
||||||
dtype) # Pass dtype to model constructor
|
|
||||||
|
|
||||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
hidden_states = torch.randn(
|
||||||
dtype=dtype,
|
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
|
||||||
requires_grad=False)
|
)
|
||||||
|
|
||||||
compiled_model = torch.compile(model, backend=backend)
|
compiled_model = torch.compile(model, backend=backend)
|
||||||
compiled_model(hidden_states)
|
compiled_model(hidden_states)
|
||||||
@@ -306,10 +346,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
|
|
||||||
|
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
@pytest.mark.parametrize("model_id", [
|
@pytest.mark.parametrize(
|
||||||
"meta-llama/Llama-3.2-1B-Instruct",
|
"model_id",
|
||||||
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
|
["meta-llama/Llama-3.2-1B-Instruct", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"],
|
||||||
])
|
)
|
||||||
@pytest.mark.parametrize("tp_size", [2])
|
@pytest.mark.parametrize("tp_size", [2])
|
||||||
@pytest.mark.parametrize("async_tp_enabled", [True])
|
@pytest.mark.parametrize("async_tp_enabled", [True])
|
||||||
@pytest.mark.parametrize("distributed_backend", ["mp"])
|
@pytest.mark.parametrize("distributed_backend", ["mp"])
|
||||||
@@ -342,12 +382,10 @@ def test_async_tp_pass_correctness(
|
|||||||
common_args.append("--enforce-eager")
|
common_args.append("--enforce-eager")
|
||||||
|
|
||||||
compilation_config = {
|
compilation_config = {
|
||||||
'level': 3,
|
"level": 3,
|
||||||
'compile_sizes': [2, 4, 8],
|
"compile_sizes": [2, 4, 8],
|
||||||
'splitting_ops': [],
|
"splitting_ops": [],
|
||||||
'pass_config': {
|
"pass_config": {"enable_async_tp": async_tp_enabled},
|
||||||
'enable_async_tp': async_tp_enabled
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async_tp_env = tp_env = {
|
async_tp_env = tp_env = {
|
||||||
@@ -372,9 +410,6 @@ def test_async_tp_pass_correctness(
|
|||||||
"mp",
|
"mp",
|
||||||
]
|
]
|
||||||
|
|
||||||
compare_two_settings(model_id,
|
compare_two_settings(
|
||||||
async_tp_args,
|
model_id, async_tp_args, tp_args, async_tp_env, tp_env, method="generate"
|
||||||
tp_args,
|
)
|
||||||
async_tp_env,
|
|
||||||
tp_env,
|
|
||||||
method="generate")
|
|
||||||
|
|||||||
@@ -103,15 +103,20 @@ def test_compile_correctness(
|
|||||||
attn_backend = test_setting.attn_backend
|
attn_backend = test_setting.attn_backend
|
||||||
method = test_setting.method
|
method = test_setting.method
|
||||||
if cuda_device_count_stateless() < pp_size * tp_size:
|
if cuda_device_count_stateless() < pp_size * tp_size:
|
||||||
pytest.skip(f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
|
pytest.skip(
|
||||||
f"{cuda_device_count_stateless()}")
|
f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
|
||||||
|
f"{cuda_device_count_stateless()}"
|
||||||
|
)
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
final_args = [
|
final_args = [
|
||||||
"--enforce-eager", *model_args, "-pp",
|
"--enforce-eager",
|
||||||
str(pp_size), "-tp",
|
*model_args,
|
||||||
str(tp_size)
|
"-pp",
|
||||||
|
str(pp_size),
|
||||||
|
"-tp",
|
||||||
|
str(tp_size),
|
||||||
]
|
]
|
||||||
|
|
||||||
all_args: list[list[str]] = []
|
all_args: list[list[str]] = []
|
||||||
@@ -130,7 +135,8 @@ def test_compile_correctness(
|
|||||||
model,
|
model,
|
||||||
all_args,
|
all_args,
|
||||||
all_envs,
|
all_envs,
|
||||||
method=method if method != "generate" else "generate_close")
|
method=method if method != "generate" else "generate_close",
|
||||||
|
)
|
||||||
all_envs.clear()
|
all_envs.clear()
|
||||||
all_args.clear()
|
all_args.clear()
|
||||||
|
|
||||||
|
|||||||
@@ -9,11 +9,11 @@ from vllm.utils import _is_torch_equal_or_newer
|
|||||||
|
|
||||||
|
|
||||||
def test_version():
|
def test_version():
|
||||||
assert _is_torch_equal_or_newer('2.8.0.dev20250624+cu128', '2.8.0.dev')
|
assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
|
||||||
assert _is_torch_equal_or_newer('2.8.0a0+gitc82a174', '2.8.0.dev')
|
assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
|
||||||
assert _is_torch_equal_or_newer('2.8.0', '2.8.0.dev')
|
assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
|
||||||
assert _is_torch_equal_or_newer('2.8.1', '2.8.0.dev')
|
assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
|
||||||
assert not _is_torch_equal_or_newer('2.7.1', '2.8.0.dev')
|
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
|
||||||
|
|
||||||
|
|
||||||
def test_use_cudagraphs_dynamic(monkeypatch):
|
def test_use_cudagraphs_dynamic(monkeypatch):
|
||||||
@@ -21,7 +21,7 @@ def test_use_cudagraphs_dynamic(monkeypatch):
|
|||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
assert vllm_config.compilation_config.use_cudagraph
|
assert vllm_config.compilation_config.use_cudagraph
|
||||||
|
|
||||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
assert not vllm_config.compilation_config.use_cudagraph
|
assert not vllm_config.compilation_config.use_cudagraph
|
||||||
|
|
||||||
@@ -44,19 +44,23 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
|
|||||||
assert vllm.envs.VLLM_USE_V1
|
assert vllm.envs.VLLM_USE_V1
|
||||||
|
|
||||||
# Disable multiprocessing so that the counter is in the same process
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val)
|
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
|
||||||
|
|
||||||
compilation_config = {
|
compilation_config = {
|
||||||
"use_cudagraph": False, # speed things up a bit
|
"use_cudagraph": False, # speed things up a bit
|
||||||
}
|
}
|
||||||
with (
|
with (
|
||||||
compilation_counter.expect(num_cache_entries_updated=0,
|
compilation_counter.expect(
|
||||||
num_compiled_artifacts_saved=0),
|
num_cache_entries_updated=0, num_compiled_artifacts_saved=0
|
||||||
|
),
|
||||||
# loading the model causes compilation (if enabled) to happen
|
# loading the model causes compilation (if enabled) to happen
|
||||||
vllm_runner('facebook/opt-125m',
|
vllm_runner(
|
||||||
|
"facebook/opt-125m",
|
||||||
compilation_config=compilation_config,
|
compilation_config=compilation_config,
|
||||||
gpu_memory_utilization=0.4) as _):
|
gpu_memory_utilization=0.4,
|
||||||
|
) as _,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -67,7 +71,7 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
|||||||
assert vllm.envs.VLLM_USE_V1
|
assert vllm.envs.VLLM_USE_V1
|
||||||
|
|
||||||
# Disable multiprocessing so that the counter is in the same process
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
|
|
||||||
compilation_config = {
|
compilation_config = {
|
||||||
"cudagraph_capture_sizes": [100],
|
"cudagraph_capture_sizes": [100],
|
||||||
@@ -80,9 +84,12 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
|||||||
num_cudagraph_captured=13 if enabled else 0,
|
num_cudagraph_captured=13 if enabled else 0,
|
||||||
),
|
),
|
||||||
# loading the model causes compilation (if enabled) to happen
|
# loading the model causes compilation (if enabled) to happen
|
||||||
vllm_runner('facebook/opt-125m',
|
vllm_runner(
|
||||||
|
"facebook/opt-125m",
|
||||||
compilation_config=compilation_config,
|
compilation_config=compilation_config,
|
||||||
gpu_memory_utilization=0.4) as _):
|
gpu_memory_utilization=0.4,
|
||||||
|
) as _,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -90,14 +97,17 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
|||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_dynamo_as_is(vllm_runner, monkeypatch):
|
def test_dynamo_as_is(vllm_runner, monkeypatch):
|
||||||
# Disable multiprocessing so that the counter is in the same process
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
|
|
||||||
with (
|
with (
|
||||||
compilation_counter.expect(dynamo_as_is_count=1),
|
compilation_counter.expect(dynamo_as_is_count=1),
|
||||||
# loading the model causes compilation (if enabled) to happen
|
# loading the model causes compilation (if enabled) to happen
|
||||||
vllm_runner('facebook/opt-125m',
|
vllm_runner(
|
||||||
|
"facebook/opt-125m",
|
||||||
compilation_config={"level": 1},
|
compilation_config={"level": 1},
|
||||||
gpu_memory_utilization=0.4) as _):
|
gpu_memory_utilization=0.4,
|
||||||
|
) as _,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -105,14 +115,16 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
|
|||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_no_compilation(vllm_runner, monkeypatch):
|
def test_no_compilation(vllm_runner, monkeypatch):
|
||||||
# Disable multiprocessing so that the counter is in the same process
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
with (
|
with (
|
||||||
compilation_counter.expect(num_graphs_seen=0,
|
compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
|
||||||
dynamo_as_is_count=0),
|
|
||||||
# loading the model causes compilation (if enabled) to happen
|
# loading the model causes compilation (if enabled) to happen
|
||||||
vllm_runner('facebook/opt-125m',
|
vllm_runner(
|
||||||
|
"facebook/opt-125m",
|
||||||
compilation_config={"level": 0},
|
compilation_config={"level": 0},
|
||||||
gpu_memory_utilization=0.4) as _):
|
gpu_memory_utilization=0.4,
|
||||||
|
) as _,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -120,77 +132,73 @@ def test_no_compilation(vllm_runner, monkeypatch):
|
|||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_enforce_eager(vllm_runner, monkeypatch):
|
def test_enforce_eager(vllm_runner, monkeypatch):
|
||||||
# Disable multiprocessing so that the counter is in the same process
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
|
|
||||||
with (
|
with (
|
||||||
compilation_counter.expect(num_graphs_seen=0,
|
compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
|
||||||
dynamo_as_is_count=0),
|
|
||||||
# loading the model causes compilation (if enabled) to happen
|
# loading the model causes compilation (if enabled) to happen
|
||||||
vllm_runner('facebook/opt-125m',
|
vllm_runner(
|
||||||
enforce_eager=True,
|
"facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
|
||||||
gpu_memory_utilization=0.4) as _):
|
) as _,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_splitting_ops_dynamic():
|
def test_splitting_ops_dynamic():
|
||||||
# Default config
|
# Default config
|
||||||
config = VllmConfig()
|
config = VllmConfig()
|
||||||
assert config.compilation_config.cudagraph_mode == \
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||||
CUDAGraphMode.FULL_AND_PIECEWISE
|
|
||||||
assert config.compilation_config.splitting_ops_contain_attention()
|
assert config.compilation_config.splitting_ops_contain_attention()
|
||||||
|
|
||||||
# When use_inductor_graph_partition=True
|
# When use_inductor_graph_partition=True
|
||||||
if _is_torch_equal_or_newer('2.9.0.dev'):
|
if _is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
# inductor graph partition is only available in PyTorch 2.9+.
|
# inductor graph partition is only available in PyTorch 2.9+.
|
||||||
# this is a fast config check so we are not using pytest.skip.
|
# this is a fast config check so we are not using pytest.skip.
|
||||||
config = VllmConfig(compilation_config=CompilationConfig(
|
config = VllmConfig(
|
||||||
use_inductor_graph_partition=True,
|
compilation_config=CompilationConfig(
|
||||||
splitting_ops=["silly_attention"]))
|
use_inductor_graph_partition=True, splitting_ops=["silly_attention"]
|
||||||
|
)
|
||||||
|
)
|
||||||
# should ignore splitting_ops
|
# should ignore splitting_ops
|
||||||
assert config.compilation_config.splitting_ops == []
|
assert config.compilation_config.splitting_ops == []
|
||||||
|
|
||||||
# When attn_fusion pass enabled.
|
# When attn_fusion pass enabled.
|
||||||
config = VllmConfig(compilation_config=CompilationConfig(
|
config = VllmConfig(
|
||||||
pass_config={
|
compilation_config=CompilationConfig(
|
||||||
"enable_attn_fusion": True,
|
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||||
"enable_noop": True
|
|
||||||
},
|
|
||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
assert config.compilation_config.splitting_ops == []
|
assert config.compilation_config.splitting_ops == []
|
||||||
# cudagraph mode also fall back to FULL
|
# cudagraph mode also fall back to FULL
|
||||||
assert config.compilation_config.cudagraph_mode == \
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
||||||
CUDAGraphMode.FULL
|
|
||||||
|
|
||||||
# splitting_ops can not contain attention ops when attn_fusion
|
# splitting_ops can not contain attention ops when attn_fusion
|
||||||
# pass enabled.
|
# pass enabled.
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = VllmConfig(compilation_config=CompilationConfig(
|
config = VllmConfig(
|
||||||
pass_config={
|
compilation_config=CompilationConfig(
|
||||||
"enable_attn_fusion": True,
|
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||||
"enable_noop": True
|
|
||||||
},
|
|
||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
# work around for accessing all attntion ops
|
# work around for accessing all attntion ops
|
||||||
splitting_ops=CompilationConfig()._attention_ops,
|
splitting_ops=CompilationConfig()._attention_ops,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
||||||
if _is_torch_equal_or_newer('2.9.0.dev'):
|
if _is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
config = VllmConfig(compilation_config=CompilationConfig(
|
config = VllmConfig(
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
use_inductor_graph_partition=True,
|
use_inductor_graph_partition=True,
|
||||||
pass_config={
|
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||||
"enable_attn_fusion": True,
|
|
||||||
"enable_noop": True
|
|
||||||
},
|
|
||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
assert config.compilation_config.splitting_ops == []
|
assert config.compilation_config.splitting_ops == []
|
||||||
# enable_attn_fusion is directly support under
|
# enable_attn_fusion is directly support under
|
||||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||||
# is unchanged.
|
# is unchanged.
|
||||||
assert config.compilation_config.cudagraph_mode == \
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||||
CUDAGraphMode.PIECEWISE
|
|
||||||
|
|||||||
@@ -4,10 +4,15 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
|
||||||
support_torch_compile)
|
from vllm.config import (
|
||||||
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
CacheConfig,
|
||||||
CUDAGraphMode, VllmConfig, set_current_vllm_config)
|
CompilationConfig,
|
||||||
|
CompilationLevel,
|
||||||
|
CUDAGraphMode,
|
||||||
|
VllmConfig,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
|
|
||||||
# This import automatically registers `torch.ops.silly.attention`
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
@@ -18,32 +23,42 @@ MLP_SIZE = 128
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def run_model(vllm_config: VllmConfig, model: nn.Module,
|
def run_model(
|
||||||
cudagraph_runtime_mode: CUDAGraphMode):
|
vllm_config: VllmConfig, model: nn.Module, cudagraph_runtime_mode: CUDAGraphMode
|
||||||
|
):
|
||||||
with set_forward_context({}, vllm_config=vllm_config):
|
with set_forward_context({}, vllm_config=vllm_config):
|
||||||
# warmup for the model with cudagraph_mode NONE
|
# warmup for the model with cudagraph_mode NONE
|
||||||
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||||
|
|
||||||
# simulate cudagraphs capturing
|
# simulate cudagraphs capturing
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
|
{},
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=BatchDescriptor(
|
batch_descriptor=BatchDescriptor(
|
||||||
num_tokens=2, )):
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(torch.randn(2, MLP_SIZE).cuda())
|
model(torch.randn(2, MLP_SIZE).cuda())
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
|
{},
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=BatchDescriptor(
|
batch_descriptor=BatchDescriptor(
|
||||||
num_tokens=1, )):
|
num_tokens=1,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(torch.randn(1, MLP_SIZE).cuda())
|
model(torch.randn(1, MLP_SIZE).cuda())
|
||||||
|
|
||||||
# simulate cudagraphs replay
|
# simulate cudagraphs replay
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
|
{},
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=BatchDescriptor(
|
batch_descriptor=BatchDescriptor(
|
||||||
num_tokens=2, )):
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
output = model(torch.randn(2, MLP_SIZE).cuda())
|
output = model(torch.randn(2, MLP_SIZE).cuda())
|
||||||
|
|
||||||
output = output.cpu()
|
output = output.cpu()
|
||||||
@@ -52,22 +67,21 @@ def run_model(vllm_config: VllmConfig, model: nn.Module,
|
|||||||
|
|
||||||
def test_ignore_torch_compile_decorator():
|
def test_ignore_torch_compile_decorator():
|
||||||
# piecewise
|
# piecewise
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
splitting_ops=["silly.attention"],
|
splitting_ops=["silly.attention"],
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class A(nn.Module):
|
class A(nn.Module):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs
|
||||||
*,
|
) -> None:
|
||||||
vllm_config: VllmConfig,
|
|
||||||
prefix: str = '',
|
|
||||||
**kwargs) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -79,15 +93,13 @@ def test_ignore_torch_compile_decorator():
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
@ignore_torch_compile
|
@ignore_torch_compile
|
||||||
class B(A):
|
class B(A): ...
|
||||||
...
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class C(B):
|
class C(B): ...
|
||||||
...
|
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||||
|
|
||||||
# A has support_torch_compile
|
# A has support_torch_compile
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
@@ -101,7 +113,7 @@ def test_ignore_torch_compile_decorator():
|
|||||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
|
mod_B = B(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||||
|
|
||||||
# B's ignore_torch_compile should override A's support_torch_compile
|
# B's ignore_torch_compile should override A's support_torch_compile
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
@@ -114,7 +126,7 @@ def test_ignore_torch_compile_decorator():
|
|||||||
run_model(vllm_config, mod_B, cudagraph_runtime_mode)
|
run_model(vllm_config, mod_B, cudagraph_runtime_mode)
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
|
mod_C = C(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||||
|
|
||||||
# C's support_torch_compile should override B's ignore_torch_compile
|
# C's support_torch_compile should override B's ignore_torch_compile
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
@@ -130,15 +142,11 @@ def test_ignore_torch_compile_decorator():
|
|||||||
|
|
||||||
# Only enable torch.compile if
|
# Only enable torch.compile if
|
||||||
# vllm_config.cache_config.kv_sharing_fast_prefill=True
|
# vllm_config.cache_config.kv_sharing_fast_prefill=True
|
||||||
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
@support_torch_compile(
|
||||||
kv_sharing_fast_prefill)
|
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
|
||||||
|
)
|
||||||
class B(nn.Module):
|
class B(nn.Module):
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||||
def __init__(self,
|
|
||||||
*,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
prefix: str = '',
|
|
||||||
**kwargs) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -152,15 +160,11 @@ class B(nn.Module):
|
|||||||
|
|
||||||
# Only enable torch.compile if
|
# Only enable torch.compile if
|
||||||
# vllm_config.cache_config.kv_sharing_fast_prefill=False
|
# vllm_config.cache_config.kv_sharing_fast_prefill=False
|
||||||
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
|
@support_torch_compile(
|
||||||
cache_config.kv_sharing_fast_prefill)
|
enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
|
||||||
|
)
|
||||||
class A(nn.Module):
|
class A(nn.Module):
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||||
def __init__(self,
|
|
||||||
*,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
prefix: str = '',
|
|
||||||
**kwargs) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||||
self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||||
@@ -175,18 +179,21 @@ class A(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def test_conditional_compile_enable_if():
|
def test_conditional_compile_enable_if():
|
||||||
vllm_config = VllmConfig(cache_config=CacheConfig(
|
vllm_config = VllmConfig(
|
||||||
kv_sharing_fast_prefill=True, ),
|
cache_config=CacheConfig(
|
||||||
|
kv_sharing_fast_prefill=True,
|
||||||
|
),
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
splitting_ops=["silly.attention"],
|
splitting_ops=["silly.attention"],
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
))
|
),
|
||||||
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||||
|
|
||||||
# A has support_torch_compile but enable_if fn returns False
|
# A has support_torch_compile but enable_if fn returns False
|
||||||
# enalbe_if will be True for B, so we expect mod1 and mod2
|
# enalbe_if will be True for B, so we expect mod1 and mod2
|
||||||
@@ -204,17 +211,20 @@ def test_conditional_compile_enable_if():
|
|||||||
|
|
||||||
# Set kv_sharing_fast_prefill=False
|
# Set kv_sharing_fast_prefill=False
|
||||||
# which will cause A to be compiled and B to not be compiled
|
# which will cause A to be compiled and B to not be compiled
|
||||||
vllm_config = VllmConfig(cache_config=CacheConfig(
|
vllm_config = VllmConfig(
|
||||||
kv_sharing_fast_prefill=False, ),
|
cache_config=CacheConfig(
|
||||||
|
kv_sharing_fast_prefill=False,
|
||||||
|
),
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
splitting_ops=["silly.attention"],
|
splitting_ops=["silly.attention"],
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
))
|
),
|
||||||
|
)
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=1,
|
num_graphs_seen=1,
|
||||||
|
|||||||
@@ -14,8 +14,7 @@ from tests.quantization.utils import is_quant_method_supported
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig
|
||||||
PassConfig)
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_torch_equal_or_newer
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
@@ -25,43 +24,54 @@ from ..utils import create_new_process_for_each_test
|
|||||||
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
|
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
|
||||||
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
|
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
|
||||||
("facebook/opt-125m", {}),
|
("facebook/opt-125m", {}),
|
||||||
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
|
(
|
||||||
|
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
|
||||||
|
{
|
||||||
"dtype": torch.float16,
|
"dtype": torch.float16,
|
||||||
}),
|
},
|
||||||
("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", {
|
),
|
||||||
|
(
|
||||||
|
"neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
|
||||||
|
{
|
||||||
"dtype": torch.float16,
|
"dtype": torch.float16,
|
||||||
}),
|
},
|
||||||
|
),
|
||||||
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
|
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
|
||||||
("meta-llama/Llama-3.2-1B-Instruct", {}),
|
("meta-llama/Llama-3.2-1B-Instruct", {}),
|
||||||
]
|
]
|
||||||
|
|
||||||
if all:
|
if all:
|
||||||
|
|
||||||
# TODO: figure out why this fails.
|
# TODO: figure out why this fails.
|
||||||
if False and is_quant_method_supported("gguf"): # noqa: SIM223
|
if False and is_quant_method_supported("gguf"): # noqa: SIM223
|
||||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
|
TEST_MODELS.append(
|
||||||
"quantization": "gguf"
|
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {"quantization": "gguf"})
|
||||||
}))
|
)
|
||||||
|
|
||||||
if is_quant_method_supported("gptq"):
|
if is_quant_method_supported("gptq"):
|
||||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
|
TEST_MODELS.append(
|
||||||
"quantization": "gptq"
|
("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {"quantization": "gptq"})
|
||||||
}))
|
)
|
||||||
|
|
||||||
if is_quant_method_supported("gptq_marlin"):
|
if is_quant_method_supported("gptq_marlin"):
|
||||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
|
TEST_MODELS.append(
|
||||||
"quantization": "gptq_marlin"
|
(
|
||||||
}))
|
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
|
||||||
|
{"quantization": "gptq_marlin"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if is_quant_method_supported("gptq_marlin_24"):
|
if is_quant_method_supported("gptq_marlin_24"):
|
||||||
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
|
TEST_MODELS.append(
|
||||||
"quantization": "gptq_marlin_24"
|
(
|
||||||
}))
|
"alexm-nm/tinyllama-24-marlin24-4bit-g128",
|
||||||
|
{"quantization": "gptq_marlin_24"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
|
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
|
||||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
|
TEST_MODELS.append(
|
||||||
"quantization": "AWQ"
|
("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {"quantization": "AWQ"})
|
||||||
}))
|
)
|
||||||
|
|
||||||
if keywords is None:
|
if keywords is None:
|
||||||
return TEST_MODELS
|
return TEST_MODELS
|
||||||
@@ -95,22 +105,34 @@ def test_full_graph(
|
|||||||
"compilation_config, model_info",
|
"compilation_config, model_info",
|
||||||
[
|
[
|
||||||
# additional compile sizes, only some of the models
|
# additional compile sizes, only some of the models
|
||||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
(
|
||||||
compile_sizes=[1, 2]), model)
|
CompilationConfig(level=CompilationLevel.PIECEWISE, compile_sizes=[1, 2]),
|
||||||
|
model,
|
||||||
|
)
|
||||||
for model in models_list(all=False)
|
for model in models_list(all=False)
|
||||||
] + [
|
]
|
||||||
|
+ [
|
||||||
# RMSNorm + quant fusion, only 8-bit quant models
|
# RMSNorm + quant fusion, only 8-bit quant models
|
||||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
(
|
||||||
|
CompilationConfig(
|
||||||
|
level=CompilationLevel.PIECEWISE,
|
||||||
custom_ops=["+rms_norm"],
|
custom_ops=["+rms_norm"],
|
||||||
pass_config=PassConfig(enable_fusion=True,
|
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||||
enable_noop=True)), model)
|
),
|
||||||
|
model,
|
||||||
|
)
|
||||||
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
|
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
|
||||||
] + [
|
]
|
||||||
|
+ [
|
||||||
# Test depyf integration works
|
# Test depyf integration works
|
||||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
(
|
||||||
debug_dump_path=tempfile.gettempdir()),
|
CompilationConfig(
|
||||||
("facebook/opt-125m", {})),
|
level=CompilationLevel.PIECEWISE, debug_dump_path=tempfile.gettempdir()
|
||||||
] + [
|
),
|
||||||
|
("facebook/opt-125m", {}),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
+ [
|
||||||
# graph inductor partition
|
# graph inductor partition
|
||||||
(
|
(
|
||||||
CompilationConfig(
|
CompilationConfig(
|
||||||
@@ -119,20 +141,24 @@ def test_full_graph(
|
|||||||
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
|
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
|
||||||
use_inductor_graph_partition=True,
|
use_inductor_graph_partition=True,
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
compile_sizes=[1, 2]),
|
compile_sizes=[1, 2],
|
||||||
model) for model in models_list(all=False)
|
),
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
for model in models_list(all=False)
|
||||||
if is_torch_equal_or_newer("2.9.0.dev")
|
if is_torch_equal_or_newer("2.9.0.dev")
|
||||||
])
|
],
|
||||||
|
)
|
||||||
# only test some of the models
|
# only test some of the models
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_custom_compile_config(
|
def test_custom_compile_config(
|
||||||
compilation_config: CompilationConfig,
|
compilation_config: CompilationConfig,
|
||||||
model_info: tuple[str, dict[str, Any]],
|
model_info: tuple[str, dict[str, Any]],
|
||||||
):
|
):
|
||||||
if (compilation_config.use_inductor_graph_partition
|
if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
|
||||||
and not is_torch_equal_or_newer("2.9.0.dev")):
|
"2.9.0.dev"
|
||||||
pytest.skip("inductor graph partition is only available "
|
):
|
||||||
"in PyTorch 2.9+")
|
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||||
|
|
||||||
model, model_kwargs = model_info
|
model, model_kwargs = model_info
|
||||||
print(f"MODEL={model}")
|
print(f"MODEL={model}")
|
||||||
@@ -156,8 +182,7 @@ def test_fp8_kv_scale_compile(optimization_level: int):
|
|||||||
|
|
||||||
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
||||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
pytest.skip("inductor graph partition is only available "
|
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||||
"in PyTorch 2.9+")
|
|
||||||
|
|
||||||
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
|
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
@@ -171,14 +196,16 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
|||||||
"kv_cache_dtype": "fp8",
|
"kv_cache_dtype": "fp8",
|
||||||
"max_model_len": 1024,
|
"max_model_len": 1024,
|
||||||
}
|
}
|
||||||
with caplog_vllm.at_level(
|
with (
|
||||||
logging.DEBUG), global_force_attn_backend_context_manager(
|
caplog_vllm.at_level(logging.DEBUG),
|
||||||
_Backend.FLASHINFER):
|
global_force_attn_backend_context_manager(_Backend.FLASHINFER),
|
||||||
|
):
|
||||||
run_model(compilation_config, model, model_kwargs)
|
run_model(compilation_config, model, model_kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert ("Fused quantization onto 48 attention nodes"
|
assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, (
|
||||||
in caplog_vllm.text), caplog_vllm.text
|
caplog_vllm.text
|
||||||
|
)
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
# Note: this message is only triggered when the compilation goes
|
# Note: this message is only triggered when the compilation goes
|
||||||
# through the custom pass. Due to multiple layers of cache on
|
# through the custom pass. Due to multiple layers of cache on
|
||||||
@@ -189,8 +216,11 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
|||||||
assert "Fused quantization" not in caplog_vllm.text
|
assert "Fused quantization" not in caplog_vllm.text
|
||||||
|
|
||||||
|
|
||||||
def run_model(compile_config: Union[int, CompilationConfig], model: str,
|
def run_model(
|
||||||
model_kwargs: dict[str, Any]):
|
compile_config: Union[int, CompilationConfig],
|
||||||
|
model: str,
|
||||||
|
model_kwargs: dict[str, Any],
|
||||||
|
):
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
"The president of the United States is",
|
"The president of the United States is",
|
||||||
|
|||||||
@@ -14,10 +14,8 @@ from vllm.compilation.post_cleanup import PostCleanupPass
|
|||||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
GroupShape)
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
||||||
Fp8LinearOp)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@@ -28,7 +26,6 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
|||||||
|
|
||||||
|
|
||||||
class TestSiluMul(torch.nn.Module):
|
class TestSiluMul(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size: int = 128):
|
def __init__(self, hidden_size: int = 128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.silu_and_mul = SiluAndMul()
|
self.silu_and_mul = SiluAndMul()
|
||||||
@@ -36,8 +33,7 @@ class TestSiluMul(torch.nn.Module):
|
|||||||
self.scale = torch.rand(1, dtype=torch.float32)
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
|
|
||||||
if TEST_FP8:
|
if TEST_FP8:
|
||||||
self.w = torch.rand(hidden_size,
|
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||||
hidden_size).to(dtype=FP8_DTYPE).t()
|
|
||||||
self.fp8_linear = Fp8LinearOp(
|
self.fp8_linear = Fp8LinearOp(
|
||||||
act_quant_static=True,
|
act_quant_static=True,
|
||||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||||
@@ -46,17 +42,14 @@ class TestSiluMul(torch.nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = self.silu_and_mul(x)
|
y = self.silu_and_mul(x)
|
||||||
if TEST_FP8:
|
if TEST_FP8:
|
||||||
x2 = self.fp8_linear.apply(y,
|
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
|
||||||
self.w,
|
|
||||||
self.wscale,
|
|
||||||
input_scale=self.wscale)
|
|
||||||
return x2
|
return x2
|
||||||
else:
|
else:
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def example_inputs(self, num_tokens=32, hidden_size=128):
|
def example_inputs(self, num_tokens=32, hidden_size=128):
|
||||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||||
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype), )
|
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),)
|
||||||
|
|
||||||
def ops_in_model(self, do_fusion):
|
def ops_in_model(self, do_fusion):
|
||||||
if TEST_FP8 and do_fusion:
|
if TEST_FP8 and do_fusion:
|
||||||
@@ -69,7 +62,6 @@ class TestSiluMul(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestFusedAddRMSNorm(torch.nn.Module):
|
class TestFusedAddRMSNorm(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -78,10 +70,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
|||||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||||
|
|
||||||
self.gate_proj = torch.nn.Parameter(
|
self.gate_proj = torch.nn.Parameter(
|
||||||
torch.empty((intermediate_size, hidden_size), dtype=dtype))
|
torch.empty((intermediate_size, hidden_size), dtype=dtype)
|
||||||
|
)
|
||||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||||
self.norm.weight = torch.nn.Parameter(
|
self.norm.weight = torch.nn.Parameter(
|
||||||
torch.ones(intermediate_size, dtype=dtype))
|
torch.ones(intermediate_size, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||||
|
|
||||||
@@ -89,8 +83,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
|||||||
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
||||||
|
|
||||||
self.scale = torch.rand(1, dtype=torch.float32)
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
self.w = torch.rand(hidden_size,
|
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||||
intermediate_size).to(dtype=FP8_DTYPE).t()
|
|
||||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual):
|
def forward(self, hidden_states, residual):
|
||||||
@@ -120,10 +113,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
|||||||
|
|
||||||
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
|
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
|
||||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||||
dtype=dtype)
|
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||||
residual = torch.randn((batch_size * seq_len, hidden_size),
|
|
||||||
dtype=dtype)
|
|
||||||
return (hidden_states, residual)
|
return (hidden_states, residual)
|
||||||
|
|
||||||
def ops_in_model(self, do_fusion):
|
def ops_in_model(self, do_fusion):
|
||||||
@@ -137,12 +128,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestRotaryEmbedding(torch.nn.Module):
|
class TestRotaryEmbedding(torch.nn.Module):
|
||||||
|
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
|
||||||
def __init__(self,
|
|
||||||
head_dim=64,
|
|
||||||
rotary_dim=None,
|
|
||||||
max_position=2048,
|
|
||||||
base=10000):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.rotary_dim = rotary_dim or head_dim
|
self.rotary_dim = rotary_dim or head_dim
|
||||||
@@ -173,21 +159,15 @@ class TestRotaryEmbedding(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||||
|
def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000):
|
||||||
def __init__(self,
|
|
||||||
head_dim=64,
|
|
||||||
num_heads=4,
|
|
||||||
max_position=2048,
|
|
||||||
base=10000):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.hidden_size = head_dim * num_heads
|
self.hidden_size = head_dim * num_heads
|
||||||
|
|
||||||
self.qkv_proj = torch.nn.Linear(self.hidden_size,
|
self.qkv_proj = torch.nn.Linear(
|
||||||
self.hidden_size * 3,
|
self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16
|
||||||
bias=False,
|
)
|
||||||
dtype=torch.float16)
|
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -233,21 +213,24 @@ MODELS = [
|
|||||||
|
|
||||||
@pytest.mark.parametrize("model_class", MODELS)
|
@pytest.mark.parametrize("model_class", MODELS)
|
||||||
@pytest.mark.parametrize("do_fusion", [True, False])
|
@pytest.mark.parametrize("do_fusion", [True, False])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
|
||||||
reason="Only test on CUDA")
|
|
||||||
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
|
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
vllm_config.compilation_config = CompilationConfig(
|
vllm_config.compilation_config = CompilationConfig(
|
||||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
|
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)
|
||||||
|
)
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||||
cleanup_pass = PostCleanupPass(vllm_config)
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||||
|
|
||||||
passes = ([noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
passes = (
|
||||||
if do_fusion else [noop_pass, cleanup_pass])
|
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
||||||
|
if do_fusion
|
||||||
|
else [noop_pass, cleanup_pass]
|
||||||
|
)
|
||||||
func_pass = FixFunctionalizationPass(vllm_config)
|
func_pass = FixFunctionalizationPass(vllm_config)
|
||||||
|
|
||||||
backend_func = TestBackend(*passes, func_pass)
|
backend_func = TestBackend(*passes, func_pass)
|
||||||
@@ -260,8 +243,7 @@ def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
|
|||||||
# check if the functionalization pass is applied
|
# check if the functionalization pass is applied
|
||||||
for op in model.ops_in_model(do_fusion):
|
for op in model.ops_in_model(do_fusion):
|
||||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||||
assert (find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op)
|
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501
|
||||||
is None) # noqa: E501
|
|
||||||
|
|
||||||
# make sure the ops were all de-functionalized
|
# make sure the ops were all de-functionalized
|
||||||
found = dict()
|
found = dict()
|
||||||
|
|||||||
@@ -5,17 +5,26 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.plugins
|
import vllm.plugins
|
||||||
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
from vllm.compilation.fusion import (
|
||||||
RMSNormQuantFusionPass)
|
FUSED_OPS,
|
||||||
|
QUANT_OPS,
|
||||||
|
FusedRMSQuantKey,
|
||||||
|
RMSNormQuantFusionPass,
|
||||||
|
)
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
|
||||||
VllmConfig)
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape, QuantKey, ScaleDesc)
|
GroupShape,
|
||||||
|
QuantKey,
|
||||||
|
ScaleDesc,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity)
|
Fp8LinearOp,
|
||||||
|
cutlass_fp8_supported,
|
||||||
|
maybe_create_device_identity,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import override_cutlass_fp8_supported
|
from ..utils import override_cutlass_fp8_supported
|
||||||
@@ -25,9 +34,15 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
|||||||
|
|
||||||
|
|
||||||
class TestModel(torch.nn.Module):
|
class TestModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
def __init__(self, hidden_size: int, eps: float, static: bool,
|
self,
|
||||||
cuda_force_torch: bool, *args, **kwargs):
|
hidden_size: int,
|
||||||
|
eps: float,
|
||||||
|
static: bool,
|
||||||
|
cuda_force_torch: bool,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.cuda_force_torch = cuda_force_torch
|
self.cuda_force_torch = cuda_force_torch
|
||||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||||
@@ -54,17 +69,15 @@ class TestModel(torch.nn.Module):
|
|||||||
resid = torch.sqrt(x)
|
resid = torch.sqrt(x)
|
||||||
y = self.norm[0](x)
|
y = self.norm[0](x)
|
||||||
|
|
||||||
x2 = self.fp8_linear.apply(y,
|
x2 = self.fp8_linear.apply(
|
||||||
self.w[0],
|
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||||
self.wscale[0],
|
)
|
||||||
input_scale=self.scale[0])
|
|
||||||
# make sure resid is used for replacement to work
|
# make sure resid is used for replacement to work
|
||||||
y2, resid = self.norm[1](x2, resid)
|
y2, resid = self.norm[1](x2, resid)
|
||||||
|
|
||||||
x3 = self.fp8_linear.apply(y2,
|
x3 = self.fp8_linear.apply(
|
||||||
self.w[1],
|
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||||
self.wscale[1],
|
)
|
||||||
input_scale=self.scale[1])
|
|
||||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||||
return y3
|
return y3
|
||||||
|
|
||||||
@@ -74,7 +87,7 @@ class TestModel(torch.nn.Module):
|
|||||||
def ops_in_model_after(self):
|
def ops_in_model_after(self):
|
||||||
return [
|
return [
|
||||||
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
|
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
|
||||||
FUSED_OPS[FusedRMSQuantKey(self.key, True)]
|
FUSED_OPS[FusedRMSQuantKey(self.key, True)],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -85,22 +98,27 @@ class TestModel(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize("static", [True, False])
|
@pytest.mark.parametrize("static", [True, False])
|
||||||
# cuda_force_torch used to test torch code path on platforms that
|
# cuda_force_torch used to test torch code path on platforms that
|
||||||
# cutlass_fp8_supported() == True.
|
# cutlass_fp8_supported() == True.
|
||||||
@pytest.mark.parametrize("cuda_force_torch",
|
@pytest.mark.parametrize(
|
||||||
[True, False] if cutlass_fp8_supported() else [True])
|
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
)
|
||||||
reason="Only test on CUDA and ROCm")
|
@pytest.mark.skipif(
|
||||||
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
|
||||||
cuda_force_torch):
|
)
|
||||||
|
def test_fusion_rmsnorm_quant(
|
||||||
|
dtype, hidden_size, num_tokens, eps, static, cuda_force_torch
|
||||||
|
):
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
custom_ops=["+rms_norm", "+quant_fp8"],
|
custom_ops=["+rms_norm", "+quant_fp8"],
|
||||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||||
))
|
)
|
||||||
|
)
|
||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
# Reshape pass is needed for the fusion pass to work
|
# Reshape pass is needed for the fusion pass to work
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
|
|||||||
@@ -10,14 +10,24 @@ from vllm.compilation.collective_fusion import AllReduceFusionPass
|
|||||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
|
from vllm.config import (
|
||||||
ModelConfig, PassConfig, VllmConfig)
|
CompilationConfig,
|
||||||
|
CompilationLevel,
|
||||||
|
DeviceConfig,
|
||||||
|
ModelConfig,
|
||||||
|
PassConfig,
|
||||||
|
VllmConfig,
|
||||||
|
)
|
||||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
from vllm.distributed.parallel_state import (
|
||||||
initialize_model_parallel)
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
GroupShape, QuantFP8)
|
GroupShape,
|
||||||
|
QuantFP8,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
@@ -26,7 +36,6 @@ from .backend import TestBackend
|
|||||||
|
|
||||||
|
|
||||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -47,7 +56,6 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -68,25 +76,22 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
|
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.norm = RMSNorm(hidden_size, eps)
|
self.norm = RMSNorm(hidden_size, eps)
|
||||||
self.quant_fp8 = QuantFP8(static=True,
|
self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
|
||||||
group_shape=GroupShape.PER_TENSOR)
|
|
||||||
self.scale = torch.rand(1, dtype=torch.float32)
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
self.output = torch.empty((token_num, hidden_size),
|
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, residual):
|
def forward(self, hidden_states, residual):
|
||||||
view = hidden_states.reshape(-1, self.hidden_size)
|
view = hidden_states.reshape(-1, self.hidden_size)
|
||||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||||
torch.ops._C.static_scaled_fp8_quant(self.output,
|
torch.ops._C.static_scaled_fp8_quant(
|
||||||
norm_output.contiguous(),
|
self.output, norm_output.contiguous(), self.scale
|
||||||
self.scale)
|
)
|
||||||
return self.output, residual_output
|
return self.output, residual_output
|
||||||
|
|
||||||
def ops_in_model_after(self):
|
def ops_in_model_after(self):
|
||||||
@@ -95,35 +100,33 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
|
|||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
return [
|
return [
|
||||||
torch.ops.vllm.all_reduce.default,
|
torch.ops.vllm.all_reduce.default,
|
||||||
torch.ops._C.static_scaled_fp8_quant.default
|
torch.ops._C.static_scaled_fp8_quant.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.norm = RMSNorm(hidden_size, eps)
|
self.norm = RMSNorm(hidden_size, eps)
|
||||||
self.scale = torch.rand(1, dtype=torch.float32)
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
self.output = torch.empty((token_num, hidden_size),
|
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
round_up = lambda x, y: (x + y - 1) // y * y
|
round_up = lambda x, y: (x + y - 1) // y * y
|
||||||
rounded_m = round_up(token_num, 128)
|
rounded_m = round_up(token_num, 128)
|
||||||
scale_n = hidden_size // 16
|
scale_n = hidden_size // 16
|
||||||
rounded_n = round_up(scale_n, 4)
|
rounded_n = round_up(scale_n, 4)
|
||||||
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
|
self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32)
|
||||||
dtype=torch.int32)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, residual):
|
def forward(self, hidden_states, residual):
|
||||||
view = hidden_states.reshape(-1, self.hidden_size)
|
view = hidden_states.reshape(-1, self.hidden_size)
|
||||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||||
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
|
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
|
||||||
torch.ops._C.scaled_fp4_quant(self.output, norm_output,
|
torch.ops._C.scaled_fp4_quant(
|
||||||
self.output_scale, self.scale)
|
self.output, norm_output, self.output_scale, self.scale
|
||||||
|
)
|
||||||
return self.output, residual_output, self.output_scale
|
return self.output, residual_output, self.output_scale
|
||||||
|
|
||||||
def ops_in_model_after(self):
|
def ops_in_model_after(self):
|
||||||
@@ -132,7 +135,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
|||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
return [
|
return [
|
||||||
torch.ops.vllm.all_reduce.default,
|
torch.ops.vllm.all_reduce.default,
|
||||||
torch.ops._C.scaled_fp4_quant.default
|
torch.ops._C.scaled_fp4_quant.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -145,41 +148,55 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
|||||||
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
|
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
|
||||||
# TODO: Enable with torch==2.8.0
|
# TODO: Enable with torch==2.8.0
|
||||||
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
|
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
|
||||||
])
|
],
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize("seq_len", [8])
|
@pytest.mark.parametrize("seq_len", [8])
|
||||||
@pytest.mark.parametrize("hidden_size", [16])
|
@pytest.mark.parametrize("hidden_size", [16])
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||||
reason="Only test on CUDA")
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not find_spec("flashinfer")
|
not find_spec("flashinfer")
|
||||||
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
|
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
|
||||||
reason="flashinfer is not found or flashinfer "
|
reason="flashinfer is not found or flashinfer "
|
||||||
"is not compiled with trtllm_allreduce_fusion")
|
"is not compiled with trtllm_allreduce_fusion",
|
||||||
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
|
)
|
||||||
batch_size: int, seq_len: int,
|
def test_all_reduce_fusion_pass_replace(
|
||||||
hidden_size: int, dtype: torch.dtype):
|
test_model: torch.nn.Module,
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
num_processes = 2
|
num_processes = 2
|
||||||
if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
|
if (
|
||||||
and not current_platform.has_device_capability(100)):
|
test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
|
||||||
pytest.skip("Skip as nvfp4 is only supported on "
|
and not current_platform.has_device_capability(100)
|
||||||
"devices with compute capability 10.0 (Blackwell)")
|
):
|
||||||
|
pytest.skip(
|
||||||
|
"Skip as nvfp4 is only supported on "
|
||||||
|
"devices with compute capability 10.0 (Blackwell)"
|
||||||
|
)
|
||||||
|
|
||||||
def run_torch_spawn(fn, nprocs):
|
def run_torch_spawn(fn, nprocs):
|
||||||
torch.multiprocessing.spawn(fn,
|
torch.multiprocessing.spawn(
|
||||||
args=(num_processes, test_model,
|
fn,
|
||||||
batch_size, seq_len, hidden_size,
|
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
|
||||||
dtype),
|
nprocs=nprocs,
|
||||||
nprocs=nprocs)
|
)
|
||||||
|
|
||||||
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
|
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
|
||||||
|
|
||||||
|
|
||||||
def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
def all_reduce_fusion_pass_on_test_model(
|
||||||
|
local_rank: int,
|
||||||
|
world_size: int,
|
||||||
test_model_cls: torch.nn.Module,
|
test_model_cls: torch.nn.Module,
|
||||||
batch_size: int, seq_len: int,
|
batch_size: int,
|
||||||
hidden_size: int, dtype: torch.dtype):
|
seq_len: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
@@ -187,39 +204,42 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
|
|
||||||
update_environment_variables({
|
update_environment_variables(
|
||||||
'RANK': str(local_rank),
|
{
|
||||||
'LOCAL_RANK': str(local_rank),
|
"RANK": str(local_rank),
|
||||||
'WORLD_SIZE': str(world_size),
|
"LOCAL_RANK": str(local_rank),
|
||||||
'MASTER_ADDR': 'localhost',
|
"WORLD_SIZE": str(world_size),
|
||||||
'MASTER_PORT': '12345',
|
"MASTER_ADDR": "localhost",
|
||||||
})
|
"MASTER_PORT": "12345",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
compilation_config=CompilationConfig(
|
||||||
custom_ops=["+rms_norm", "+quant_fp8"]))
|
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm", "+quant_fp8"]
|
||||||
|
)
|
||||||
|
)
|
||||||
vllm_config.compilation_config.pass_config = PassConfig(
|
vllm_config.compilation_config.pass_config = PassConfig(
|
||||||
enable_fi_allreduce_fusion=True, enable_noop=True)
|
enable_fi_allreduce_fusion=True, enable_noop=True
|
||||||
|
)
|
||||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||||
|
|
||||||
# this is a fake model name to construct the model config
|
# this is a fake model name to construct the model config
|
||||||
# in the vllm_config, it's not really used.
|
# in the vllm_config, it's not really used.
|
||||||
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||||
vllm_config.model_config = ModelConfig(model=model_name,
|
vllm_config.model_config = ModelConfig(
|
||||||
trust_remote_code=True,
|
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||||
dtype=dtype,
|
)
|
||||||
seed=42)
|
|
||||||
|
|
||||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
func_pass = FixFunctionalizationPass(vllm_config)
|
func_pass = FixFunctionalizationPass(vllm_config)
|
||||||
cleanup_pass = PostCleanupPass(vllm_config)
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
|
|
||||||
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass,
|
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass)
|
||||||
cleanup_pass)
|
|
||||||
|
|
||||||
token_num = batch_size * seq_len
|
token_num = batch_size * seq_len
|
||||||
model = test_model_cls(hidden_size, token_num)
|
model = test_model_cls(hidden_size, token_num)
|
||||||
|
|||||||
@@ -19,14 +19,23 @@ from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
|||||||
from vllm.compilation.fx_utils import find_op_nodes
|
from vllm.compilation.fx_utils import find_op_nodes
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
from vllm.config import (
|
||||||
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
|
CacheConfig,
|
||||||
set_current_vllm_config)
|
CompilationConfig,
|
||||||
|
CompilationLevel,
|
||||||
|
ModelConfig,
|
||||||
|
PassConfig,
|
||||||
|
SchedulerConfig,
|
||||||
|
VllmConfig,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
|
QuantKey,
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
kFp8StaticTensorSym,
|
||||||
Fp8LinearOp)
|
kNvfp4Quant,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_torch_equal_or_newer
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
@@ -40,14 +49,16 @@ backend_unfused: Optional[TestBackend] = None
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, quant_key",
|
"model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]
|
||||||
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
|
)
|
||||||
@pytest.mark.parametrize("use_triton_fa", [True, False])
|
@pytest.mark.parametrize("use_triton_fa", [True, False])
|
||||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
@pytest.mark.skipif(
|
||||||
reason="V0 attn quant fusion only on ROCm")
|
not current_platform.is_rocm(), reason="V0 attn quant fusion only on ROCm"
|
||||||
def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
)
|
||||||
quant_key: QuantKey, use_triton_fa: bool):
|
def test_attention_fusion_v0(
|
||||||
|
example_prompts, monkeypatch, model: str, quant_key: QuantKey, use_triton_fa: bool
|
||||||
|
):
|
||||||
# Clean Dynamo cache to avoid reusing other test cases
|
# Clean Dynamo cache to avoid reusing other test cases
|
||||||
# (for some reason the reset at the end is not enough)
|
# (for some reason the reset at the end is not enough)
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
@@ -69,22 +80,24 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
|||||||
backend="tests.compile.test_fusion_attn.backend_unfused",
|
backend="tests.compile.test_fusion_attn.backend_unfused",
|
||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
)
|
)
|
||||||
vllm_config = VllmConfig(compilation_config=compile_config,
|
vllm_config = VllmConfig(
|
||||||
|
compilation_config=compile_config,
|
||||||
model_config=ModelConfig(
|
model_config=ModelConfig(
|
||||||
model=model,
|
model=model,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
))
|
),
|
||||||
|
)
|
||||||
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
|
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
|
||||||
|
|
||||||
llm = LLM(model,
|
llm = LLM(
|
||||||
|
model,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
compilation_config=compile_config,
|
compilation_config=compile_config,
|
||||||
gpu_memory_utilization=0.5,
|
gpu_memory_utilization=0.5,
|
||||||
max_model_len=2048)
|
max_model_len=2048,
|
||||||
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.0,
|
sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_p=0.95)
|
||||||
max_tokens=10,
|
|
||||||
top_p=0.95)
|
|
||||||
|
|
||||||
unfused_output = llm.generate(prompts, sampling_params)
|
unfused_output = llm.generate(prompts, sampling_params)
|
||||||
backend_unfused = None # Reset backend to make sure llm gets released
|
backend_unfused = None # Reset backend to make sure llm gets released
|
||||||
@@ -97,21 +110,25 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
|||||||
backend="tests.compile.test_fusion_attn.backend",
|
backend="tests.compile.test_fusion_attn.backend",
|
||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
)
|
)
|
||||||
vllm_config = VllmConfig(compilation_config=compile_config,
|
vllm_config = VllmConfig(
|
||||||
|
compilation_config=compile_config,
|
||||||
model_config=ModelConfig(
|
model_config=ModelConfig(
|
||||||
model=model,
|
model=model,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
))
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# AttnFusionPass needs attention layers to be registered in config upon init
|
# AttnFusionPass needs attention layers to be registered in config upon init
|
||||||
# so we initialize it during compilation.
|
# so we initialize it during compilation.
|
||||||
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
|
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
|
||||||
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
|
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
|
||||||
llm2 = LLM(model,
|
llm2 = LLM(
|
||||||
|
model,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
compilation_config=compile_config,
|
compilation_config=compile_config,
|
||||||
gpu_memory_utilization=0.5,
|
gpu_memory_utilization=0.5,
|
||||||
max_model_len=2048)
|
max_model_len=2048,
|
||||||
|
)
|
||||||
|
|
||||||
# check support
|
# check support
|
||||||
attn_fusion_supported = [
|
attn_fusion_supported = [
|
||||||
@@ -132,9 +149,9 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
|||||||
for i in range(len(attn_nodes_pre)):
|
for i in range(len(attn_nodes_pre)):
|
||||||
assert attn_nodes_pre[i].kwargs["output_scale"] is None
|
assert attn_nodes_pre[i].kwargs["output_scale"] is None
|
||||||
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
|
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
|
||||||
assert fused == attn_fusion_supported[i], \
|
assert fused == attn_fusion_supported[i], (
|
||||||
f"Node {i} {'' if fused else 'not '} expected " \
|
f"Node {i} {'' if fused else 'not '} expected to have fused output quant"
|
||||||
f"to have fused output quant"
|
)
|
||||||
|
|
||||||
# check outputs
|
# check outputs
|
||||||
fused_output = llm2.generate(prompts, sampling_params)
|
fused_output = llm2.generate(prompts, sampling_params)
|
||||||
@@ -160,9 +177,16 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
|||||||
class AttentionQuantPatternModel(torch.nn.Module):
|
class AttentionQuantPatternModel(torch.nn.Module):
|
||||||
"""Base model for AttentionQuantPattern fusion."""
|
"""Base model for AttentionQuantPattern fusion."""
|
||||||
|
|
||||||
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
|
def __init__(
|
||||||
kv_cache_dtype: torch.dtype, device: torch.device,
|
self,
|
||||||
vllm_config: VllmConfig, **kwargs):
|
num_qo_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
kv_cache_dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_qo_heads = num_qo_heads
|
self.num_qo_heads = num_qo_heads
|
||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
@@ -197,33 +221,30 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_attn_metadata(self, batch_size: int, use_hnd: bool) \
|
def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetadata:
|
||||||
-> AttentionMetadata:
|
|
||||||
"""Initialize attention metadata."""
|
"""Initialize attention metadata."""
|
||||||
|
|
||||||
# Create common attn metadata
|
# Create common attn metadata
|
||||||
batch_spec = BatchSpec(seq_lens=[1] * batch_size,
|
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
|
||||||
query_lens=[1] * batch_size)
|
|
||||||
common_attn_metadata = create_common_attn_metadata(
|
common_attn_metadata = create_common_attn_metadata(
|
||||||
batch_spec,
|
batch_spec, self.block_size, self.device, arange_block_indices=True
|
||||||
self.block_size,
|
)
|
||||||
self.device,
|
|
||||||
arange_block_indices=True)
|
|
||||||
|
|
||||||
max_blocks = (max(batch_spec.seq_lens) + self.block_size -
|
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
|
||||||
1) // self.block_size
|
|
||||||
num_blocks = batch_size * max_blocks
|
num_blocks = batch_size * max_blocks
|
||||||
|
|
||||||
# Create dummy KV cache for FlashInfer TRTLLM
|
# Create dummy KV cache for FlashInfer TRTLLM
|
||||||
# - NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
# - NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||||
# - HND: [num_blocks, num_kv_heads, block_size, head_size]
|
# - HND: [num_blocks, num_kv_heads, block_size, head_size]
|
||||||
kv_cache = torch.zeros(num_blocks,
|
kv_cache = torch.zeros(
|
||||||
|
num_blocks,
|
||||||
2,
|
2,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
device=self.device)
|
device=self.device,
|
||||||
|
)
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
# k/v as 1st dimention
|
# k/v as 1st dimention
|
||||||
if use_hnd:
|
if use_hnd:
|
||||||
@@ -239,7 +260,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
|||||||
|
|
||||||
# Build attn metadata
|
# Build attn metadata
|
||||||
self.attn_metadata = self.builder.build(
|
self.attn_metadata = self.builder.build(
|
||||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata)
|
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||||
|
)
|
||||||
|
|
||||||
return self.attn_metadata
|
return self.attn_metadata
|
||||||
|
|
||||||
@@ -254,27 +276,30 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
|||||||
|
|
||||||
self.fp8_linear = Fp8LinearOp(
|
self.fp8_linear = Fp8LinearOp(
|
||||||
act_quant_static=self.quant_key.scale.static,
|
act_quant_static=self.quant_key.scale.static,
|
||||||
act_quant_group_shape=self.quant_key.scale.group_shape)
|
act_quant_group_shape=self.quant_key.scale.group_shape,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_size = self.num_qo_heads * self.head_size
|
hidden_size = self.num_qo_heads * self.head_size
|
||||||
self.w = kwargs.get(
|
self.w = kwargs.get(
|
||||||
"w", {
|
"w",
|
||||||
"weight":
|
{
|
||||||
torch.randn(hidden_size, hidden_size).to(
|
"weight": torch.randn(hidden_size, hidden_size)
|
||||||
dtype=FP8_DTYPE, device=self.device).t(),
|
.to(dtype=FP8_DTYPE, device=self.device)
|
||||||
"wscale":
|
.t(),
|
||||||
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
"wscale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||||
"scale":
|
"scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||||
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
},
|
||||||
})
|
)
|
||||||
|
|
||||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||||
"""Forward pass that creates the pattern to be fused."""
|
"""Forward pass that creates the pattern to be fused."""
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
return self.fp8_linear.apply(input=attn_output,
|
return self.fp8_linear.apply(
|
||||||
|
input=attn_output,
|
||||||
weight=self.w["weight"],
|
weight=self.w["weight"],
|
||||||
weight_scale=self.w["wscale"],
|
weight_scale=self.w["wscale"],
|
||||||
input_scale=self.w["scale"])
|
input_scale=self.w["scale"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
||||||
@@ -287,42 +312,54 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
|||||||
|
|
||||||
hidden_size = self.num_qo_heads * self.head_size
|
hidden_size = self.num_qo_heads * self.head_size
|
||||||
self.w = kwargs.get(
|
self.w = kwargs.get(
|
||||||
"w", {
|
"w",
|
||||||
"weight":
|
{
|
||||||
torch.randint(256, (hidden_size, hidden_size // 2),
|
"weight": torch.randint(
|
||||||
|
256,
|
||||||
|
(hidden_size, hidden_size // 2),
|
||||||
dtype=FP4_DTYPE,
|
dtype=FP4_DTYPE,
|
||||||
device=self.device),
|
device=self.device,
|
||||||
"wscale_swizzled":
|
),
|
||||||
torch.randn(hidden_size, hidden_size // 16).to(
|
"wscale_swizzled": torch.randn(hidden_size, hidden_size // 16).to(
|
||||||
dtype=FP8_DTYPE, device=self.device),
|
dtype=FP8_DTYPE, device=self.device
|
||||||
"wscale":
|
),
|
||||||
torch.tensor([500], dtype=torch.float32, device=self.device),
|
"wscale": torch.tensor([500], dtype=torch.float32, device=self.device),
|
||||||
"scale":
|
"scale": torch.tensor([0.002], dtype=torch.float32, device=self.device),
|
||||||
torch.tensor([0.002], dtype=torch.float32, device=self.device),
|
},
|
||||||
})
|
)
|
||||||
|
|
||||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||||
"""Forward pass that creates the pattern to be fused."""
|
"""Forward pass that creates the pattern to be fused."""
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
quant_output, output_block_scale = scaled_fp4_quant(
|
quant_output, output_block_scale = scaled_fp4_quant(
|
||||||
attn_output, 1 / self.w["scale"])
|
attn_output, 1 / self.w["scale"]
|
||||||
return cutlass_scaled_fp4_mm(a=quant_output,
|
)
|
||||||
|
return cutlass_scaled_fp4_mm(
|
||||||
|
a=quant_output,
|
||||||
b=self.w["weight"],
|
b=self.w["weight"],
|
||||||
block_scale_a=output_block_scale,
|
block_scale_a=output_block_scale,
|
||||||
block_scale_b=self.w["wscale_swizzled"],
|
block_scale_b=self.w["wscale_swizzled"],
|
||||||
alpha=self.w["scale"] * self.w["wscale"],
|
alpha=self.w["scale"] * self.w["wscale"],
|
||||||
out_dtype=attn_output.dtype)
|
out_dtype=attn_output.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
MODELS = [
|
||||||
TestAttentionFp8StaticQuantPatternModel),
|
(
|
||||||
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||||
TestAttentionNvfp4QuantPatternModel)]
|
TestAttentionFp8StaticQuantPatternModel,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
||||||
|
TestAttentionNvfp4QuantPatternModel,
|
||||||
|
),
|
||||||
|
]
|
||||||
HEADS = [(64, 8), (40, 8)]
|
HEADS = [(64, 8), (40, 8)]
|
||||||
elif current_platform.is_rocm():
|
elif current_platform.is_rocm():
|
||||||
MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV",
|
MODELS = [
|
||||||
TestAttentionFp8StaticQuantPatternModel)]
|
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
|
||||||
|
]
|
||||||
HEADS = [(32, 8), (40, 8)]
|
HEADS = [(32, 8), (40, 8)]
|
||||||
else:
|
else:
|
||||||
MODELS = []
|
MODELS = []
|
||||||
@@ -331,41 +368,53 @@ else:
|
|||||||
|
|
||||||
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
|
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
|
||||||
@pytest.mark.parametrize("head_size", [128])
|
@pytest.mark.parametrize("head_size", [128])
|
||||||
@pytest.mark.parametrize("batch_size",
|
@pytest.mark.parametrize(
|
||||||
[7, 256, 533] if current_platform.is_cuda() else [8])
|
"batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||||
@pytest.mark.parametrize("model_name, model_class", MODELS)
|
@pytest.mark.parametrize("model_name, model_class", MODELS)
|
||||||
@pytest.mark.parametrize("backend",
|
|
||||||
[_Backend.FLASHINFER] if current_platform.is_cuda()
|
|
||||||
else [_Backend.TRITON_ATTN])
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"split_attention",
|
"backend",
|
||||||
[False, True] if current_platform.is_rocm() else [False])
|
[_Backend.FLASHINFER] if current_platform.is_cuda() else [_Backend.TRITON_ATTN],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"split_attention", [False, True] if current_platform.is_rocm() else [False]
|
||||||
|
)
|
||||||
# TODO(boyuan): test inductor graph partition on rocm
|
# TODO(boyuan): test inductor graph partition on rocm
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"use_inductor_graph_partition",
|
"use_inductor_graph_partition",
|
||||||
[False] if current_platform.is_rocm() else [False, True])
|
[False] if current_platform.is_rocm() else [False, True],
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
)
|
||||||
reason="Only test ROCm or CUDA")
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
|
||||||
|
)
|
||||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||||
@pytest.mark.skipif(current_platform.is_cuda()
|
@pytest.mark.skipif(
|
||||||
and not current_platform.is_device_capability((10, 0)),
|
current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)),
|
||||||
reason="On CUDA only test on SM100(Blackwell)")
|
reason="On CUDA only test on SM100(Blackwell)",
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
)
|
||||||
reason="Only test ROCm or CUDA")
|
@pytest.mark.skipif(
|
||||||
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
|
||||||
head_size: int, batch_size: int,
|
)
|
||||||
dtype: torch.dtype, model_name: str,
|
def test_attention_quant_pattern(
|
||||||
|
num_qo_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
batch_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
model_name: str,
|
||||||
model_class: type[AttentionQuantPatternModel],
|
model_class: type[AttentionQuantPatternModel],
|
||||||
backend: _Backend, split_attention: bool,
|
backend: _Backend,
|
||||||
|
split_attention: bool,
|
||||||
use_inductor_graph_partition: bool,
|
use_inductor_graph_partition: bool,
|
||||||
monkeypatch, dist_init, caplog_vllm):
|
monkeypatch,
|
||||||
|
dist_init,
|
||||||
|
caplog_vllm,
|
||||||
|
):
|
||||||
"""Test AttentionStaticQuantPattern fusion pass"""
|
"""Test AttentionStaticQuantPattern fusion pass"""
|
||||||
|
|
||||||
if use_inductor_graph_partition and not is_torch_equal_or_newer(
|
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
"2.9.0.dev"):
|
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||||
pytest.skip("inductor graph partition is only available "
|
|
||||||
"in PyTorch 2.9+")
|
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
if split_attention:
|
if split_attention:
|
||||||
@@ -386,21 +435,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||||
),
|
),
|
||||||
cache_config=CacheConfig(cache_dtype="fp8"))
|
cache_config=CacheConfig(cache_dtype="fp8"),
|
||||||
|
)
|
||||||
|
|
||||||
# Create test inputs
|
# Create test inputs
|
||||||
q = torch.randn(batch_size,
|
q = torch.randn(batch_size, num_qo_heads * head_size, dtype=dtype, device=device)
|
||||||
num_qo_heads * head_size,
|
k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||||
dtype=dtype,
|
v = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||||
device=device)
|
|
||||||
k = torch.randn(batch_size,
|
|
||||||
num_kv_heads * head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
v = torch.randn(batch_size,
|
|
||||||
num_kv_heads * head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
# Mark first dimension as dynamic for realistic testing
|
# Mark first dimension as dynamic for realistic testing
|
||||||
torch._dynamo.mark_dynamic(q, 0)
|
torch._dynamo.mark_dynamic(q, 0)
|
||||||
@@ -409,42 +450,53 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
|
|
||||||
# Run model directly without compilation and fusion
|
# Run model directly without compilation and fusion
|
||||||
vllm_config_unfused = copy.deepcopy(vllm_config)
|
vllm_config_unfused = copy.deepcopy(vllm_config)
|
||||||
with set_current_vllm_config(vllm_config_unfused), set_forward_context(
|
with (
|
||||||
attn_metadata=None, vllm_config=vllm_config_unfused
|
set_current_vllm_config(vllm_config_unfused),
|
||||||
), global_force_attn_backend_context_manager(backend):
|
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
|
||||||
model_unfused = model_class(num_qo_heads=num_qo_heads,
|
global_force_attn_backend_context_manager(backend),
|
||||||
|
):
|
||||||
|
model_unfused = model_class(
|
||||||
|
num_qo_heads=num_qo_heads,
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
kv_cache_dtype=FP8_DTYPE,
|
kv_cache_dtype=FP8_DTYPE,
|
||||||
device=device,
|
device=device,
|
||||||
vllm_config=vllm_config_unfused)
|
vllm_config=vllm_config_unfused,
|
||||||
|
)
|
||||||
model_unfused = model_unfused.to(device)
|
model_unfused = model_unfused.to(device)
|
||||||
|
|
||||||
forward_ctx = get_forward_context()
|
forward_ctx = get_forward_context()
|
||||||
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
|
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
|
||||||
batch_size, use_hnd=split_attention)
|
batch_size, use_hnd=split_attention
|
||||||
|
)
|
||||||
|
|
||||||
# Run model directly without compilation and fusion
|
# Run model directly without compilation and fusion
|
||||||
result_unfused = model_unfused(q, k, v)
|
result_unfused = model_unfused(q, k, v)
|
||||||
|
|
||||||
# Run model with attn fusion enabled
|
# Run model with attn fusion enabled
|
||||||
vllm_config.compilation_config.pass_config = PassConfig(
|
vllm_config.compilation_config.pass_config = PassConfig(
|
||||||
enable_attn_fusion=True, enable_noop=True)
|
enable_attn_fusion=True, enable_noop=True
|
||||||
with set_current_vllm_config(vllm_config), set_forward_context(
|
)
|
||||||
attn_metadata=None, vllm_config=vllm_config
|
with (
|
||||||
), global_force_attn_backend_context_manager(backend):
|
set_current_vllm_config(vllm_config),
|
||||||
model_fused = model_class(num_qo_heads=num_qo_heads,
|
set_forward_context(attn_metadata=None, vllm_config=vllm_config),
|
||||||
|
global_force_attn_backend_context_manager(backend),
|
||||||
|
):
|
||||||
|
model_fused = model_class(
|
||||||
|
num_qo_heads=num_qo_heads,
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
kv_cache_dtype=FP8_DTYPE,
|
kv_cache_dtype=FP8_DTYPE,
|
||||||
device=device,
|
device=device,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
w=model_unfused.w)
|
w=model_unfused.w,
|
||||||
|
)
|
||||||
model_fused = model_fused.to(device)
|
model_fused = model_fused.to(device)
|
||||||
|
|
||||||
forward_ctx = get_forward_context()
|
forward_ctx = get_forward_context()
|
||||||
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
|
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
|
||||||
batch_size, use_hnd=split_attention)
|
batch_size, use_hnd=split_attention
|
||||||
|
)
|
||||||
|
|
||||||
# Create test backend with fusion passes enabled
|
# Create test backend with fusion passes enabled
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
@@ -454,9 +506,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
|
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
|
||||||
|
|
||||||
# Compile model with fusion enabled
|
# Compile model with fusion enabled
|
||||||
model_compiled = torch.compile(model_fused,
|
model_compiled = torch.compile(
|
||||||
backend=test_backend,
|
model_fused, backend=test_backend, fullgraph=True
|
||||||
fullgraph=True)
|
)
|
||||||
assert model_compiled.attn._o_scale_float is None
|
assert model_compiled.attn._o_scale_float is None
|
||||||
|
|
||||||
result_fused_1 = model_compiled(q, k, v)
|
result_fused_1 = model_compiled(q, k, v)
|
||||||
@@ -471,49 +523,49 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
|
|
||||||
assert model_compiled.attn._o_scale_float is not None
|
assert model_compiled.attn._o_scale_float is not None
|
||||||
|
|
||||||
torch.testing.assert_close(result_unfused,
|
torch.testing.assert_close(
|
||||||
result_fused_2,
|
result_unfused, result_fused_2, atol=1e-2, rtol=1e-2
|
||||||
atol=1e-2,
|
)
|
||||||
rtol=1e-2)
|
|
||||||
|
|
||||||
# Check attn fusion support
|
# Check attn fusion support
|
||||||
quant_key = model_class.quant_key
|
quant_key = model_class.quant_key
|
||||||
attn_fusion_supported = [
|
attn_fusion_supported = [
|
||||||
layer.impl.fused_output_quant_supported(quant_key) for key, layer in
|
layer.impl.fused_output_quant_supported(quant_key)
|
||||||
vllm_config.compilation_config.static_forward_context.items()
|
for key, layer in vllm_config.compilation_config.static_forward_context.items()
|
||||||
]
|
]
|
||||||
if any(attn_fusion_supported):
|
if any(attn_fusion_supported):
|
||||||
# Check quantization ops in the graph before and after fusion
|
# Check quantization ops in the graph before and after fusion
|
||||||
test_backend.check_before_ops([QUANT_OPS[quant_key]],
|
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True)
|
||||||
fully_replaced=True)
|
|
||||||
|
|
||||||
# access the underlying `AttnFusionPass` on the `LazyInitPass`
|
# access the underlying `AttnFusionPass` on the `LazyInitPass`
|
||||||
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
|
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
|
||||||
|
|
||||||
# Check attention ops in the graph before and after fusion
|
# Check attention ops in the graph before and after fusion
|
||||||
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
|
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
|
||||||
attn_nodes_post = list(find_op_nodes(ATTN_OP,
|
attn_nodes_post = list(find_op_nodes(ATTN_OP, test_backend.graph_post_pass))
|
||||||
test_backend.graph_post_pass))
|
|
||||||
|
|
||||||
assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
|
assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
|
||||||
assert len(attn_nodes_pre) == len(attn_nodes_post), \
|
assert len(attn_nodes_pre) == len(attn_nodes_post), (
|
||||||
"Should have same number of attention nodes before and after fusion"
|
"Should have same number of attention nodes before and after fusion"
|
||||||
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \
|
)
|
||||||
|
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
|
||||||
"Attention should not have output_scale before fusion"
|
"Attention should not have output_scale before fusion"
|
||||||
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
|
)
|
||||||
|
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
|
||||||
"Attention should have output_scale after fusion"
|
"Attention should have output_scale after fusion"
|
||||||
|
)
|
||||||
|
|
||||||
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \
|
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
|
||||||
"Attention should not have output_block_scale before fusion"
|
"Attention should not have output_block_scale before fusion"
|
||||||
|
)
|
||||||
if quant_key.dtype == FP8_DTYPE:
|
if quant_key.dtype == FP8_DTYPE:
|
||||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \
|
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
|
||||||
"Attention should not have output_block_scale after FP8 fusion"
|
"Attention should not have output_block_scale after FP8 fusion"
|
||||||
|
)
|
||||||
elif quant_key.dtype == FP4_DTYPE:
|
elif quant_key.dtype == FP4_DTYPE:
|
||||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
|
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
|
||||||
"Attention should have output_block_scale after FP4 fusion" # noqa: E501
|
"Attention should have output_block_scale after FP4 fusion"
|
||||||
|
) # noqa: E501
|
||||||
|
|
||||||
# Check that results are close
|
# Check that results are close
|
||||||
torch.testing.assert_close(result_unfused,
|
torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2)
|
||||||
result_fused_1,
|
|
||||||
atol=1e-2,
|
|
||||||
rtol=1e-2)
|
|
||||||
|
|||||||
@@ -6,14 +6,12 @@ import torch
|
|||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
|
||||||
VllmConfig)
|
|
||||||
|
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype",
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||||
[torch.float16, torch.bfloat16, torch.float32])
|
|
||||||
@pytest.mark.parametrize("num_tokens", [256, 1024])
|
@pytest.mark.parametrize("num_tokens", [256, 1024])
|
||||||
@pytest.mark.parametrize("hidden_size", [64, 4096])
|
@pytest.mark.parametrize("hidden_size", [64, 4096])
|
||||||
def test_noop_elimination(dtype, num_tokens, hidden_size):
|
def test_noop_elimination(dtype, num_tokens, hidden_size):
|
||||||
@@ -22,7 +20,6 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
|
|||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Chain of reshapes
|
# Chain of reshapes
|
||||||
y = x.reshape(-1, 128, 32)
|
y = x.reshape(-1, 128, 32)
|
||||||
@@ -32,7 +29,7 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
|
|||||||
# Final reshape that should remain
|
# Final reshape that should remain
|
||||||
b = a.reshape(-1, 128, 32)
|
b = a.reshape(-1, 128, 32)
|
||||||
# No-op slice
|
# No-op slice
|
||||||
c = b[0:b.shape[0]]
|
c = b[0 : b.shape[0]]
|
||||||
# The pass should replace the result of this op with `c`.
|
# The pass should replace the result of this op with `c`.
|
||||||
d = torch.slice_scatter(
|
d = torch.slice_scatter(
|
||||||
torch.ones_like(c), # Dummy tensor to be scattered into
|
torch.ones_like(c), # Dummy tensor to be scattered into
|
||||||
@@ -43,10 +40,12 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
|
|||||||
)
|
)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
pass_config=PassConfig(enable_noop=True),
|
pass_config=PassConfig(enable_noop=True),
|
||||||
))
|
)
|
||||||
|
)
|
||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
|
|
||||||
@@ -82,17 +81,18 @@ def test_non_noop_slice_preserved():
|
|||||||
x = torch.randn(16, 16)
|
x = torch.randn(16, 16)
|
||||||
|
|
||||||
class SliceModel(torch.nn.Module):
|
class SliceModel(torch.nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
base = x.clone()
|
base = x.clone()
|
||||||
src = torch.ones(15, 16)
|
src = torch.ones(15, 16)
|
||||||
y = torch.slice_scatter(base, src, dim=0, start=0, end=-1)
|
y = torch.slice_scatter(base, src, dim=0, start=0, end=-1)
|
||||||
return x[0:-1, :], y
|
return x[0:-1, :], y
|
||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
pass_config=PassConfig(enable_noop=True),
|
pass_config=PassConfig(enable_noop=True),
|
||||||
))
|
)
|
||||||
|
)
|
||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
backend = TestBackend(noop_pass)
|
backend = TestBackend(noop_pass)
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ def test_bad_callable():
|
|||||||
|
|
||||||
# Pass that inherits from InductorPass
|
# Pass that inherits from InductorPass
|
||||||
class ProperPass(InductorPass):
|
class ProperPass(InductorPass):
|
||||||
|
|
||||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -39,8 +38,7 @@ class ProperPass(InductorPass):
|
|||||||
ProperPass(),
|
ProperPass(),
|
||||||
# Can also wrap callables in CallableInductorPass for compliance
|
# Can also wrap callables in CallableInductorPass for compliance
|
||||||
CallableInductorPass(simple_callable),
|
CallableInductorPass(simple_callable),
|
||||||
CallableInductorPass(simple_callable,
|
CallableInductorPass(simple_callable, InductorPass.hash_source(__file__)),
|
||||||
InductorPass.hash_source(__file__))
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_pass_manager_uuid(callable):
|
def test_pass_manager_uuid(callable):
|
||||||
@@ -65,8 +63,9 @@ def test_pass_manager_uuid(callable):
|
|||||||
|
|
||||||
# UUID should be different due to config change
|
# UUID should be different due to config change
|
||||||
config2 = copy.deepcopy(config)
|
config2 = copy.deepcopy(config)
|
||||||
config2.compilation_config.pass_config.enable_fusion = not \
|
config2.compilation_config.pass_config.enable_fusion = (
|
||||||
config2.compilation_config.pass_config.enable_fusion
|
not config2.compilation_config.pass_config.enable_fusion
|
||||||
|
)
|
||||||
pass_manager3 = PostGradPassManager()
|
pass_manager3 = PostGradPassManager()
|
||||||
pass_manager3.configure(config2)
|
pass_manager3.configure(config2)
|
||||||
pass_manager3.add(callable)
|
pass_manager3.add(callable)
|
||||||
|
|||||||
@@ -12,14 +12,20 @@ from vllm.compilation.noop_elimination import NoOpEliminationPass
|
|||||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||||
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (
|
||||||
PassConfig, VllmConfig)
|
CompilationConfig,
|
||||||
|
DeviceConfig,
|
||||||
|
ModelConfig,
|
||||||
|
PassConfig,
|
||||||
|
VllmConfig,
|
||||||
|
)
|
||||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
from vllm.distributed.parallel_state import (
|
||||||
initialize_model_parallel)
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||||
Fp8LinearOp)
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
@@ -36,16 +42,15 @@ prompts = [
|
|||||||
|
|
||||||
|
|
||||||
class TestModel(torch.nn.Module):
|
class TestModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
|
||||||
hidden_size=16,
|
):
|
||||||
intermediate_size=32,
|
|
||||||
vllm_config: VllmConfig = None):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.gate_proj = torch.nn.Parameter(
|
self.gate_proj = torch.nn.Parameter(
|
||||||
torch.empty((intermediate_size, hidden_size)))
|
torch.empty((intermediate_size, hidden_size))
|
||||||
|
)
|
||||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||||
@@ -64,7 +69,7 @@ class TestModel(torch.nn.Module):
|
|||||||
# Reshape input
|
# Reshape input
|
||||||
view = hidden_states.reshape(-1, self.hidden_size)
|
view = hidden_states.reshape(-1, self.hidden_size)
|
||||||
|
|
||||||
#matrix multiplication
|
# matrix multiplication
|
||||||
permute = self.gate_proj.permute(1, 0)
|
permute = self.gate_proj.permute(1, 0)
|
||||||
mm = torch.mm(view, permute)
|
mm = torch.mm(view, permute)
|
||||||
|
|
||||||
@@ -82,7 +87,7 @@ class TestModel(torch.nn.Module):
|
|||||||
def ops_in_model_after(self):
|
def ops_in_model_after(self):
|
||||||
return [
|
return [
|
||||||
torch.ops.vllm.reduce_scatter.default,
|
torch.ops.vllm.reduce_scatter.default,
|
||||||
torch.ops.vllm.all_gather.default
|
torch.ops.vllm.all_gather.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
def ops_in_model(self):
|
def ops_in_model(self):
|
||||||
@@ -90,18 +95,16 @@ class TestModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestQuantModel(torch.nn.Module):
|
class TestQuantModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
|
||||||
hidden_size=16,
|
):
|
||||||
intermediate_size=32,
|
|
||||||
vllm_config: VllmConfig = None):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.gate_proj = torch.nn.Parameter(torch.empty(
|
self.gate_proj = torch.nn.Parameter(
|
||||||
(intermediate_size, hidden_size)),
|
torch.empty((intermediate_size, hidden_size)), requires_grad=False
|
||||||
requires_grad=False)
|
)
|
||||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||||
@@ -111,8 +114,7 @@ class TestQuantModel(torch.nn.Module):
|
|||||||
self.scale = torch.rand(1, dtype=torch.float32)
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
# Create a weight that is compatible with torch._scaled_mm,
|
# Create a weight that is compatible with torch._scaled_mm,
|
||||||
# which expects a column-major layout.
|
# which expects a column-major layout.
|
||||||
self.w = torch.rand(hidden_size,
|
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||||
intermediate_size).to(dtype=FP8_DTYPE).t()
|
|
||||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual):
|
def forward(self, hidden_states, residual):
|
||||||
@@ -129,7 +131,7 @@ class TestQuantModel(torch.nn.Module):
|
|||||||
# Reshape input
|
# Reshape input
|
||||||
view = hidden_states.reshape(-1, self.hidden_size)
|
view = hidden_states.reshape(-1, self.hidden_size)
|
||||||
|
|
||||||
#matrix multiplication
|
# matrix multiplication
|
||||||
permute = self.gate_proj.permute(1, 0)
|
permute = self.gate_proj.permute(1, 0)
|
||||||
mm = torch.mm(view, permute)
|
mm = torch.mm(view, permute)
|
||||||
|
|
||||||
@@ -140,45 +142,51 @@ class TestQuantModel(torch.nn.Module):
|
|||||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||||
|
|
||||||
# scaled_mm with static input quantization
|
# scaled_mm with static input quantization
|
||||||
fp8_linear_result = self.fp8_linear.apply(norm_output,
|
fp8_linear_result = self.fp8_linear.apply(
|
||||||
|
norm_output,
|
||||||
self.w,
|
self.w,
|
||||||
self.wscale,
|
self.wscale,
|
||||||
input_scale=self.scale.to(
|
input_scale=self.scale.to(norm_output.device),
|
||||||
norm_output.device))
|
)
|
||||||
|
|
||||||
return fp8_linear_result, residual_output
|
return fp8_linear_result, residual_output
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
ops_to_remove = [torch.ops.vllm.all_reduce.default
|
ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP
|
||||||
] # Always removed by SP
|
|
||||||
# The following are only removed if fusion happens
|
# The following are only removed if fusion happens
|
||||||
if self.vllm_config and self.vllm_config.compilation_config \
|
if (
|
||||||
.pass_config.enable_fusion:
|
self.vllm_config
|
||||||
ops_to_remove.extend([
|
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
||||||
|
):
|
||||||
|
ops_to_remove.extend(
|
||||||
|
[
|
||||||
torch.ops._C.fused_add_rms_norm.default,
|
torch.ops._C.fused_add_rms_norm.default,
|
||||||
torch.ops._C.static_scaled_fp8_quant.default,
|
torch.ops._C.static_scaled_fp8_quant.default,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
return ops_to_remove
|
return ops_to_remove
|
||||||
|
|
||||||
def ops_in_model_after(self):
|
def ops_in_model_after(self):
|
||||||
ops_to_add = [
|
ops_to_add = [
|
||||||
torch.ops.vllm.reduce_scatter.default,
|
torch.ops.vllm.reduce_scatter.default,
|
||||||
torch.ops.vllm.all_gather.default
|
torch.ops.vllm.all_gather.default,
|
||||||
]
|
]
|
||||||
# The following is only added if fusion happens
|
# The following is only added if fusion happens
|
||||||
if self.vllm_config and self.vllm_config.compilation_config \
|
if (
|
||||||
.pass_config.enable_fusion:
|
self.vllm_config
|
||||||
ops_to_add.append(
|
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
||||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
|
):
|
||||||
|
ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
|
||||||
return ops_to_add
|
return ops_to_add
|
||||||
|
|
||||||
def ops_in_model(self):
|
def ops_in_model(self):
|
||||||
if self.vllm_config and self.vllm_config.compilation_config \
|
if (
|
||||||
.pass_config.enable_fusion:
|
self.vllm_config
|
||||||
|
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
||||||
|
):
|
||||||
# If fusion happens, the fused op is the one
|
# If fusion happens, the fused op is the one
|
||||||
# we check for (de)functionalization
|
# we check for (de)functionalization
|
||||||
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
|
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] # noqa: E501
|
||||||
] # noqa: E501
|
|
||||||
else:
|
else:
|
||||||
# If no fusion, the original ops are checked
|
# If no fusion, the original ops are checked
|
||||||
return [
|
return [
|
||||||
@@ -195,30 +203,47 @@ class TestQuantModel(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize("hidden_size", [16])
|
@pytest.mark.parametrize("hidden_size", [16])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("enable_fusion", [True, False])
|
@pytest.mark.parametrize("enable_fusion", [True, False])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||||
reason="Only test on CUDA")
|
def test_sequence_parallelism_pass(
|
||||||
def test_sequence_parallelism_pass(test_model_cls: type[torch.nn.Module],
|
test_model_cls: type[torch.nn.Module],
|
||||||
batch_size: int, seq_len: int,
|
batch_size: int,
|
||||||
hidden_size: int, dtype: torch.dtype,
|
seq_len: int,
|
||||||
enable_fusion: bool):
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
enable_fusion: bool,
|
||||||
|
):
|
||||||
num_processes = 2
|
num_processes = 2
|
||||||
|
|
||||||
def run_torch_spawn(fn, nprocs):
|
def run_torch_spawn(fn, nprocs):
|
||||||
# need to use torch.mp.spawn otherwise will have problems with
|
# need to use torch.mp.spawn otherwise will have problems with
|
||||||
# torch.distributed and cuda
|
# torch.distributed and cuda
|
||||||
torch.multiprocessing.spawn(fn,
|
torch.multiprocessing.spawn(
|
||||||
args=(num_processes, test_model_cls,
|
fn,
|
||||||
batch_size, seq_len, hidden_size,
|
args=(
|
||||||
dtype, enable_fusion),
|
num_processes,
|
||||||
nprocs=nprocs)
|
test_model_cls,
|
||||||
|
batch_size,
|
||||||
|
seq_len,
|
||||||
|
hidden_size,
|
||||||
|
dtype,
|
||||||
|
enable_fusion,
|
||||||
|
),
|
||||||
|
nprocs=nprocs,
|
||||||
|
)
|
||||||
|
|
||||||
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
|
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
|
||||||
|
|
||||||
|
|
||||||
def sequence_parallelism_pass_on_test_model(
|
def sequence_parallelism_pass_on_test_model(
|
||||||
local_rank: int, world_size: int,
|
local_rank: int,
|
||||||
test_model_cls: type[torch.nn.Module], batch_size: int, seq_len: int,
|
world_size: int,
|
||||||
hidden_size: int, dtype: torch.dtype, enable_fusion: bool):
|
test_model_cls: type[torch.nn.Module],
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
enable_fusion: bool,
|
||||||
|
):
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
@@ -226,13 +251,15 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
|
|
||||||
update_environment_variables({
|
update_environment_variables(
|
||||||
'RANK': str(local_rank),
|
{
|
||||||
'LOCAL_RANK': str(local_rank),
|
"RANK": str(local_rank),
|
||||||
'WORLD_SIZE': str(world_size),
|
"LOCAL_RANK": str(local_rank),
|
||||||
'MASTER_ADDR': 'localhost',
|
"WORLD_SIZE": str(world_size),
|
||||||
'MASTER_PORT': '12345',
|
"MASTER_ADDR": "localhost",
|
||||||
})
|
"MASTER_PORT": "12345",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# initialize distributed
|
# initialize distributed
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
@@ -240,27 +267,28 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
|
|
||||||
# configure vllm config for SequenceParallelismPass
|
# configure vllm config for SequenceParallelismPass
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
|
vllm_config.compilation_config = CompilationConfig(
|
||||||
|
pass_config=PassConfig(
|
||||||
enable_sequence_parallelism=True,
|
enable_sequence_parallelism=True,
|
||||||
enable_fusion=enable_fusion,
|
enable_fusion=enable_fusion,
|
||||||
enable_noop=True)) # NoOp needed for fusion
|
enable_noop=True,
|
||||||
|
)
|
||||||
|
) # NoOp needed for fusion
|
||||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||||
|
|
||||||
# this is a fake model name to construct the model config
|
# this is a fake model name to construct the model config
|
||||||
# in the vllm_config, it's not really used.
|
# in the vllm_config, it's not really used.
|
||||||
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||||
vllm_config.model_config = ModelConfig(model=model_name,
|
vllm_config.model_config = ModelConfig(
|
||||||
trust_remote_code=True,
|
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||||
dtype=dtype,
|
)
|
||||||
seed=42)
|
|
||||||
|
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||||
func_pass = FixFunctionalizationPass(vllm_config)
|
func_pass = FixFunctionalizationPass(vllm_config)
|
||||||
cleanup_pass = PostCleanupPass(vllm_config)
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
|
|
||||||
passes_for_backend: list[VllmInductorPass] = \
|
passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass]
|
||||||
[noop_pass, sequence_parallelism_pass]
|
|
||||||
|
|
||||||
if enable_fusion:
|
if enable_fusion:
|
||||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||||
@@ -271,12 +299,9 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
backend_no_func = TestBackend(*passes_for_backend)
|
backend_no_func = TestBackend(*passes_for_backend)
|
||||||
backend_func = TestBackend(*passes_for_backend, func_pass)
|
backend_func = TestBackend(*passes_for_backend, func_pass)
|
||||||
|
|
||||||
model = test_model_cls(hidden_size,
|
model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config)
|
||||||
hidden_size * 2,
|
|
||||||
vllm_config=vllm_config)
|
|
||||||
|
|
||||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||||
dtype=dtype)
|
|
||||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||||
|
|
||||||
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
||||||
@@ -297,8 +322,7 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
# check if the functionalization pass is applied
|
# check if the functionalization pass is applied
|
||||||
for op in model.ops_in_model():
|
for op in model.ops_in_model():
|
||||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
|
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501
|
||||||
op) is None # noqa: E501
|
|
||||||
|
|
||||||
# make sure the ops were all de-functionalized
|
# make sure the ops were all de-functionalized
|
||||||
found = dict()
|
found = dict()
|
||||||
|
|||||||
@@ -8,10 +8,15 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
|
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
|
||||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
|
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.compilation.activation_quant_fusion import (
|
from vllm.compilation.activation_quant_fusion import (
|
||||||
FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass)
|
FUSED_OPS,
|
||||||
|
SILU_MUL_OP,
|
||||||
|
ActivationQuantFusionPass,
|
||||||
|
)
|
||||||
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.compilation.fusion import QUANT_OPS
|
from vllm.compilation.fusion import QUANT_OPS
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
@@ -19,9 +24,14 @@ from vllm.compilation.post_cleanup import PostCleanupPass
|
|||||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape, kFp8StaticTensorSym, kNvfp4Quant)
|
GroupShape,
|
||||||
|
kFp8StaticTensorSym,
|
||||||
|
kNvfp4Quant,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
Fp8LinearOp, cutlass_fp8_supported)
|
Fp8LinearOp,
|
||||||
|
cutlass_fp8_supported,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import override_cutlass_fp8_supported
|
from ..utils import override_cutlass_fp8_supported
|
||||||
@@ -36,7 +46,6 @@ def is_nvfp4_supported():
|
|||||||
|
|
||||||
|
|
||||||
class TestSiluMulFp8QuantModel(torch.nn.Module):
|
class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
|
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.silu_and_mul = SiluAndMul()
|
self.silu_and_mul = SiluAndMul()
|
||||||
@@ -53,10 +62,7 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = self.silu_and_mul(x)
|
y = self.silu_and_mul(x)
|
||||||
x2 = self.fp8_linear.apply(y,
|
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
|
||||||
self.w,
|
|
||||||
self.wscale,
|
|
||||||
input_scale=self.wscale)
|
|
||||||
return x2
|
return x2
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
@@ -67,11 +73,12 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
|
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from vllm.compilation.activation_quant_fusion import (
|
from vllm.compilation.activation_quant_fusion import (
|
||||||
silu_and_mul_nvfp4_quant_supported)
|
silu_and_mul_nvfp4_quant_supported,
|
||||||
|
)
|
||||||
|
|
||||||
assert silu_and_mul_nvfp4_quant_supported
|
assert silu_and_mul_nvfp4_quant_supported
|
||||||
|
|
||||||
self.silu_and_mul = SiluAndMul()
|
self.silu_and_mul = SiluAndMul()
|
||||||
@@ -88,12 +95,14 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = self.silu_and_mul(x)
|
y = self.silu_and_mul(x)
|
||||||
y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale)
|
y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale)
|
||||||
out = cutlass_scaled_fp4_mm(a=y_quant,
|
out = cutlass_scaled_fp4_mm(
|
||||||
|
a=y_quant,
|
||||||
b=self.w,
|
b=self.w,
|
||||||
block_scale_a=y_block_scale,
|
block_scale_a=y_block_scale,
|
||||||
block_scale_b=self.w_block_scale,
|
block_scale_b=self.w_block_scale,
|
||||||
alpha=self.alpha,
|
alpha=self.alpha,
|
||||||
out_dtype=y.dtype)
|
out_dtype=y.dtype,
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
@@ -108,16 +117,24 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_class",
|
"model_class",
|
||||||
cast(list[type], [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
|
cast(
|
||||||
if is_nvfp4_supported() else [TestSiluMulFp8QuantModel]))
|
list[type],
|
||||||
|
[TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
|
||||||
|
if is_nvfp4_supported()
|
||||||
|
else [TestSiluMulFp8QuantModel],
|
||||||
|
),
|
||||||
|
)
|
||||||
# cuda_force_torch used to test torch code path on platforms that
|
# cuda_force_torch used to test torch code path on platforms that
|
||||||
# cutlass_fp8_supported() == True.
|
# cutlass_fp8_supported() == True.
|
||||||
@pytest.mark.parametrize("cuda_force_torch",
|
@pytest.mark.parametrize(
|
||||||
[True, False] if cutlass_fp8_supported() else [True])
|
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
)
|
||||||
reason="Only test on CUDA and ROCm")
|
@pytest.mark.skipif(
|
||||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
|
||||||
cuda_force_torch):
|
)
|
||||||
|
def test_fusion_silu_and_mul_quant(
|
||||||
|
num_tokens, hidden_size, dtype, model_class, cuda_force_torch
|
||||||
|
):
|
||||||
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
|
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
|
||||||
pytest.skip("Duplicate tests for NVFP4")
|
pytest.skip("Duplicate tests for NVFP4")
|
||||||
|
|
||||||
@@ -129,17 +146,13 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
|||||||
# Reshape pass is needed for the fusion pass to work
|
# Reshape pass is needed for the fusion pass to work
|
||||||
config = VllmConfig()
|
config = VllmConfig()
|
||||||
config.compilation_config = CompilationConfig(
|
config.compilation_config = CompilationConfig(
|
||||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
|
pass_config=PassConfig(enable_fusion=True, enable_noop=True)
|
||||||
|
)
|
||||||
fusion_pass = ActivationQuantFusionPass(config)
|
fusion_pass = ActivationQuantFusionPass(config)
|
||||||
|
|
||||||
passes = [
|
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
|
||||||
NoOpEliminationPass(config), fusion_pass,
|
|
||||||
PostCleanupPass(config)
|
|
||||||
]
|
|
||||||
backend = TestBackend(*passes)
|
backend = TestBackend(*passes)
|
||||||
model = model_class(hidden_size=hidden_size,
|
model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x)
|
||||||
cuda_force_torch=cuda_force_torch,
|
|
||||||
x=x)
|
|
||||||
|
|
||||||
# First dimension dynamic
|
# First dimension dynamic
|
||||||
torch._dynamo.mark_dynamic(x, 0)
|
torch._dynamo.mark_dynamic(x, 0)
|
||||||
@@ -155,10 +168,9 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
|||||||
elif model_class == TestSiluMulNvfp4QuantModel:
|
elif model_class == TestSiluMulNvfp4QuantModel:
|
||||||
atol, rtol = 1e-1, 1e-1
|
atol, rtol = 1e-1, 1e-1
|
||||||
|
|
||||||
torch.testing.assert_close(result[0].to(dtype=dtype),
|
torch.testing.assert_close(
|
||||||
result2[0].to(dtype=dtype),
|
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
|
||||||
atol=atol,
|
)
|
||||||
rtol=rtol)
|
|
||||||
|
|
||||||
assert fusion_pass.matched_count == 1
|
assert fusion_pass.matched_count == 1
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from vllm.config import CompilationLevel
|
|||||||
|
|
||||||
|
|
||||||
class MyMod(torch.nn.Module):
|
class MyMod(torch.nn.Module):
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
return x + cache
|
return x + cache
|
||||||
@@ -18,12 +17,12 @@ class MyMod(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
|
class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
|
||||||
|
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self.model = model
|
self.model = model
|
||||||
compiled_callable = torch.compile(self.forward, backend="eager")
|
compiled_callable = torch.compile(self.forward, backend="eager")
|
||||||
super().__init__(compiled_callable,
|
super().__init__(
|
||||||
compilation_level=CompilationLevel.DYNAMO_ONCE)
|
compiled_callable, compilation_level=CompilationLevel.DYNAMO_ONCE
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||||
# this is the function to be compiled
|
# this is the function to be compiled
|
||||||
@@ -54,10 +53,8 @@ def test_torch_compile_wrapper():
|
|||||||
|
|
||||||
# for new input, dispatch to the compiled code directly
|
# for new input, dispatch to the compiled code directly
|
||||||
new_x = torch.tensor([3])
|
new_x = torch.tensor([3])
|
||||||
assert wrapper(new_x,
|
assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code
|
||||||
None).item() == 6 # dispatch to the first compiled code
|
assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code
|
||||||
assert wrapper(
|
|
||||||
new_x, cache).item() == 5 # dispatch to the second compiled code
|
|
||||||
|
|
||||||
for wrapper in wrappers:
|
for wrapper in wrappers:
|
||||||
# make sure they have independent compiled codes
|
# make sure they have independent compiled codes
|
||||||
|
|||||||
@@ -14,8 +14,9 @@ def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def create_config():
|
def create_config():
|
||||||
engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite",
|
engine_args = EngineArgs(
|
||||||
trust_remote_code=True)
|
model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True
|
||||||
|
)
|
||||||
return engine_args.create_engine_config()
|
return engine_args.create_engine_config()
|
||||||
|
|
||||||
# Create config with CUDA_VISIBLE_DEVICES set normally
|
# Create config with CUDA_VISIBLE_DEVICES set normally
|
||||||
@@ -34,16 +35,18 @@ def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch):
|
|||||||
empty_config_dict.pop("instance_id", None)
|
empty_config_dict.pop("instance_id", None)
|
||||||
|
|
||||||
assert deep_compare(normal_config_dict, empty_config_dict), (
|
assert deep_compare(normal_config_dict, empty_config_dict), (
|
||||||
"Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\""
|
'Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=""'
|
||||||
" should be equivalent")
|
" should be equivalent"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
|
def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
|
||||||
# In testing, this method needs to be nested inside as ray does not
|
# In testing, this method needs to be nested inside as ray does not
|
||||||
# see the test module.
|
# see the test module.
|
||||||
def create_config():
|
def create_config():
|
||||||
engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite",
|
engine_args = EngineArgs(
|
||||||
trust_remote_code=True)
|
model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True
|
||||||
|
)
|
||||||
return engine_args.create_engine_config()
|
return engine_args.create_engine_config()
|
||||||
|
|
||||||
config = create_config()
|
config = create_config()
|
||||||
@@ -51,6 +54,7 @@ def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert parallel_config.ray_runtime_env is None
|
assert parallel_config.ray_runtime_env is None
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
ray.init()
|
ray.init()
|
||||||
|
|
||||||
runtime_env = {
|
runtime_env = {
|
||||||
@@ -59,13 +63,13 @@ def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
config_ref = ray.remote(create_config).options(
|
config_ref = ray.remote(create_config).options(runtime_env=runtime_env).remote()
|
||||||
runtime_env=runtime_env).remote()
|
|
||||||
|
|
||||||
config = ray.get(config_ref)
|
config = ray.get(config_ref)
|
||||||
parallel_config = config.parallel_config
|
parallel_config = config.parallel_config
|
||||||
assert parallel_config.ray_runtime_env is not None
|
assert parallel_config.ray_runtime_env is not None
|
||||||
assert parallel_config.ray_runtime_env.env_vars().get(
|
assert (
|
||||||
"TEST_ENV_VAR") == "test_value"
|
parallel_config.ray_runtime_env.env_vars().get("TEST_ENV_VAR") == "test_value"
|
||||||
|
)
|
||||||
|
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|||||||
@@ -16,13 +16,13 @@ def test_mp_reducer(monkeypatch):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Use V1 AsyncLLM which calls maybe_register_config_serialize_by_value
|
# Use V1 AsyncLLM which calls maybe_register_config_serialize_by_value
|
||||||
monkeypatch.setenv('VLLM_USE_V1', '1')
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
# Ensure transformers_modules is not in sys.modules
|
# Ensure transformers_modules is not in sys.modules
|
||||||
if 'transformers_modules' in sys.modules:
|
if "transformers_modules" in sys.modules:
|
||||||
del sys.modules['transformers_modules']
|
del sys.modules["transformers_modules"]
|
||||||
|
|
||||||
with patch('multiprocessing.reducer.register') as mock_register:
|
with patch("multiprocessing.reducer.register") as mock_register:
|
||||||
engine_args = AsyncEngineArgs(
|
engine_args = AsyncEngineArgs(
|
||||||
model="facebook/opt-125m",
|
model="facebook/opt-125m",
|
||||||
max_model_len=32,
|
max_model_len=32,
|
||||||
@@ -36,7 +36,8 @@ def test_mp_reducer(monkeypatch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert mock_register.called, (
|
assert mock_register.called, (
|
||||||
"multiprocessing.reducer.register should have been called")
|
"multiprocessing.reducer.register should have been called"
|
||||||
|
)
|
||||||
|
|
||||||
vllm_config_registered = False
|
vllm_config_registered = False
|
||||||
for call_args in mock_register.call_args_list:
|
for call_args in mock_register.call_args_list:
|
||||||
@@ -45,8 +46,7 @@ def test_mp_reducer(monkeypatch):
|
|||||||
vllm_config_registered = True
|
vllm_config_registered = True
|
||||||
|
|
||||||
reducer_func = call_args[0][1]
|
reducer_func = call_args[0][1]
|
||||||
assert callable(
|
assert callable(reducer_func), "Reducer function should be callable"
|
||||||
reducer_func), "Reducer function should be callable"
|
|
||||||
break
|
break
|
||||||
|
|
||||||
assert vllm_config_registered, (
|
assert vllm_config_registered, (
|
||||||
|
|||||||
@@ -30,22 +30,27 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
from transformers import (
|
||||||
BatchEncoding, BatchFeature)
|
AutoConfig,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
BatchEncoding,
|
||||||
|
BatchFeature,
|
||||||
|
)
|
||||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||||
|
|
||||||
from tests.models.utils import (TokensTextLogprobs,
|
from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs
|
||||||
TokensTextLogprobsPromptLogprobs)
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.assets.audio import AudioAsset
|
from vllm.assets.audio import AudioAsset
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.assets.video import VideoAsset
|
from vllm.assets.video import VideoAsset
|
||||||
from vllm.config.model import (ConvertOption, RunnerOption,
|
from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype
|
||||||
_get_and_verify_dtype)
|
|
||||||
from vllm.connections import global_http_connection
|
from vllm.connections import global_http_connection
|
||||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
from vllm.distributed import (
|
||||||
|
cleanup_dist_env_and_memory,
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
initialize_model_parallel)
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logprobs import Logprob
|
from vllm.logprobs import Logprob
|
||||||
from vllm.multimodal.utils import fetch_image
|
from vllm.multimodal.utils import fetch_image
|
||||||
@@ -82,12 +87,13 @@ class ImageAssetPrompts(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
class ImageTestAssets(list[ImageAsset]):
|
class ImageTestAssets(list[ImageAsset]):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__([
|
super().__init__(
|
||||||
|
[
|
||||||
ImageAsset("stop_sign"),
|
ImageAsset("stop_sign"),
|
||||||
ImageAsset("cherry_blossom"),
|
ImageAsset("cherry_blossom"),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
|
def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@@ -104,11 +110,12 @@ class VideoAssetPrompts(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
class VideoTestAssets(list[VideoAsset]):
|
class VideoTestAssets(list[VideoAsset]):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__([
|
super().__init__(
|
||||||
|
[
|
||||||
VideoAsset("baby_reading"),
|
VideoAsset("baby_reading"),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
|
def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
|
||||||
return [prompts["baby_reading"]]
|
return [prompts["baby_reading"]]
|
||||||
@@ -120,12 +127,13 @@ class AudioAssetPrompts(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
class AudioTestAssets(list[AudioAsset]):
|
class AudioTestAssets(list[AudioAsset]):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__([
|
super().__init__(
|
||||||
|
[
|
||||||
AudioAsset("mary_had_lamb"),
|
AudioAsset("mary_had_lamb"),
|
||||||
AudioAsset("winning_call"),
|
AudioAsset("winning_call"),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
|
def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
|
||||||
return [prompts["mary_had_lamb"], prompts["winning_call"]]
|
return [prompts["mary_had_lamb"], prompts["winning_call"]]
|
||||||
@@ -220,6 +228,7 @@ def example_system_message() -> str:
|
|||||||
|
|
||||||
class DecoderPromptType(Enum):
|
class DecoderPromptType(Enum):
|
||||||
"""For encoder/decoder models only."""
|
"""For encoder/decoder models only."""
|
||||||
|
|
||||||
CUSTOM = 1
|
CUSTOM = 1
|
||||||
NONE = 2
|
NONE = 2
|
||||||
EMPTY_STR = 3
|
EMPTY_STR = 3
|
||||||
@@ -253,15 +262,13 @@ _R = TypeVar("_R")
|
|||||||
|
|
||||||
|
|
||||||
class HfRunner:
|
class HfRunner:
|
||||||
|
|
||||||
def get_default_device(self):
|
def get_default_device(self):
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
return ("cpu"
|
return "cpu" if current_platform.is_cpu() else current_platform.device_type
|
||||||
if current_platform.is_cpu() else current_platform.device_type)
|
|
||||||
|
|
||||||
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
||||||
if x is None or isinstance(x, (bool, )):
|
if x is None or isinstance(x, (bool,)):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
@@ -289,8 +296,11 @@ class HfRunner:
|
|||||||
# Set this to avoid hanging issue
|
# Set this to avoid hanging issue
|
||||||
default_torch_num_threads: Optional[int] = None,
|
default_torch_num_threads: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
init_ctx = (nullcontext() if default_torch_num_threads is None else
|
init_ctx = (
|
||||||
set_default_torch_num_threads(default_torch_num_threads))
|
nullcontext()
|
||||||
|
if default_torch_num_threads is None
|
||||||
|
else set_default_torch_num_threads(default_torch_num_threads)
|
||||||
|
)
|
||||||
|
|
||||||
with init_ctx:
|
with init_ctx:
|
||||||
self._init(
|
self._init(
|
||||||
@@ -362,14 +372,15 @@ class HfRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# in case some unquantized custom models are not in same dtype
|
# in case some unquantized custom models are not in same dtype
|
||||||
if (getattr(model, "quantization_method", None) is None
|
if getattr(model, "quantization_method", None) is None and any(
|
||||||
and any(p.dtype != self.dtype
|
p.dtype != self.dtype for p in model.parameters()
|
||||||
for p in model.parameters())):
|
):
|
||||||
model = model.to(dtype=self.dtype)
|
model = model.to(dtype=self.dtype)
|
||||||
|
|
||||||
if (getattr(model, "quantization_method", None) != "bitsandbytes"
|
if (
|
||||||
and len({p.device
|
getattr(model, "quantization_method", None) != "bitsandbytes"
|
||||||
for p in model.parameters()}) < 2):
|
and len({p.device for p in model.parameters()}) < 2
|
||||||
|
):
|
||||||
model = model.to(device=self.device)
|
model = model.to(device=self.device)
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -384,6 +395,7 @@ class HfRunner:
|
|||||||
# don't put this import at the top level
|
# don't put this import at the top level
|
||||||
# it will call torch.cuda.device_count()
|
# it will call torch.cuda.device_count()
|
||||||
from transformers import AutoProcessor # noqa: F401
|
from transformers import AutoProcessor # noqa: F401
|
||||||
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -471,10 +483,9 @@ class HfRunner:
|
|||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[tuple[list[list[int]], list[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
all_inputs = self.get_inputs(prompts,
|
all_inputs = self.get_inputs(
|
||||||
images=images,
|
prompts, images=images, videos=videos, audios=audios
|
||||||
videos=videos,
|
)
|
||||||
audios=audios)
|
|
||||||
|
|
||||||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||||
for inputs in all_inputs:
|
for inputs in all_inputs:
|
||||||
@@ -501,16 +512,17 @@ class HfRunner:
|
|||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[tuple[list[int], str]]:
|
) -> list[tuple[list[int], str]]:
|
||||||
outputs = self.generate(prompts,
|
outputs = self.generate(
|
||||||
|
prompts,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
images=images,
|
images=images,
|
||||||
videos=videos,
|
videos=videos,
|
||||||
audios=audios,
|
audios=audios,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
return [(output_ids[0], output_str[0])
|
return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
|
||||||
for output_ids, output_str in outputs]
|
|
||||||
|
|
||||||
def generate_beam_search(
|
def generate_beam_search(
|
||||||
self,
|
self,
|
||||||
@@ -521,21 +533,22 @@ class HfRunner:
|
|||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
) -> list[tuple[list[list[int]], list[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
outputs = self.generate(prompts,
|
outputs = self.generate(
|
||||||
|
prompts,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
num_beams=beam_width,
|
num_beams=beam_width,
|
||||||
num_return_sequences=beam_width,
|
num_return_sequences=beam_width,
|
||||||
images=images,
|
images=images,
|
||||||
videos=videos,
|
videos=videos,
|
||||||
audios=audios)
|
audios=audios,
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(len(outputs)):
|
for i in range(len(outputs)):
|
||||||
output_ids, output_str = outputs[i]
|
output_ids, output_str = outputs[i]
|
||||||
for j in range(len(output_ids)):
|
for j in range(len(output_ids)):
|
||||||
output_ids[j] = [
|
output_ids[j] = [
|
||||||
x for x in output_ids[j]
|
x for x in output_ids[j] if x != self.tokenizer.pad_token_id
|
||||||
if x != self.tokenizer.pad_token_id
|
|
||||||
]
|
]
|
||||||
outputs[i] = (output_ids, output_str)
|
outputs[i] = (output_ids, output_str)
|
||||||
return outputs
|
return outputs
|
||||||
@@ -549,10 +562,9 @@ class HfRunner:
|
|||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[list[torch.Tensor]]:
|
) -> list[list[torch.Tensor]]:
|
||||||
all_inputs = self.get_inputs(prompts,
|
all_inputs = self.get_inputs(
|
||||||
images=images,
|
prompts, images=images, videos=videos, audios=audios
|
||||||
videos=videos,
|
)
|
||||||
audios=audios)
|
|
||||||
|
|
||||||
all_logprobs: list[list[torch.Tensor]] = []
|
all_logprobs: list[list[torch.Tensor]] = []
|
||||||
for inputs in all_inputs:
|
for inputs in all_inputs:
|
||||||
@@ -565,8 +577,7 @@ class HfRunner:
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
seq_logprobs = self._hidden_states_to_seq_logprobs(
|
seq_logprobs = self._hidden_states_to_seq_logprobs(output.hidden_states)
|
||||||
output.hidden_states)
|
|
||||||
all_logprobs.append(seq_logprobs)
|
all_logprobs.append(seq_logprobs)
|
||||||
return all_logprobs
|
return all_logprobs
|
||||||
|
|
||||||
@@ -630,10 +641,9 @@ class HfRunner:
|
|||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[TokensTextLogprobs]:
|
) -> list[TokensTextLogprobs]:
|
||||||
all_inputs = self.get_inputs(prompts,
|
all_inputs = self.get_inputs(
|
||||||
images=images,
|
prompts, images=images, videos=videos, audios=audios
|
||||||
videos=videos,
|
)
|
||||||
audios=audios)
|
|
||||||
|
|
||||||
all_logprobs: list[list[dict[int, float]]] = []
|
all_logprobs: list[list[dict[int, float]]] = []
|
||||||
all_output_ids: list[list[int]] = []
|
all_output_ids: list[list[int]] = []
|
||||||
@@ -653,8 +663,7 @@ class HfRunner:
|
|||||||
(
|
(
|
||||||
seq_logprobs_lst,
|
seq_logprobs_lst,
|
||||||
output_len,
|
output_len,
|
||||||
) = self._hidden_states_to_logprobs(output.hidden_states,
|
) = self._hidden_states_to_logprobs(output.hidden_states, num_logprobs)
|
||||||
num_logprobs)
|
|
||||||
|
|
||||||
all_logprobs.append(seq_logprobs_lst)
|
all_logprobs.append(seq_logprobs_lst)
|
||||||
seq_ids = output.sequences[0]
|
seq_ids = output.sequences[0]
|
||||||
@@ -664,19 +673,16 @@ class HfRunner:
|
|||||||
all_output_strs.append(self.tokenizer.decode(output_ids))
|
all_output_strs.append(self.tokenizer.decode(output_ids))
|
||||||
|
|
||||||
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
|
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
|
||||||
return [(output_ids, output_str, output_logprobs)
|
return [
|
||||||
for output_ids, output_str, output_logprobs in outputs]
|
(output_ids, output_str, output_logprobs)
|
||||||
|
for output_ids, output_str, output_logprobs in outputs
|
||||||
|
]
|
||||||
|
|
||||||
def encode(self, prompts: list[str], *args,
|
def encode(self, prompts: list[str], *args, **kwargs) -> list[list[torch.Tensor]]:
|
||||||
**kwargs) -> list[list[torch.Tensor]]:
|
|
||||||
return self.model.encode(prompts, *args, **kwargs)
|
return self.model.encode(prompts, *args, **kwargs)
|
||||||
|
|
||||||
def predict(self, prompts: list[list[str]], *args,
|
def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
|
||||||
**kwargs) -> torch.Tensor:
|
return self.model.predict(prompts, *args, convert_to_tensor=True, **kwargs)
|
||||||
return self.model.predict(prompts,
|
|
||||||
*args,
|
|
||||||
convert_to_tensor=True,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
@@ -727,8 +733,11 @@ class VllmRunner:
|
|||||||
default_torch_num_threads: Optional[int] = None,
|
default_torch_num_threads: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
init_ctx = (nullcontext() if default_torch_num_threads is None else
|
init_ctx = (
|
||||||
set_default_torch_num_threads(default_torch_num_threads))
|
nullcontext()
|
||||||
|
if default_torch_num_threads is None
|
||||||
|
else set_default_torch_num_threads(default_torch_num_threads)
|
||||||
|
)
|
||||||
|
|
||||||
if not kwargs.get("compilation_config", None):
|
if not kwargs.get("compilation_config", None):
|
||||||
kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]}
|
kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]}
|
||||||
@@ -760,11 +769,12 @@ class VllmRunner:
|
|||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
if any(x is not None and len(x) != len(prompts)
|
if any(
|
||||||
for x in [images, videos, audios]):
|
x is not None and len(x) != len(prompts) for x in [images, videos, audios]
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"All non-None multimodal inputs must have the same length as "
|
"All non-None multimodal inputs must have the same length as prompts"
|
||||||
"prompts")
|
)
|
||||||
|
|
||||||
inputs = list[dict[str, Any]]()
|
inputs = list[dict[str, Any]]()
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
@@ -800,14 +810,11 @@ class VllmRunner:
|
|||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[tuple[list[list[int]], list[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
inputs = self.get_inputs(prompts,
|
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||||
images=images,
|
|
||||||
videos=videos,
|
|
||||||
audios=audios)
|
|
||||||
|
|
||||||
req_outputs = self.llm.generate(inputs,
|
req_outputs = self.llm.generate(
|
||||||
sampling_params=sampling_params,
|
inputs, sampling_params=sampling_params, **kwargs
|
||||||
**kwargs)
|
)
|
||||||
|
|
||||||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||||
for req_output in req_outputs:
|
for req_output in req_outputs:
|
||||||
@@ -834,8 +841,9 @@ class VllmRunner:
|
|||||||
output_str = sample.text
|
output_str = sample.text
|
||||||
output_ids = list(sample.token_ids)
|
output_ids = list(sample.token_ids)
|
||||||
output_logprobs = sample.logprobs
|
output_logprobs = sample.logprobs
|
||||||
outputs.append((output_ids, output_str, output_logprobs,
|
outputs.append(
|
||||||
req_output.prompt_logprobs))
|
(output_ids, output_str, output_logprobs, req_output.prompt_logprobs)
|
||||||
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def generate_w_logprobs(
|
def generate_w_logprobs(
|
||||||
@@ -846,23 +854,22 @@ class VllmRunner:
|
|||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[list[TokensTextLogprobs],
|
) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]:
|
||||||
list[TokensTextLogprobsPromptLogprobs]]:
|
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||||
inputs = self.get_inputs(prompts,
|
|
||||||
images=images,
|
|
||||||
videos=videos,
|
|
||||||
audios=audios)
|
|
||||||
|
|
||||||
req_outputs = self.llm.generate(inputs,
|
req_outputs = self.llm.generate(
|
||||||
sampling_params=sampling_params,
|
inputs, sampling_params=sampling_params, **kwargs
|
||||||
**kwargs)
|
)
|
||||||
|
|
||||||
toks_str_logsprobs_prompt_logprobs = (
|
toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs(
|
||||||
self._final_steps_generate_w_logprobs(req_outputs))
|
req_outputs
|
||||||
|
)
|
||||||
# Omit prompt logprobs if not required by sampling params
|
# Omit prompt logprobs if not required by sampling params
|
||||||
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
return (
|
||||||
if sampling_params.prompt_logprobs is None else
|
[x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
||||||
toks_str_logsprobs_prompt_logprobs)
|
if sampling_params.prompt_logprobs is None
|
||||||
|
else toks_str_logsprobs_prompt_logprobs
|
||||||
|
)
|
||||||
|
|
||||||
def generate_greedy(
|
def generate_greedy(
|
||||||
self,
|
self,
|
||||||
@@ -874,14 +881,15 @@ class VllmRunner:
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[tuple[list[int], str]]:
|
) -> list[tuple[list[int], str]]:
|
||||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||||
outputs = self.generate(prompts,
|
outputs = self.generate(
|
||||||
|
prompts,
|
||||||
greedy_params,
|
greedy_params,
|
||||||
images=images,
|
images=images,
|
||||||
videos=videos,
|
videos=videos,
|
||||||
audios=audios,
|
audios=audios,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
return [(output_ids[0], output_str[0])
|
)
|
||||||
for output_ids, output_str in outputs]
|
return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
|
||||||
|
|
||||||
def generate_greedy_logprobs(
|
def generate_greedy_logprobs(
|
||||||
self,
|
self,
|
||||||
@@ -895,22 +903,24 @@ class VllmRunner:
|
|||||||
stop_token_ids: Optional[list[int]] = None,
|
stop_token_ids: Optional[list[int]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[list[TokensTextLogprobs],
|
) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]:
|
||||||
list[TokensTextLogprobsPromptLogprobs]]:
|
|
||||||
greedy_logprobs_params = SamplingParams(
|
greedy_logprobs_params = SamplingParams(
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
logprobs=num_logprobs,
|
logprobs=num_logprobs,
|
||||||
prompt_logprobs=num_prompt_logprobs,
|
prompt_logprobs=num_prompt_logprobs,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
stop=stop)
|
stop=stop,
|
||||||
|
)
|
||||||
|
|
||||||
return self.generate_w_logprobs(prompts,
|
return self.generate_w_logprobs(
|
||||||
|
prompts,
|
||||||
greedy_logprobs_params,
|
greedy_logprobs_params,
|
||||||
images=images,
|
images=images,
|
||||||
audios=audios,
|
audios=audios,
|
||||||
videos=videos,
|
videos=videos,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
|
def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
|
||||||
"""
|
"""
|
||||||
@@ -919,10 +929,9 @@ class VllmRunner:
|
|||||||
:param prompts: list of prompts to score
|
:param prompts: list of prompts to score
|
||||||
:return: perplexity score of each prompt
|
:return: perplexity score of each prompt
|
||||||
"""
|
"""
|
||||||
outputs = self.generate_greedy_logprobs(prompts,
|
outputs = self.generate_greedy_logprobs(
|
||||||
max_tokens=1,
|
prompts, max_tokens=1, num_logprobs=None, num_prompt_logprobs=0
|
||||||
num_logprobs=None,
|
)
|
||||||
num_prompt_logprobs=0)
|
|
||||||
|
|
||||||
perplexities = []
|
perplexities = []
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
@@ -951,15 +960,13 @@ class VllmRunner:
|
|||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
concurrency_limit: Optional[int] = None,
|
concurrency_limit: Optional[int] = None,
|
||||||
) -> list[tuple[list[list[int]], list[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
inputs = self.get_inputs(prompts,
|
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||||
images=images,
|
|
||||||
videos=videos,
|
|
||||||
audios=audios)
|
|
||||||
|
|
||||||
outputs = self.llm.beam_search(inputs,
|
outputs = self.llm.beam_search(
|
||||||
BeamSearchParams(beam_width=beam_width,
|
inputs,
|
||||||
max_tokens=max_tokens),
|
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens),
|
||||||
concurrency_limit=concurrency_limit)
|
concurrency_limit=concurrency_limit,
|
||||||
|
)
|
||||||
returned_outputs = []
|
returned_outputs = []
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
token_ids = [x.tokens for x in output.sequences]
|
token_ids = [x.tokens for x in output.sequences]
|
||||||
@@ -971,17 +978,16 @@ class VllmRunner:
|
|||||||
req_outputs = self.llm.classify(prompts)
|
req_outputs = self.llm.classify(prompts)
|
||||||
return [req_output.outputs.probs for req_output in req_outputs]
|
return [req_output.outputs.probs for req_output in req_outputs]
|
||||||
|
|
||||||
def embed(self,
|
def embed(
|
||||||
|
self,
|
||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs) -> list[list[float]]:
|
**kwargs,
|
||||||
inputs = self.get_inputs(prompts,
|
) -> list[list[float]]:
|
||||||
images=images,
|
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||||
videos=videos,
|
|
||||||
audios=audios)
|
|
||||||
|
|
||||||
req_outputs = self.llm.embed(inputs, *args, **kwargs)
|
req_outputs = self.llm.embed(inputs, *args, **kwargs)
|
||||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||||
@@ -1026,6 +1032,7 @@ def vllm_runner():
|
|||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def temporary_enable_log_propagate():
|
def temporary_enable_log_propagate():
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger("vllm")
|
logger = logging.getLogger("vllm")
|
||||||
logger.propagate = True
|
logger.propagate = True
|
||||||
yield
|
yield
|
||||||
@@ -1045,6 +1052,7 @@ def num_gpus_available():
|
|||||||
in current process."""
|
in current process."""
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
return current_platform.device_count()
|
return current_platform.device_count()
|
||||||
|
|
||||||
|
|
||||||
@@ -1058,12 +1066,11 @@ _dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
|
|||||||
def dummy_opt_path():
|
def dummy_opt_path():
|
||||||
json_path = os.path.join(_dummy_opt_path, "config.json")
|
json_path = os.path.join(_dummy_opt_path, "config.json")
|
||||||
if not os.path.exists(_dummy_opt_path):
|
if not os.path.exists(_dummy_opt_path):
|
||||||
snapshot_download(repo_id="facebook/opt-125m",
|
snapshot_download(
|
||||||
|
repo_id="facebook/opt-125m",
|
||||||
local_dir=_dummy_opt_path,
|
local_dir=_dummy_opt_path,
|
||||||
ignore_patterns=[
|
ignore_patterns=["*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack"],
|
||||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
)
|
||||||
"*.msgpack"
|
|
||||||
])
|
|
||||||
assert os.path.exists(json_path)
|
assert os.path.exists(json_path)
|
||||||
with open(json_path) as f:
|
with open(json_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
@@ -1077,12 +1084,18 @@ def dummy_opt_path():
|
|||||||
def dummy_llava_path():
|
def dummy_llava_path():
|
||||||
json_path = os.path.join(_dummy_llava_path, "config.json")
|
json_path = os.path.join(_dummy_llava_path, "config.json")
|
||||||
if not os.path.exists(_dummy_llava_path):
|
if not os.path.exists(_dummy_llava_path):
|
||||||
snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
|
snapshot_download(
|
||||||
|
repo_id="llava-hf/llava-1.5-7b-hf",
|
||||||
local_dir=_dummy_llava_path,
|
local_dir=_dummy_llava_path,
|
||||||
ignore_patterns=[
|
ignore_patterns=[
|
||||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
"*.bin",
|
||||||
"*.msgpack", "*.safetensors"
|
"*.bin.index.json",
|
||||||
])
|
"*.pt",
|
||||||
|
"*.h5",
|
||||||
|
"*.msgpack",
|
||||||
|
"*.safetensors",
|
||||||
|
],
|
||||||
|
)
|
||||||
assert os.path.exists(json_path)
|
assert os.path.exists(json_path)
|
||||||
with open(json_path) as f:
|
with open(json_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
@@ -1096,12 +1109,18 @@ def dummy_llava_path():
|
|||||||
def dummy_gemma2_embedding_path():
|
def dummy_gemma2_embedding_path():
|
||||||
json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
|
json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
|
||||||
if not os.path.exists(_dummy_gemma2_embedding_path):
|
if not os.path.exists(_dummy_gemma2_embedding_path):
|
||||||
snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
|
snapshot_download(
|
||||||
|
repo_id="BAAI/bge-multilingual-gemma2",
|
||||||
local_dir=_dummy_gemma2_embedding_path,
|
local_dir=_dummy_gemma2_embedding_path,
|
||||||
ignore_patterns=[
|
ignore_patterns=[
|
||||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
"*.bin",
|
||||||
"*.msgpack", "*.safetensors"
|
"*.bin.index.json",
|
||||||
])
|
"*.pt",
|
||||||
|
"*.h5",
|
||||||
|
"*.msgpack",
|
||||||
|
"*.safetensors",
|
||||||
|
],
|
||||||
|
)
|
||||||
assert os.path.exists(json_path)
|
assert os.path.exists(json_path)
|
||||||
with open(json_path) as f:
|
with open(json_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
@@ -1114,10 +1133,9 @@ def dummy_gemma2_embedding_path():
|
|||||||
# Add the flag `--optional` to allow run tests
|
# Add the flag `--optional` to allow run tests
|
||||||
# that are marked with @pytest.mark.optional
|
# that are marked with @pytest.mark.optional
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser):
|
||||||
parser.addoption("--optional",
|
parser.addoption(
|
||||||
action="store_true",
|
"--optional", action="store_true", default=False, help="run optional test"
|
||||||
default=False,
|
)
|
||||||
help="run optional test")
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(config, items):
|
def pytest_collection_modifyitems(config, items):
|
||||||
@@ -1185,7 +1203,6 @@ def _find_free_port() -> int:
|
|||||||
|
|
||||||
|
|
||||||
class LocalAssetServer:
|
class LocalAssetServer:
|
||||||
|
|
||||||
address: str
|
address: str
|
||||||
port: int
|
port: int
|
||||||
server: Optional[http.server.ThreadingHTTPServer]
|
server: Optional[http.server.ThreadingHTTPServer]
|
||||||
@@ -1200,9 +1217,9 @@ class LocalAssetServer:
|
|||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.port = _find_free_port()
|
self.port = _find_free_port()
|
||||||
self.server = http.server.ThreadingHTTPServer(
|
self.server = http.server.ThreadingHTTPServer(
|
||||||
(self.address, self.port), AssetHandler)
|
(self.address, self.port), AssetHandler
|
||||||
self.thread = threading.Thread(target=self.server.serve_forever,
|
)
|
||||||
daemon=True)
|
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
||||||
self.thread.start()
|
self.thread.start()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from vllm.platforms import current_platform
|
|||||||
def check_cuda_context():
|
def check_cuda_context():
|
||||||
"""Check CUDA driver context status"""
|
"""Check CUDA driver context status"""
|
||||||
try:
|
try:
|
||||||
cuda = ctypes.CDLL('libcuda.so')
|
cuda = ctypes.CDLL("libcuda.so")
|
||||||
device = ctypes.c_int()
|
device = ctypes.c_int()
|
||||||
result = cuda.cuCtxGetDevice(ctypes.byref(device))
|
result = cuda.cuCtxGetDevice(ctypes.byref(device))
|
||||||
return (True, device.value) if result == 0 else (False, None)
|
return (True, device.value) if result == 0 else (False, None)
|
||||||
@@ -27,9 +27,11 @@ def run_cuda_test_in_thread(device_input, expected_device_id):
|
|||||||
# New thread should have no CUDA context initially
|
# New thread should have no CUDA context initially
|
||||||
valid_before, device_before = check_cuda_context()
|
valid_before, device_before = check_cuda_context()
|
||||||
if valid_before:
|
if valid_before:
|
||||||
return False, \
|
return (
|
||||||
"CUDA context should not exist in new thread, " \
|
False,
|
||||||
f"got device {device_before}"
|
"CUDA context should not exist in new thread, "
|
||||||
|
f"got device {device_before}",
|
||||||
|
)
|
||||||
|
|
||||||
# Test setting CUDA context
|
# Test setting CUDA context
|
||||||
current_platform.set_device(device_input)
|
current_platform.set_device(device_input)
|
||||||
@@ -39,8 +41,7 @@ def run_cuda_test_in_thread(device_input, expected_device_id):
|
|||||||
if not valid_after:
|
if not valid_after:
|
||||||
return False, "CUDA context should be valid after set_cuda_context"
|
return False, "CUDA context should be valid after set_cuda_context"
|
||||||
if device_id != expected_device_id:
|
if device_id != expected_device_id:
|
||||||
return False, \
|
return False, f"Expected device {expected_device_id}, got {device_id}"
|
||||||
f"Expected device {expected_device_id}, got {device_id}"
|
|
||||||
|
|
||||||
return True, "Success"
|
return True, "Success"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -50,30 +51,30 @@ def run_cuda_test_in_thread(device_input, expected_device_id):
|
|||||||
class TestSetCudaContext:
|
class TestSetCudaContext:
|
||||||
"""Test suite for the set_cuda_context function."""
|
"""Test suite for the set_cuda_context function."""
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||||
reason="CUDA not available")
|
@pytest.mark.parametrize(
|
||||||
@pytest.mark.parametrize(argnames="device_input,expected_device_id",
|
argnames="device_input,expected_device_id",
|
||||||
argvalues=[
|
argvalues=[
|
||||||
(0, 0),
|
(0, 0),
|
||||||
(torch.device('cuda:0'), 0),
|
(torch.device("cuda:0"), 0),
|
||||||
('cuda:0', 0),
|
("cuda:0", 0),
|
||||||
],
|
],
|
||||||
ids=["int", "torch_device", "string"])
|
ids=["int", "torch_device", "string"],
|
||||||
def test_set_cuda_context_parametrized(self, device_input,
|
)
|
||||||
expected_device_id):
|
def test_set_cuda_context_parametrized(self, device_input, expected_device_id):
|
||||||
"""Test setting CUDA context in isolated threads."""
|
"""Test setting CUDA context in isolated threads."""
|
||||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
future = executor.submit(run_cuda_test_in_thread, device_input,
|
future = executor.submit(
|
||||||
expected_device_id)
|
run_cuda_test_in_thread, device_input, expected_device_id
|
||||||
|
)
|
||||||
success, message = future.result(timeout=30)
|
success, message = future.result(timeout=30)
|
||||||
assert success, message
|
assert success, message
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||||
reason="CUDA not available")
|
|
||||||
def test_set_cuda_context_invalid_device_type(self):
|
def test_set_cuda_context_invalid_device_type(self):
|
||||||
"""Test error handling for invalid device type."""
|
"""Test error handling for invalid device type."""
|
||||||
with pytest.raises(ValueError, match="Expected a cuda device"):
|
with pytest.raises(ValueError, match="Expected a cuda device"):
|
||||||
current_platform.set_device(torch.device('cpu'))
|
current_platform.set_device(torch.device("cpu"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -17,20 +17,16 @@ def test_computed_prefix_blocks(model: str):
|
|||||||
prompt = (
|
prompt = (
|
||||||
"You are a helpful assistant. How do I build a car from cardboard and "
|
"You are a helpful assistant. How do I build a car from cardboard and "
|
||||||
"paper clips? Is there an easy to follow video tutorial available "
|
"paper clips? Is there an easy to follow video tutorial available "
|
||||||
"online for free?")
|
"online for free?"
|
||||||
|
)
|
||||||
|
|
||||||
llm = LLM(model=model)
|
llm = LLM(model=model)
|
||||||
sampling_params = SamplingParams(max_tokens=10,
|
sampling_params = SamplingParams(max_tokens=10, temperature=0.0, detokenize=False)
|
||||||
temperature=0.0,
|
|
||||||
detokenize=False)
|
|
||||||
|
|
||||||
outputs_no_detokenization = llm.generate(prompt,
|
outputs_no_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0]
|
||||||
sampling_params)[0].outputs[0]
|
|
||||||
sampling_params.detokenize = True
|
sampling_params.detokenize = True
|
||||||
outputs_with_detokenization = llm.generate(prompt,
|
outputs_with_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0]
|
||||||
sampling_params)[0].outputs[0]
|
|
||||||
|
|
||||||
assert outputs_no_detokenization.text == ''
|
assert outputs_no_detokenization.text == ""
|
||||||
assert outputs_with_detokenization.text != ''
|
assert outputs_with_detokenization.text != ""
|
||||||
assert outputs_no_detokenization.token_ids == \
|
assert outputs_no_detokenization.token_ids == outputs_with_detokenization.token_ids
|
||||||
outputs_with_detokenization.token_ids
|
|
||||||
|
|||||||
@@ -8,15 +8,17 @@ from vllm import SamplingParams
|
|||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.detokenizer import FastIncrementalDetokenizer
|
from vllm.v1.engine.detokenizer import FastIncrementalDetokenizer
|
||||||
|
|
||||||
PROMPT = "Hello, my name is Lee, and I'm a student in the " + \
|
PROMPT = "Hello, my name is Lee, and I'm a student in the " + "college of engineering"
|
||||||
"college of engineering"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("min_tokens,stop,truth", [
|
@pytest.mark.parametrize(
|
||||||
|
"min_tokens,stop,truth",
|
||||||
|
[
|
||||||
(0, None, " is Lee, and I'm a student in the college of engineering"),
|
(0, None, " is Lee, and I'm a student in the college of engineering"),
|
||||||
(0, "e", " is L"),
|
(0, "e", " is L"),
|
||||||
(5, "e", " is Lee, and I'm a stud"),
|
(5, "e", " is Lee, and I'm a stud"),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
|
def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
|
||||||
"""Test for a specific min_tokens and stop.
|
"""Test for a specific min_tokens and stop.
|
||||||
|
|
||||||
@@ -31,7 +33,8 @@ def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
min_tokens=min_tokens,
|
min_tokens=min_tokens,
|
||||||
)
|
)
|
||||||
request = EngineCoreRequest(request_id="",
|
request = EngineCoreRequest(
|
||||||
|
request_id="",
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
mm_features=None,
|
mm_features=None,
|
||||||
sampling_params=params,
|
sampling_params=params,
|
||||||
@@ -40,7 +43,8 @@ def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
|
|||||||
arrival_time=0.0,
|
arrival_time=0.0,
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
cache_salt=None,
|
cache_salt=None,
|
||||||
data_parallel_rank=None)
|
data_parallel_rank=None,
|
||||||
|
)
|
||||||
|
|
||||||
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
|
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
|
||||||
|
|
||||||
|
|||||||
@@ -31,34 +31,39 @@ def test_stop_reason(vllm_model, example_prompts):
|
|||||||
llm = vllm_model.llm
|
llm = vllm_model.llm
|
||||||
|
|
||||||
# test stop token
|
# test stop token
|
||||||
outputs = llm.generate(example_prompts,
|
outputs = llm.generate(
|
||||||
|
example_prompts,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
seed=SEED,
|
seed=SEED,
|
||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
stop_token_ids=[stop_token_id]))
|
stop_token_ids=[stop_token_id],
|
||||||
|
),
|
||||||
|
)
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
output = output.outputs[0]
|
output = output.outputs[0]
|
||||||
assert output.finish_reason == "stop"
|
assert output.finish_reason == "stop"
|
||||||
assert output.stop_reason == stop_token_id
|
assert output.stop_reason == stop_token_id
|
||||||
|
|
||||||
# test stop string
|
# test stop string
|
||||||
outputs = llm.generate(example_prompts,
|
outputs = llm.generate(
|
||||||
|
example_prompts,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
ignore_eos=True,
|
ignore_eos=True, seed=SEED, max_tokens=MAX_TOKENS, stop="."
|
||||||
seed=SEED,
|
),
|
||||||
max_tokens=MAX_TOKENS,
|
)
|
||||||
stop="."))
|
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
output = output.outputs[0]
|
output = output.outputs[0]
|
||||||
assert output.finish_reason == "stop"
|
assert output.finish_reason == "stop"
|
||||||
assert output.stop_reason == STOP_STR
|
assert output.stop_reason == STOP_STR
|
||||||
|
|
||||||
# test EOS token
|
# test EOS token
|
||||||
outputs = llm.generate(example_prompts,
|
outputs = llm.generate(
|
||||||
sampling_params=SamplingParams(
|
example_prompts,
|
||||||
seed=SEED, max_tokens=MAX_TOKENS))
|
sampling_params=SamplingParams(seed=SEED, max_tokens=MAX_TOKENS),
|
||||||
|
)
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
output = output.outputs[0]
|
output = output.outputs[0]
|
||||||
assert output.finish_reason == "length" or (
|
assert output.finish_reason == "length" or (
|
||||||
output.finish_reason == "stop" and output.stop_reason is None)
|
output.finish_reason == "stop" and output.stop_reason is None
|
||||||
|
)
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ def include_stop_str_in_output(request):
|
|||||||
|
|
||||||
|
|
||||||
class _DummyDetokenizer(BaseIncrementalDetokenizer):
|
class _DummyDetokenizer(BaseIncrementalDetokenizer):
|
||||||
|
|
||||||
def __init__(self, request: EngineCoreRequest):
|
def __init__(self, request: EngineCoreRequest):
|
||||||
super().__init__(request)
|
super().__init__(request)
|
||||||
|
|
||||||
@@ -27,7 +26,8 @@ def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0):
|
|||||||
params = SamplingParams(
|
params = SamplingParams(
|
||||||
stop=stop,
|
stop=stop,
|
||||||
include_stop_str_in_output=include_stop_str_in_output,
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
min_tokens=min_tokens)
|
min_tokens=min_tokens,
|
||||||
|
)
|
||||||
# Keep other fields minimal for unit test purposes.
|
# Keep other fields minimal for unit test purposes.
|
||||||
req = EngineCoreRequest(
|
req = EngineCoreRequest(
|
||||||
request_id="test",
|
request_id="test",
|
||||||
@@ -44,8 +44,7 @@ def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0):
|
|||||||
return req
|
return req
|
||||||
|
|
||||||
|
|
||||||
def test_stop_string_while_stop_token_terminates(
|
def test_stop_string_while_stop_token_terminates(include_stop_str_in_output: bool):
|
||||||
include_stop_str_in_output: bool):
|
|
||||||
"""
|
"""
|
||||||
This test verifies that the detokenizer correctly handles the case where
|
This test verifies that the detokenizer correctly handles the case where
|
||||||
the generated token sequence contains both:
|
the generated token sequence contains both:
|
||||||
@@ -78,8 +77,9 @@ def test_stop_string_while_stop_token_terminates(
|
|||||||
token_ids = [ord(c) for c in generated_text]
|
token_ids = [ord(c) for c in generated_text]
|
||||||
|
|
||||||
# Create a request with the stop string and initialize the detokenizer.
|
# Create a request with the stop string and initialize the detokenizer.
|
||||||
req = _make_request(stop=[stop_string],
|
req = _make_request(
|
||||||
include_stop_str_in_output=include_stop_str_in_output)
|
stop=[stop_string], include_stop_str_in_output=include_stop_str_in_output
|
||||||
|
)
|
||||||
detok = _DummyDetokenizer(req)
|
detok = _DummyDetokenizer(req)
|
||||||
|
|
||||||
# Simulate that the last token ('Z') is a stop token (stop_terminated=True).
|
# Simulate that the last token ('Z') is a stop token (stop_terminated=True).
|
||||||
@@ -99,5 +99,4 @@ def test_stop_string_while_stop_token_terminates(
|
|||||||
|
|
||||||
# get_next_output_text should return the full text when finished=True.
|
# get_next_output_text should return the full text when finished=True.
|
||||||
# (Buffering only applies during streaming when finished=False.)
|
# (Buffering only applies during streaming when finished=False.)
|
||||||
assert detok.get_next_output_text(finished=True,
|
assert detok.get_next_output_text(finished=True, delta=False) == expected_text
|
||||||
delta=False) == expected_text
|
|
||||||
|
|||||||
@@ -11,12 +11,14 @@ MODEL = "meta-llama/llama-2-7b-hf"
|
|||||||
MAX_TOKENS = 200
|
MAX_TOKENS = 200
|
||||||
|
|
||||||
|
|
||||||
def _test_stopping(llm: LLM,
|
def _test_stopping(
|
||||||
|
llm: LLM,
|
||||||
expected_output: str,
|
expected_output: str,
|
||||||
expected_reason: Any,
|
expected_reason: Any,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
stop_token_ids: Optional[list[int]] = None,
|
stop_token_ids: Optional[list[int]] = None,
|
||||||
include_in_output: bool = False) -> None:
|
include_in_output: bool = False,
|
||||||
|
) -> None:
|
||||||
output = llm.generate(
|
output = llm.generate(
|
||||||
"A story about vLLM:\n",
|
"A story about vLLM:\n",
|
||||||
SamplingParams(
|
SamplingParams(
|
||||||
@@ -25,7 +27,8 @@ def _test_stopping(llm: LLM,
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
include_stop_str_in_output=include_in_output,
|
include_stop_str_in_output=include_in_output,
|
||||||
))[0].outputs[0]
|
),
|
||||||
|
)[0].outputs[0]
|
||||||
|
|
||||||
assert output is not None
|
assert output is not None
|
||||||
assert output.text == expected_output
|
assert output.text == expected_output
|
||||||
@@ -33,17 +36,21 @@ def _test_stopping(llm: LLM,
|
|||||||
|
|
||||||
|
|
||||||
def _stop_basic(llm):
|
def _stop_basic(llm):
|
||||||
_test_stopping(llm,
|
_test_stopping(
|
||||||
|
llm,
|
||||||
stop=["."],
|
stop=["."],
|
||||||
include_in_output=False,
|
include_in_output=False,
|
||||||
expected_output="VLLM is a 100% volunteer organization",
|
expected_output="VLLM is a 100% volunteer organization",
|
||||||
expected_reason=".")
|
expected_reason=".",
|
||||||
|
)
|
||||||
|
|
||||||
_test_stopping(llm,
|
_test_stopping(
|
||||||
|
llm,
|
||||||
stop=["."],
|
stop=["."],
|
||||||
include_in_output=True,
|
include_in_output=True,
|
||||||
expected_output="VLLM is a 100% volunteer organization.",
|
expected_output="VLLM is a 100% volunteer organization.",
|
||||||
expected_reason=".")
|
expected_reason=".",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _stop_multi_tokens(llm):
|
def _stop_multi_tokens(llm):
|
||||||
@@ -52,45 +59,54 @@ def _stop_multi_tokens(llm):
|
|||||||
stop=["group of peo", "short"],
|
stop=["group of peo", "short"],
|
||||||
include_in_output=False,
|
include_in_output=False,
|
||||||
expected_output="VLLM is a 100% volunteer organization. We are a ",
|
expected_output="VLLM is a 100% volunteer organization. We are a ",
|
||||||
expected_reason="group of peo")
|
expected_reason="group of peo",
|
||||||
|
)
|
||||||
|
|
||||||
_test_stopping(
|
_test_stopping(
|
||||||
llm,
|
llm,
|
||||||
stop=["group of peo", "short"],
|
stop=["group of peo", "short"],
|
||||||
include_in_output=True,
|
include_in_output=True,
|
||||||
expected_output=
|
expected_output="VLLM is a 100% volunteer organization. We are a group of peo",
|
||||||
"VLLM is a 100% volunteer organization. We are a group of peo",
|
expected_reason="group of peo",
|
||||||
expected_reason="group of peo")
|
)
|
||||||
|
|
||||||
|
|
||||||
def _stop_partial_token(llm):
|
def _stop_partial_token(llm):
|
||||||
_test_stopping(llm,
|
_test_stopping(
|
||||||
|
llm,
|
||||||
stop=["gani"],
|
stop=["gani"],
|
||||||
include_in_output=False,
|
include_in_output=False,
|
||||||
expected_output="VLLM is a 100% volunteer or",
|
expected_output="VLLM is a 100% volunteer or",
|
||||||
expected_reason="gani")
|
expected_reason="gani",
|
||||||
|
)
|
||||||
|
|
||||||
_test_stopping(llm,
|
_test_stopping(
|
||||||
|
llm,
|
||||||
stop=["gani"],
|
stop=["gani"],
|
||||||
include_in_output=True,
|
include_in_output=True,
|
||||||
expected_output="VLLM is a 100% volunteer organi",
|
expected_output="VLLM is a 100% volunteer organi",
|
||||||
expected_reason="gani")
|
expected_reason="gani",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _stop_token_id(llm):
|
def _stop_token_id(llm):
|
||||||
# token id 13013 => " organization"
|
# token id 13013 => " organization"
|
||||||
|
|
||||||
_test_stopping(llm,
|
_test_stopping(
|
||||||
|
llm,
|
||||||
stop_token_ids=[13013],
|
stop_token_ids=[13013],
|
||||||
include_in_output=False,
|
include_in_output=False,
|
||||||
expected_output="VLLM is a 100% volunteer",
|
expected_output="VLLM is a 100% volunteer",
|
||||||
expected_reason=13013)
|
expected_reason=13013,
|
||||||
|
)
|
||||||
|
|
||||||
_test_stopping(llm,
|
_test_stopping(
|
||||||
|
llm,
|
||||||
stop_token_ids=[13013],
|
stop_token_ids=[13013],
|
||||||
include_in_output=True,
|
include_in_output=True,
|
||||||
expected_output="VLLM is a 100% volunteer organization",
|
expected_output="VLLM is a 100% volunteer organization",
|
||||||
expected_reason=13013)
|
expected_reason=13013,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
|
|||||||
@@ -111,8 +111,7 @@ class MockSubscriber:
|
|||||||
self.last_seq = -1
|
self.last_seq = -1
|
||||||
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
|
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
|
||||||
|
|
||||||
def receive_one(self,
|
def receive_one(self, timeout=1000) -> Union[tuple[int, SampleBatch], None]:
|
||||||
timeout=1000) -> Union[tuple[int, SampleBatch], None]:
|
|
||||||
"""Receive a single message with timeout"""
|
"""Receive a single message with timeout"""
|
||||||
if not self.sub.poll(timeout):
|
if not self.sub.poll(timeout):
|
||||||
return None
|
return None
|
||||||
@@ -135,8 +134,7 @@ class MockSubscriber:
|
|||||||
|
|
||||||
self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big"))
|
self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big"))
|
||||||
|
|
||||||
def receive_replay(self,
|
def receive_replay(self, socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
|
||||||
socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
|
|
||||||
"""Receive replayed messages from a specific replay socket"""
|
"""Receive replayed messages from a specific replay socket"""
|
||||||
if not self.replay_sockets:
|
if not self.replay_sockets:
|
||||||
raise ValueError("Replay sockets not initialized")
|
raise ValueError("Replay sockets not initialized")
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||||
from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa
|
from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa
|
||||||
CustomAllreduce)
|
CustomAllreduce,
|
||||||
|
)
|
||||||
|
|
||||||
# create a cpu process group for communicating metadata (ipc handle)
|
# create a cpu process group for communicating metadata (ipc handle)
|
||||||
dist.init_process_group(backend="gloo")
|
dist.init_process_group(backend="gloo")
|
||||||
@@ -52,7 +53,8 @@ for p in pointers:
|
|||||||
assert ord(host_data[i]) == byte_value, (
|
assert ord(host_data[i]) == byte_value, (
|
||||||
f"Rank {rank} failed"
|
f"Rank {rank} failed"
|
||||||
f" to verify buffer {p}. Expected {byte_value}, "
|
f" to verify buffer {p}. Expected {byte_value}, "
|
||||||
f"got {ord(host_data[i])}")
|
f"got {ord(host_data[i])}"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Rank {rank} verified all buffers")
|
print(f"Rank {rank} verified all buffers")
|
||||||
|
|
||||||
|
|||||||
@@ -13,13 +13,19 @@ import pytest
|
|||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
|
from vllm.distributed import (
|
||||||
|
broadcast_tensor_dict,
|
||||||
|
get_pp_group,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
tensor_model_parallel_reduce_scatter)
|
tensor_model_parallel_reduce_scatter,
|
||||||
|
)
|
||||||
|
|
||||||
from ..utils import (init_test_distributed_environment, multi_gpu_test,
|
from ..utils import (
|
||||||
multi_process_parallel)
|
init_test_distributed_environment,
|
||||||
|
multi_gpu_test,
|
||||||
|
multi_process_parallel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(num_gpus=1, max_calls=1)
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
@@ -37,12 +43,11 @@ def all_reduce_test_worker(
|
|||||||
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
num_elements = 8
|
num_elements = 8
|
||||||
all_tensors = [
|
all_tensors = [
|
||||||
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
|
torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
|
||||||
(r + 1) for r in range(tp_size)
|
for r in range(tp_size)
|
||||||
]
|
]
|
||||||
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||||
t = all_tensors[rank % tp_size]
|
t = all_tensors[rank % tp_size]
|
||||||
@@ -51,28 +56,31 @@ def all_reduce_test_worker(
|
|||||||
|
|
||||||
|
|
||||||
@ray.remote(num_gpus=1, max_calls=1)
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int,
|
def reduce_scatter_test_worker(
|
||||||
pp_size: int, rank: int,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
distributed_init_port: str):
|
tp_size: int,
|
||||||
|
pp_size: int,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_port: str,
|
||||||
|
):
|
||||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||||
# so that each worker can see all the GPUs
|
# so that each worker can see all the GPUs
|
||||||
# they will be able to set the device to the correct GPU
|
# they will be able to set the device to the correct GPU
|
||||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
|
|
||||||
num_elements = 8
|
num_elements = 8
|
||||||
all_tensors = [
|
all_tensors = [
|
||||||
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
|
torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
|
||||||
(r + 1) for r in range(tp_size)
|
for r in range(tp_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
index = rank % tp_size
|
index = rank % tp_size
|
||||||
partition_size = num_elements // tp_size
|
partition_size = num_elements // tp_size
|
||||||
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||||
expected = all_reduce[index * partition_size:(index + 1) * partition_size]
|
expected = all_reduce[index * partition_size : (index + 1) * partition_size]
|
||||||
t = all_tensors[index]
|
t = all_tensors[index]
|
||||||
t = tensor_model_parallel_reduce_scatter(t, 0)
|
t = tensor_model_parallel_reduce_scatter(t, 0)
|
||||||
torch.testing.assert_close(t, expected)
|
torch.testing.assert_close(t, expected)
|
||||||
@@ -92,8 +100,7 @@ def all_gather_test_worker(
|
|||||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
num_dimensions = 3
|
num_dimensions = 3
|
||||||
tensor_size = list(range(2, num_dimensions + 2))
|
tensor_size = list(range(2, num_dimensions + 2))
|
||||||
total_size = 1
|
total_size = 1
|
||||||
@@ -101,8 +108,10 @@ def all_gather_test_worker(
|
|||||||
total_size *= s
|
total_size *= s
|
||||||
for all_gather_dimension in range(num_dimensions):
|
for all_gather_dimension in range(num_dimensions):
|
||||||
all_tensors = [
|
all_tensors = [
|
||||||
torch.arange(total_size, dtype=torch.float32,
|
torch.arange(total_size, dtype=torch.float32, device="cuda").reshape(
|
||||||
device="cuda").reshape(tensor_size) * (r + 1)
|
tensor_size
|
||||||
|
)
|
||||||
|
* (r + 1)
|
||||||
for r in range(tp_size)
|
for r in range(tp_size)
|
||||||
]
|
]
|
||||||
expected = torch.cat(all_tensors, dim=all_gather_dimension)
|
expected = torch.cat(all_tensors, dim=all_gather_dimension)
|
||||||
@@ -125,8 +134,7 @@ def broadcast_tensor_dict_test_worker(
|
|||||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
test_dict = {
|
test_dict = {
|
||||||
# device tensor
|
# device tensor
|
||||||
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
||||||
@@ -134,10 +142,7 @@ def broadcast_tensor_dict_test_worker(
|
|||||||
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
||||||
"c": "test",
|
"c": "test",
|
||||||
"d": [1, 2, 3],
|
"d": [1, 2, 3],
|
||||||
"e": {
|
"e": {"a": 1, "b": 2},
|
||||||
"a": 1,
|
|
||||||
"b": 2
|
|
||||||
},
|
|
||||||
# empty tensor
|
# empty tensor
|
||||||
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
||||||
}
|
}
|
||||||
@@ -166,8 +171,7 @@ def send_recv_tensor_dict_test_worker(
|
|||||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
|
|
||||||
test_dict = {
|
test_dict = {
|
||||||
# device tensor
|
# device tensor
|
||||||
@@ -176,10 +180,7 @@ def send_recv_tensor_dict_test_worker(
|
|||||||
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
||||||
"c": "test",
|
"c": "test",
|
||||||
"d": [1, 2, 3],
|
"d": [1, 2, 3],
|
||||||
"e": {
|
"e": {"a": 1, "b": 2},
|
||||||
"a": 1,
|
|
||||||
"b": 2
|
|
||||||
},
|
|
||||||
# empty tensor
|
# empty tensor
|
||||||
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
||||||
}
|
}
|
||||||
@@ -211,8 +212,7 @@ def send_recv_test_worker(
|
|||||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
|
|
||||||
size = 64
|
size = 64
|
||||||
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")
|
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")
|
||||||
@@ -229,10 +229,10 @@ def send_recv_test_worker(
|
|||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@pytest.mark.parametrize("tp_size", [2])
|
@pytest.mark.parametrize("tp_size", [2])
|
||||||
@pytest.mark.parametrize("test_target", [
|
@pytest.mark.parametrize(
|
||||||
all_reduce_test_worker, all_gather_test_worker,
|
"test_target",
|
||||||
broadcast_tensor_dict_test_worker
|
[all_reduce_test_worker, all_gather_test_worker, broadcast_tensor_dict_test_worker],
|
||||||
])
|
)
|
||||||
def test_multi_process_tensor_parallel(
|
def test_multi_process_tensor_parallel(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
@@ -244,7 +244,8 @@ def test_multi_process_tensor_parallel(
|
|||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@pytest.mark.parametrize("pp_size", [2])
|
@pytest.mark.parametrize("pp_size", [2])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
|
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]
|
||||||
|
)
|
||||||
def test_multi_process_pipeline_parallel(
|
def test_multi_process_pipeline_parallel(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
pp_size: int,
|
pp_size: int,
|
||||||
@@ -256,11 +257,16 @@ def test_multi_process_pipeline_parallel(
|
|||||||
@multi_gpu_test(num_gpus=4)
|
@multi_gpu_test(num_gpus=4)
|
||||||
@pytest.mark.parametrize("tp_size", [2])
|
@pytest.mark.parametrize("tp_size", [2])
|
||||||
@pytest.mark.parametrize("pp_size", [2])
|
@pytest.mark.parametrize("pp_size", [2])
|
||||||
@pytest.mark.parametrize("test_target", [
|
@pytest.mark.parametrize(
|
||||||
send_recv_test_worker, send_recv_tensor_dict_test_worker,
|
"test_target",
|
||||||
all_reduce_test_worker, all_gather_test_worker,
|
[
|
||||||
broadcast_tensor_dict_test_worker
|
send_recv_test_worker,
|
||||||
])
|
send_recv_tensor_dict_test_worker,
|
||||||
|
all_reduce_test_worker,
|
||||||
|
all_gather_test_worker,
|
||||||
|
broadcast_tensor_dict_test_worker,
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_multi_process_tensor_parallel_pipeline_parallel(
|
def test_multi_process_tensor_parallel_pipeline_parallel(
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
pp_size: int,
|
pp_size: int,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
|||||||
all workers in a node other than the head node, which can cause the test
|
all workers in a node other than the head node, which can cause the test
|
||||||
to fail.
|
to fail.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -56,7 +57,8 @@ class CPTestSettings:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Length mismatch: distributed_backends "
|
f"Length mismatch: distributed_backends "
|
||||||
f"({len(self.distributed_backends)}) != "
|
f"({len(self.distributed_backends)}) != "
|
||||||
f"vllm_major_versions ({len(self.vllm_major_versions)})")
|
f"vllm_major_versions ({len(self.vllm_major_versions)})"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def detailed(
|
def detailed(
|
||||||
@@ -74,29 +76,39 @@ class CPTestSettings:
|
|||||||
for dcp_multiplier in [0.5, 1]:
|
for dcp_multiplier in [0.5, 1]:
|
||||||
for chunked_prefill_val in [True]:
|
for chunked_prefill_val in [True]:
|
||||||
parallel_setups.append(
|
parallel_setups.append(
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(
|
||||||
|
tp_size=tp_base,
|
||||||
pp_size=pp_multiplier * pp_base,
|
pp_size=pp_multiplier * pp_base,
|
||||||
dcp_size=int(dcp_multiplier *
|
dcp_size=int(dcp_multiplier * tp_base),
|
||||||
tp_base),
|
|
||||||
eager_mode=eager_mode_val,
|
eager_mode=eager_mode_val,
|
||||||
chunked_prefill=chunked_prefill_val))
|
chunked_prefill=chunked_prefill_val,
|
||||||
|
)
|
||||||
|
)
|
||||||
return CPTestSettings(
|
return CPTestSettings(
|
||||||
parallel_setups=parallel_setups,
|
parallel_setups=parallel_setups,
|
||||||
distributed_backends=["mp"],
|
distributed_backends=["mp"],
|
||||||
vllm_major_versions=["1"],
|
vllm_major_versions=["1"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=CPTestOptions(multi_node_only=multi_node_only,
|
test_options=CPTestOptions(
|
||||||
load_format=load_format),
|
multi_node_only=multi_node_only, load_format=load_format
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_params(self, model_id: str):
|
def iter_params(self, model_id: str):
|
||||||
opts = self.test_options
|
opts = self.test_options
|
||||||
|
|
||||||
for parallel_setup in self.parallel_setups:
|
for parallel_setup in self.parallel_setups:
|
||||||
for backend, vllm_major_version in zip(self.distributed_backends,
|
for backend, vllm_major_version in zip(
|
||||||
self.vllm_major_versions):
|
self.distributed_backends, self.vllm_major_versions
|
||||||
yield (model_id, parallel_setup, backend, vllm_major_version,
|
):
|
||||||
self.runner, opts)
|
yield (
|
||||||
|
model_id,
|
||||||
|
parallel_setup,
|
||||||
|
backend,
|
||||||
|
vllm_major_version,
|
||||||
|
self.runner,
|
||||||
|
opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _compare_cp_with_tp(
|
def _compare_cp_with_tp(
|
||||||
@@ -148,8 +160,10 @@ def _compare_cp_with_tp(
|
|||||||
if num_gpus_available < tp_size * pp_size:
|
if num_gpus_available < tp_size * pp_size:
|
||||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
pytest.skip(
|
||||||
"multiprocessing distributed backend")
|
"Skipping multi-node pipeline parallel test for "
|
||||||
|
"multiprocessing distributed backend"
|
||||||
|
)
|
||||||
if multi_node_only and not VLLM_MULTI_NODE:
|
if multi_node_only and not VLLM_MULTI_NODE:
|
||||||
pytest.skip("Not in multi-node setting")
|
pytest.skip("Not in multi-node setting")
|
||||||
|
|
||||||
@@ -178,8 +192,7 @@ def _compare_cp_with_tp(
|
|||||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||||
|
|
||||||
cp_env = tp_env = {
|
cp_env = tp_env = {
|
||||||
"VLLM_USE_V1":
|
"VLLM_USE_V1": vllm_major_version, # Note(hc): DCP only support V1 engine only
|
||||||
vllm_major_version, # Note(hc): DCP only support V1 engine only
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cp_args = [
|
cp_args = [
|
||||||
@@ -205,13 +218,15 @@ def _compare_cp_with_tp(
|
|||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
compare_two_settings(model_id,
|
compare_two_settings(
|
||||||
|
model_id,
|
||||||
cp_args,
|
cp_args,
|
||||||
tp_args,
|
tp_args,
|
||||||
cp_env,
|
cp_env,
|
||||||
tp_env,
|
tp_env,
|
||||||
method=method,
|
method=method,
|
||||||
max_wait_seconds=720)
|
max_wait_seconds=720,
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
testing_ray_compiled_graph = cp_env is not None
|
testing_ray_compiled_graph = cp_env is not None
|
||||||
if testing_ray_compiled_graph and vllm_major_version == "0":
|
if testing_ray_compiled_graph and vllm_major_version == "0":
|
||||||
@@ -224,9 +239,10 @@ def _compare_cp_with_tp(
|
|||||||
|
|
||||||
CP_TEXT_GENERATION_MODELS = {
|
CP_TEXT_GENERATION_MODELS = {
|
||||||
# [MLA attention only]
|
# [MLA attention only]
|
||||||
"deepseek-ai/DeepSeek-V2-Lite-Chat":
|
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
||||||
[CPTestSettings.detailed(),
|
CPTestSettings.detailed(),
|
||||||
CPTestSettings.detailed(tp_base=2)],
|
CPTestSettings.detailed(tp_base=2),
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
CP_TEST_MODELS = [
|
CP_TEST_MODELS = [
|
||||||
@@ -237,11 +253,19 @@ CP_TEST_MODELS = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
(
|
||||||
"runner", "test_options"),
|
"model_id",
|
||||||
|
"parallel_setup",
|
||||||
|
"distributed_backend",
|
||||||
|
"vllm_major_version",
|
||||||
|
"runner",
|
||||||
|
"test_options",
|
||||||
|
),
|
||||||
[
|
[
|
||||||
params for model_id, settings in CP_TEXT_GENERATION_MODELS.items()
|
params
|
||||||
for setting in settings for params in setting.iter_params(model_id)
|
for model_id, settings in CP_TEXT_GENERATION_MODELS.items()
|
||||||
|
for setting in settings
|
||||||
|
for params in setting.iter_params(model_id)
|
||||||
if model_id in CP_TEST_MODELS
|
if model_id in CP_TEST_MODELS
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -255,7 +279,8 @@ def test_cp_generation(
|
|||||||
test_options: CPTestOptions,
|
test_options: CPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
_compare_cp_with_tp(model_id,
|
_compare_cp_with_tp(
|
||||||
|
model_id,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
vllm_major_version,
|
vllm_major_version,
|
||||||
@@ -263,4 +288,5 @@ def test_cp_generation(
|
|||||||
test_options,
|
test_options,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
method="generate",
|
method="generate",
|
||||||
is_multimodal=False)
|
is_multimodal=False,
|
||||||
|
)
|
||||||
|
|||||||
@@ -8,12 +8,14 @@ import ray
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from vllm.distributed.communication_op import ( # noqa
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||||
tensor_model_parallel_all_reduce)
|
|
||||||
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
||||||
|
|
||||||
from ..utils import (ensure_model_parallel_initialized,
|
from ..utils import (
|
||||||
init_test_distributed_environment, multi_process_parallel)
|
ensure_model_parallel_initialized,
|
||||||
|
init_test_distributed_environment,
|
||||||
|
multi_process_parallel,
|
||||||
|
)
|
||||||
|
|
||||||
random.seed(42)
|
random.seed(42)
|
||||||
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
|
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
|
||||||
@@ -33,8 +35,7 @@ def graph_allreduce(
|
|||||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||||
group = get_tp_group().device_group
|
group = get_tp_group().device_group
|
||||||
|
|
||||||
@@ -60,18 +61,15 @@ def graph_allreduce(
|
|||||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||||
with graph_capture(device=device) as graph_capture_context:
|
with graph_capture(device=device) as graph_capture_context:
|
||||||
# use integers so result matches NCCL exactly
|
# use integers so result matches NCCL exactly
|
||||||
inp1 = torch.randint(1,
|
inp1 = torch.randint(
|
||||||
16, (sz, ),
|
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||||
dtype=dtype,
|
)
|
||||||
device=torch.cuda.current_device())
|
inp2 = torch.randint(
|
||||||
inp2 = torch.randint(1,
|
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||||
16, (sz, ),
|
)
|
||||||
dtype=dtype,
|
|
||||||
device=torch.cuda.current_device())
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(graph,
|
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
|
||||||
stream=graph_capture_context.stream):
|
|
||||||
for i in range(num_communication):
|
for i in range(num_communication):
|
||||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||||
# the input buffer is immediately modified to test
|
# the input buffer is immediately modified to test
|
||||||
@@ -96,8 +94,7 @@ def eager_allreduce(
|
|||||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
|
|
||||||
# we use the first group to communicate once
|
# we use the first group to communicate once
|
||||||
# and the second group to communicate twice
|
# and the second group to communicate twice
|
||||||
@@ -132,5 +129,4 @@ def test_custom_allreduce(
|
|||||||
world_size = tp_size * pipeline_parallel_size
|
world_size = tp_size * pipeline_parallel_size
|
||||||
if world_size > torch.cuda.device_count():
|
if world_size > torch.cuda.device_count():
|
||||||
pytest.skip("Not enough GPUs to run the test.")
|
pytest.skip("Not enough GPUs to run the test.")
|
||||||
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size,
|
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
|
||||||
test_target)
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from ..entrypoints.openai.test_oot_registration import (
|
from ..entrypoints.openai.test_oot_registration import run_and_test_dummy_opt_api_server
|
||||||
run_and_test_dummy_opt_api_server)
|
|
||||||
|
|
||||||
|
|
||||||
def test_distributed_oot(dummy_opt_path: str):
|
def test_distributed_oot(dummy_opt_path: str):
|
||||||
|
|||||||
@@ -10,10 +10,12 @@ from vllm.distributed.eplb.rebalance_algo import rebalance_experts
|
|||||||
def test_basic_rebalance():
|
def test_basic_rebalance():
|
||||||
"""Test basic rebalancing functionality"""
|
"""Test basic rebalancing functionality"""
|
||||||
# Example from https://github.com/deepseek-ai/eplb
|
# Example from https://github.com/deepseek-ai/eplb
|
||||||
weight = torch.tensor([
|
weight = torch.tensor(
|
||||||
|
[
|
||||||
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
||||||
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
num_layers = weight.shape[0]
|
num_layers = weight.shape[0]
|
||||||
num_replicas = 16
|
num_replicas = 16
|
||||||
@@ -21,45 +23,49 @@ def test_basic_rebalance():
|
|||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 8
|
num_gpus = 8
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Verify output shapes
|
# Verify output shapes
|
||||||
assert phy2log.shape == (
|
assert phy2log.shape == (
|
||||||
2,
|
2,
|
||||||
16,
|
16,
|
||||||
), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}"
|
), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}"
|
||||||
assert (log2phy.shape[0] == 2
|
assert log2phy.shape[0] == 2, (
|
||||||
), f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}"
|
f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}"
|
||||||
assert (
|
)
|
||||||
log2phy.shape[1] == 12
|
assert log2phy.shape[1] == 12, (
|
||||||
), f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}"
|
f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}"
|
||||||
|
)
|
||||||
assert logcnt.shape == (
|
assert logcnt.shape == (
|
||||||
2,
|
2,
|
||||||
12,
|
12,
|
||||||
), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}"
|
), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}"
|
||||||
|
|
||||||
# Verify physical to logical expert mapping range is correct
|
# Verify physical to logical expert mapping range is correct
|
||||||
assert torch.all(phy2log >= 0) and torch.all(
|
assert torch.all(phy2log >= 0) and torch.all(phy2log < 12), (
|
||||||
phy2log < 12), "Physical to logical mapping should be in range [0, 12)"
|
"Physical to logical mapping should be in range [0, 12)"
|
||||||
|
)
|
||||||
|
|
||||||
# Verify expert count reasonableness
|
# Verify expert count reasonableness
|
||||||
assert torch.all(
|
assert torch.all(logcnt >= 1), "Each logical expert should have at least 1 replica"
|
||||||
logcnt >= 1), "Each logical expert should have at least 1 replica"
|
assert torch.sum(logcnt, dim=1).sum() == num_replicas * num_layers, (
|
||||||
assert (
|
f"Total replicas should be {num_replicas * num_layers}"
|
||||||
torch.sum(logcnt, dim=1).sum() == num_replicas *
|
)
|
||||||
num_layers), f"Total replicas should be {num_replicas * num_layers}"
|
|
||||||
|
|
||||||
# Verify expected output
|
# Verify expected output
|
||||||
expected_phy2log = torch.tensor([
|
expected_phy2log = torch.tensor(
|
||||||
|
[
|
||||||
[5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1],
|
[5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1],
|
||||||
[7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1],
|
[7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1],
|
||||||
])
|
]
|
||||||
|
)
|
||||||
assert torch.all(phy2log == expected_phy2log)
|
assert torch.all(phy2log == expected_phy2log)
|
||||||
|
|
||||||
expected_logcnt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1],
|
expected_logcnt = torch.tensor(
|
||||||
[1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]])
|
[[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]]
|
||||||
|
)
|
||||||
assert torch.all(logcnt == expected_logcnt)
|
assert torch.all(logcnt == expected_logcnt)
|
||||||
|
|
||||||
|
|
||||||
@@ -71,9 +77,9 @@ def test_single_gpu_case():
|
|||||||
num_nodes = 1
|
num_nodes = 1
|
||||||
num_gpus = 1
|
num_gpus = 1
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Verify shapes
|
# Verify shapes
|
||||||
assert phy2log.shape == (1, 4)
|
assert phy2log.shape == (1, 4)
|
||||||
@@ -93,19 +99,19 @@ def test_equal_weights():
|
|||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 4
|
num_gpus = 4
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Verify shapes
|
# Verify shapes
|
||||||
assert phy2log.shape == (1, 8)
|
assert phy2log.shape == (1, 8)
|
||||||
assert logcnt.shape == (1, 8)
|
assert logcnt.shape == (1, 8)
|
||||||
|
|
||||||
# With equal weights, each expert should have exactly one replica
|
# With equal weights, each expert should have exactly one replica
|
||||||
assert torch.all(
|
assert torch.all(logcnt == 1), (
|
||||||
logcnt == 1
|
"With equal weights and no replication, "
|
||||||
), "With equal weights and no replication, " \
|
|
||||||
"each expert should have exactly 1 replica"
|
"each expert should have exactly 1 replica"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extreme_weight_imbalance():
|
def test_extreme_weight_imbalance():
|
||||||
@@ -116,35 +122,37 @@ def test_extreme_weight_imbalance():
|
|||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 4
|
num_gpus = 4
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Verify shapes
|
# Verify shapes
|
||||||
assert phy2log.shape == (1, 12)
|
assert phy2log.shape == (1, 12)
|
||||||
assert logcnt.shape == (1, 8)
|
assert logcnt.shape == (1, 8)
|
||||||
|
|
||||||
# Expert with highest weight (index 0) should have more replicas
|
# Expert with highest weight (index 0) should have more replicas
|
||||||
assert (
|
assert logcnt[0, 0] > logcnt[0, 1], (
|
||||||
logcnt[0, 0]
|
"Expert with highest weight should have more replicas"
|
||||||
> logcnt[0, 1]), "Expert with highest weight should have more replicas"
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_layers():
|
def test_multiple_layers():
|
||||||
"""Test multiple layers case"""
|
"""Test multiple layers case"""
|
||||||
weight = torch.tensor([
|
weight = torch.tensor(
|
||||||
|
[
|
||||||
[10, 20, 30, 40, 50, 60], # First layer
|
[10, 20, 30, 40, 50, 60], # First layer
|
||||||
[60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern)
|
[60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern)
|
||||||
[25, 25, 25, 25, 25, 25], # Third layer (equal weights)
|
[25, 25, 25, 25, 25, 25], # Third layer (equal weights)
|
||||||
])
|
]
|
||||||
|
)
|
||||||
num_replicas = 8
|
num_replicas = 8
|
||||||
num_groups = 2
|
num_groups = 2
|
||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 4
|
num_gpus = 4
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Verify shapes
|
# Verify shapes
|
||||||
assert phy2log.shape == (3, 8)
|
assert phy2log.shape == (3, 8)
|
||||||
@@ -152,12 +160,12 @@ def test_multiple_layers():
|
|||||||
|
|
||||||
# Verify expert allocation is reasonable for each layer
|
# Verify expert allocation is reasonable for each layer
|
||||||
for layer in range(3):
|
for layer in range(3):
|
||||||
assert torch.all(phy2log[layer] >= 0) and torch.all(
|
assert torch.all(phy2log[layer] >= 0) and torch.all(phy2log[layer] < 6), (
|
||||||
phy2log[layer] < 6
|
f"Layer {layer} physical to logical mappingshould be in range [0, 6)"
|
||||||
), f"Layer {layer} physical to logical mapping" \
|
)
|
||||||
"should be in range [0, 6)"
|
assert torch.sum(logcnt[layer]) == num_replicas, (
|
||||||
assert (torch.sum(logcnt[layer]) == num_replicas
|
f"Layer {layer} total replicas should be {num_replicas}"
|
||||||
), f"Layer {layer} total replicas should be {num_replicas}"
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_parameter_validation():
|
def test_parameter_validation():
|
||||||
@@ -179,17 +187,19 @@ def test_parameter_validation():
|
|||||||
|
|
||||||
def test_small_scale_hierarchical():
|
def test_small_scale_hierarchical():
|
||||||
"""Test small-scale hierarchical load balancing"""
|
"""Test small-scale hierarchical load balancing"""
|
||||||
weight = torch.tensor([
|
weight = torch.tensor(
|
||||||
|
[
|
||||||
[100, 50, 200, 75, 150, 25, 300, 80], # 8 experts
|
[100, 50, 200, 75, 150, 25, 300, 80], # 8 experts
|
||||||
])
|
]
|
||||||
|
)
|
||||||
num_replicas = 12
|
num_replicas = 12
|
||||||
num_groups = 4 # 4 groups, 2 experts each
|
num_groups = 4 # 4 groups, 2 experts each
|
||||||
num_nodes = 2 # 2 nodes
|
num_nodes = 2 # 2 nodes
|
||||||
num_gpus = 4 # 4 GPUs
|
num_gpus = 4 # 4 GPUs
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Verify basic constraints
|
# Verify basic constraints
|
||||||
assert phy2log.shape == (1, 12)
|
assert phy2log.shape == (1, 12)
|
||||||
@@ -199,8 +209,9 @@ def test_small_scale_hierarchical():
|
|||||||
|
|
||||||
# Expert with highest weight should have more replicas
|
# Expert with highest weight should have more replicas
|
||||||
max_weight_expert = torch.argmax(weight[0])
|
max_weight_expert = torch.argmax(weight[0])
|
||||||
assert (logcnt[0, max_weight_expert]
|
assert logcnt[0, max_weight_expert] >= 2, (
|
||||||
>= 2), "Highest weight expert should have multiple replicas"
|
"Highest weight expert should have multiple replicas"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_global_load_balance_fallback():
|
def test_global_load_balance_fallback():
|
||||||
@@ -213,9 +224,9 @@ def test_global_load_balance_fallback():
|
|||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 4
|
num_gpus = 4
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Should work normally, just using global load balancing strategy
|
# Should work normally, just using global load balancing strategy
|
||||||
assert phy2log.shape == (1, 8)
|
assert phy2log.shape == (1, 8)
|
||||||
@@ -235,9 +246,9 @@ def test_device_compatibility(device):
|
|||||||
num_nodes = 1
|
num_nodes = 1
|
||||||
num_gpus = 2
|
num_gpus = 2
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Function will convert to CPU internally, but should handle different
|
# Function will convert to CPU internally, but should handle different
|
||||||
# device inputs normally
|
# device inputs normally
|
||||||
@@ -250,7 +261,8 @@ def test_additional_cases():
|
|||||||
|
|
||||||
# Test case 1: Large-scale distributed setup
|
# Test case 1: Large-scale distributed setup
|
||||||
weight1 = torch.tensor(
|
weight1 = torch.tensor(
|
||||||
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]])
|
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
|
||||||
|
)
|
||||||
phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8)
|
phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8)
|
||||||
|
|
||||||
assert phy2log1.shape == (1, 24)
|
assert phy2log1.shape == (1, 24)
|
||||||
@@ -258,10 +270,12 @@ def test_additional_cases():
|
|||||||
assert torch.sum(logcnt1) == 24
|
assert torch.sum(logcnt1) == 24
|
||||||
|
|
||||||
# Test case 2: Different weight distributions
|
# Test case 2: Different weight distributions
|
||||||
weight2 = torch.tensor([
|
weight2 = torch.tensor(
|
||||||
|
[
|
||||||
[200, 150, 100, 50, 25, 12], # Decreasing weights
|
[200, 150, 100, 50, 25, 12], # Decreasing weights
|
||||||
[12, 25, 50, 100, 150, 200], # Increasing weights
|
[12, 25, 50, 100, 150, 200], # Increasing weights
|
||||||
])
|
]
|
||||||
|
)
|
||||||
phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2)
|
phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2)
|
||||||
|
|
||||||
assert phy2log2.shape == (2, 10)
|
assert phy2log2.shape == (2, 10)
|
||||||
@@ -274,19 +288,21 @@ def test_additional_cases():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
weight = torch.tensor([
|
weight = torch.tensor(
|
||||||
|
[
|
||||||
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
||||||
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
num_replicas = 16
|
num_replicas = 16
|
||||||
num_groups = 4
|
num_groups = 4
|
||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 8
|
num_gpus = 8
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
print(phy2log)
|
print(phy2log)
|
||||||
|
|
||||||
test_basic_rebalance()
|
test_basic_rebalance()
|
||||||
|
|||||||
@@ -9,11 +9,12 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from vllm.distributed.eplb.rebalance_execute import (
|
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||||
rearrange_expert_weights_inplace)
|
from vllm.distributed.parallel_state import (
|
||||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
ensure_model_parallel_initialized,
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
init_distributed_environment)
|
init_distributed_environment,
|
||||||
|
)
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
|
|
||||||
@@ -22,13 +23,13 @@ def distributed_run(fn, world_size):
|
|||||||
processes: list[multiprocessing.Process] = []
|
processes: list[multiprocessing.Process] = []
|
||||||
for i in range(number_of_processes):
|
for i in range(number_of_processes):
|
||||||
env: dict[str, str] = {}
|
env: dict[str, str] = {}
|
||||||
env['RANK'] = str(i)
|
env["RANK"] = str(i)
|
||||||
env['LOCAL_RANK'] = str(i)
|
env["LOCAL_RANK"] = str(i)
|
||||||
env['WORLD_SIZE'] = str(number_of_processes)
|
env["WORLD_SIZE"] = str(number_of_processes)
|
||||||
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
|
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||||
env['MASTER_ADDR'] = 'localhost'
|
env["MASTER_ADDR"] = "localhost"
|
||||||
env['MASTER_PORT'] = '12345'
|
env["MASTER_PORT"] = "12345"
|
||||||
p = multiprocessing.Process(target=fn, args=(env, ))
|
p = multiprocessing.Process(target=fn, args=(env,))
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
@@ -45,7 +46,7 @@ def worker_fn_wrapper(fn):
|
|||||||
# and update the environment variables in the function
|
# and update the environment variables in the function
|
||||||
def wrapped_fn(env):
|
def wrapped_fn(env):
|
||||||
update_environment_variables(env)
|
update_environment_variables(env)
|
||||||
local_rank = os.environ['LOCAL_RANK']
|
local_rank = os.environ["LOCAL_RANK"]
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
@@ -120,27 +121,27 @@ def create_expert_weights(
|
|||||||
for layer in range(num_layers):
|
for layer in range(num_layers):
|
||||||
layer_weights = []
|
layer_weights = []
|
||||||
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
||||||
weight_tensor = torch.zeros(num_local_experts,
|
weight_tensor = torch.zeros(
|
||||||
hidden_size,
|
num_local_experts, hidden_size, device=device, dtype=torch.float32
|
||||||
device=device,
|
)
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
for local_expert in range(num_local_experts):
|
for local_expert in range(num_local_experts):
|
||||||
# Get the logical expert ID for this physical expert
|
# Get the logical expert ID for this physical expert
|
||||||
global_pos = rank * num_local_experts + local_expert
|
global_pos = rank * num_local_experts + local_expert
|
||||||
logical_expert_id = physical_to_logical_mapping[
|
logical_expert_id = physical_to_logical_mapping[
|
||||||
layer, global_pos].item()
|
layer, global_pos
|
||||||
|
].item()
|
||||||
|
|
||||||
# Generate weights based on logical expert ID
|
# Generate weights based on logical expert ID
|
||||||
# (so that all replicas of the same logical expert have the
|
# (so that all replicas of the same logical expert have the
|
||||||
# same weights)
|
# same weights)
|
||||||
base_value = (logical_expert_id * 1000 + layer * 100 +
|
base_value = logical_expert_id * 1000 + layer * 100 + weight_idx * 10
|
||||||
weight_idx * 10)
|
weight_tensor[local_expert] = torch.arange(
|
||||||
weight_tensor[local_expert] = torch.arange(base_value,
|
base_value,
|
||||||
base_value +
|
base_value + hidden_size,
|
||||||
hidden_size,
|
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.float32)
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
layer_weights.append(weight_tensor)
|
layer_weights.append(weight_tensor)
|
||||||
expert_weights.append(layer_weights)
|
expert_weights.append(layer_weights)
|
||||||
@@ -182,12 +183,15 @@ def verify_expert_weights_after_shuffle(
|
|||||||
|
|
||||||
# Check if the weights are correct
|
# Check if the weights are correct
|
||||||
actual_weights = weight_tensor[local_expert]
|
actual_weights = weight_tensor[local_expert]
|
||||||
expected_base = (expected_logical_expert * 1000 + layer * 100 +
|
expected_base = (
|
||||||
weight_idx * 10)
|
expected_logical_expert * 1000 + layer * 100 + weight_idx * 10
|
||||||
expected_weights = torch.arange(expected_base,
|
)
|
||||||
|
expected_weights = torch.arange(
|
||||||
|
expected_base,
|
||||||
expected_base + hidden_size,
|
expected_base + hidden_size,
|
||||||
device=actual_weights.device,
|
device=actual_weights.device,
|
||||||
dtype=actual_weights.dtype)
|
dtype=actual_weights.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
actual_weights,
|
actual_weights,
|
||||||
@@ -195,7 +199,8 @@ def verify_expert_weights_after_shuffle(
|
|||||||
msg=f"Layer {layer}, weight {weight_idx},"
|
msg=f"Layer {layer}, weight {weight_idx},"
|
||||||
f"local expert {local_expert}: "
|
f"local expert {local_expert}: "
|
||||||
f"weights do not match. "
|
f"weights do not match. "
|
||||||
f"Expected logical expert {expected_logical_expert}")
|
f"Expected logical expert {expected_logical_expert}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def verify_redundant_experts_have_same_weights(
|
def verify_redundant_experts_have_same_weights(
|
||||||
@@ -222,23 +227,23 @@ def verify_redundant_experts_have_same_weights(
|
|||||||
total_physical_experts,
|
total_physical_experts,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
device=expert_weights[layer][weight_idx].device,
|
device=expert_weights[layer][weight_idx].device,
|
||||||
dtype=expert_weights[layer][weight_idx].dtype)
|
dtype=expert_weights[layer][weight_idx].dtype,
|
||||||
|
)
|
||||||
|
|
||||||
# Use all_gather to collect expert weights from current node
|
# Use all_gather to collect expert weights from current node
|
||||||
# expert_weights[layer][weight_idx] shape:
|
# expert_weights[layer][weight_idx] shape:
|
||||||
# [num_local_experts, hidden_size]
|
# [num_local_experts, hidden_size]
|
||||||
local_weights = expert_weights[layer][
|
local_weights = expert_weights[layer][
|
||||||
weight_idx] # [num_local_experts, hidden_size]
|
weight_idx
|
||||||
|
] # [num_local_experts, hidden_size]
|
||||||
|
|
||||||
# Split tensor along dim 0 into a list for all_gather
|
# Split tensor along dim 0 into a list for all_gather
|
||||||
gathered_weights_list = torch.chunk(gathered_weights,
|
gathered_weights_list = torch.chunk(gathered_weights, world_size, dim=0)
|
||||||
world_size,
|
|
||||||
dim=0)
|
|
||||||
|
|
||||||
torch.distributed.all_gather(
|
torch.distributed.all_gather(
|
||||||
# Output list: each element corresponds to one rank's weights
|
# Output list: each element corresponds to one rank's weights
|
||||||
list(gathered_weights_list),
|
list(gathered_weights_list),
|
||||||
local_weights # Input: current rank's local weights
|
local_weights, # Input: current rank's local weights
|
||||||
)
|
)
|
||||||
|
|
||||||
all_weights.append(gathered_weights)
|
all_weights.append(gathered_weights)
|
||||||
@@ -266,7 +271,8 @@ def verify_redundant_experts_have_same_weights(
|
|||||||
msg=f"Layer {layer}, weight {weight_idx},"
|
msg=f"Layer {layer}, weight {weight_idx},"
|
||||||
f"logical expert {logical_expert_id}: "
|
f"logical expert {logical_expert_id}: "
|
||||||
f"Physical expert {physical_pos} has different weights"
|
f"Physical expert {physical_pos} has different weights"
|
||||||
f"than expected")
|
f"than expected",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -290,10 +296,11 @@ def verify_redundant_experts_have_same_weights(
|
|||||||
# 4 GPU, 8 experts per GPU
|
# 4 GPU, 8 experts per GPU
|
||||||
# 16 logical experts, 32 physical experts, 16 redundant experts
|
# 16 logical experts, 32 physical experts, 16 redundant experts
|
||||||
(4, 8, 8, 16),
|
(4, 8, 8, 16),
|
||||||
])
|
],
|
||||||
def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
)
|
||||||
num_local_experts,
|
def test_rearrange_expert_weights_with_redundancy(
|
||||||
num_logical_experts):
|
world_size, num_layers, num_local_experts, num_logical_experts
|
||||||
|
):
|
||||||
"""Test the functionality of rearranging expert weights with redundancy."""
|
"""Test the functionality of rearranging expert weights with redundancy."""
|
||||||
|
|
||||||
if torch.cuda.device_count() < world_size:
|
if torch.cuda.device_count() < world_size:
|
||||||
@@ -304,8 +311,8 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
|||||||
# Initialize model parallel (using tensor parallel as an entrypoint
|
# Initialize model parallel (using tensor parallel as an entrypoint
|
||||||
# to expert parallel)
|
# to expert parallel)
|
||||||
ensure_model_parallel_initialized(
|
ensure_model_parallel_initialized(
|
||||||
tensor_model_parallel_size=world_size,
|
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||||
pipeline_model_parallel_size=1)
|
)
|
||||||
|
|
||||||
ep_group = get_tp_group().cpu_group
|
ep_group = get_tp_group().cpu_group
|
||||||
ep_rank = torch.distributed.get_rank()
|
ep_rank = torch.distributed.get_rank()
|
||||||
@@ -316,8 +323,9 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
|||||||
hidden_sizes = [32, 64] # Two different weight matrices
|
hidden_sizes = [32, 64] # Two different weight matrices
|
||||||
|
|
||||||
# Create old expert indices (with redundancy)
|
# Create old expert indices (with redundancy)
|
||||||
redundancy_config = create_redundancy_config(num_logical_experts,
|
redundancy_config = create_redundancy_config(
|
||||||
total_physical_experts)
|
num_logical_experts, total_physical_experts
|
||||||
|
)
|
||||||
|
|
||||||
old_indices = create_expert_indices_with_redundancy(
|
old_indices = create_expert_indices_with_redundancy(
|
||||||
num_layers,
|
num_layers,
|
||||||
@@ -328,7 +336,8 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
|||||||
|
|
||||||
# Create new expert indices (with redundancy)
|
# Create new expert indices (with redundancy)
|
||||||
new_redundancy_config = create_redundancy_config(
|
new_redundancy_config = create_redundancy_config(
|
||||||
num_logical_experts, total_physical_experts)
|
num_logical_experts, total_physical_experts
|
||||||
|
)
|
||||||
new_indices = create_expert_indices_with_redundancy(
|
new_indices = create_expert_indices_with_redundancy(
|
||||||
num_layers,
|
num_layers,
|
||||||
num_logical_experts,
|
num_logical_experts,
|
||||||
@@ -337,9 +346,9 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create expert weights
|
# Create expert weights
|
||||||
expert_weights = create_expert_weights(num_layers, num_local_experts,
|
expert_weights = create_expert_weights(
|
||||||
hidden_sizes, ep_rank, device,
|
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||||
old_indices)
|
)
|
||||||
|
|
||||||
# Execute weight rearrangement
|
# Execute weight rearrangement
|
||||||
rearrange_expert_weights_inplace(
|
rearrange_expert_weights_inplace(
|
||||||
@@ -383,8 +392,8 @@ def test_rearrange_expert_weights_no_change(world_size):
|
|||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def worker_fn():
|
def worker_fn():
|
||||||
ensure_model_parallel_initialized(
|
ensure_model_parallel_initialized(
|
||||||
tensor_model_parallel_size=world_size,
|
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||||
pipeline_model_parallel_size=1)
|
)
|
||||||
|
|
||||||
ep_group = get_tp_group().cpu_group
|
ep_group = get_tp_group().cpu_group
|
||||||
ep_rank = torch.distributed.get_rank()
|
ep_rank = torch.distributed.get_rank()
|
||||||
@@ -401,12 +410,12 @@ def test_rearrange_expert_weights_no_change(world_size):
|
|||||||
|
|
||||||
# Same indices - no change
|
# Same indices - no change
|
||||||
indices = create_expert_indices_with_redundancy(
|
indices = create_expert_indices_with_redundancy(
|
||||||
num_layers, num_logical_experts, total_physical_experts,
|
num_layers, num_logical_experts, total_physical_experts, redundancy_config
|
||||||
redundancy_config)
|
)
|
||||||
|
|
||||||
expert_weights = create_expert_weights(num_layers, num_local_experts,
|
expert_weights = create_expert_weights(
|
||||||
hidden_sizes, ep_rank, device,
|
num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
|
||||||
indices)
|
)
|
||||||
|
|
||||||
# Save original weights
|
# Save original weights
|
||||||
original_weights = []
|
original_weights = []
|
||||||
@@ -422,7 +431,8 @@ def test_rearrange_expert_weights_no_change(world_size):
|
|||||||
indices, # Same indices
|
indices, # Same indices
|
||||||
expert_weights,
|
expert_weights,
|
||||||
ep_group,
|
ep_group,
|
||||||
is_profile=False)
|
is_profile=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Verify that the weights have not changed
|
# Verify that the weights have not changed
|
||||||
for layer in range(num_layers):
|
for layer in range(num_layers):
|
||||||
@@ -430,8 +440,8 @@ def test_rearrange_expert_weights_no_change(world_size):
|
|||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
expert_weights[layer][weight_idx],
|
expert_weights[layer][weight_idx],
|
||||||
original_weights[layer][weight_idx],
|
original_weights[layer][weight_idx],
|
||||||
msg=f"Layer {layer}, weight {weight_idx} should remain "
|
msg=f"Layer {layer}, weight {weight_idx} should remain unchanged",
|
||||||
f"unchanged")
|
)
|
||||||
|
|
||||||
distributed_run(worker_fn, world_size)
|
distributed_run(worker_fn, world_size)
|
||||||
|
|
||||||
@@ -446,8 +456,8 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
|||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def worker_fn():
|
def worker_fn():
|
||||||
ensure_model_parallel_initialized(
|
ensure_model_parallel_initialized(
|
||||||
tensor_model_parallel_size=world_size,
|
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||||
pipeline_model_parallel_size=1)
|
)
|
||||||
|
|
||||||
ep_group = get_tp_group().cpu_group
|
ep_group = get_tp_group().cpu_group
|
||||||
ep_rank = torch.distributed.get_rank()
|
ep_rank = torch.distributed.get_rank()
|
||||||
@@ -460,21 +470,23 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
|||||||
hidden_sizes = [32]
|
hidden_sizes = [32]
|
||||||
|
|
||||||
# Create different index distributions
|
# Create different index distributions
|
||||||
old_redundancy = create_redundancy_config(num_logical_experts,
|
old_redundancy = create_redundancy_config(
|
||||||
total_physical_experts)
|
num_logical_experts, total_physical_experts
|
||||||
new_redundancy = create_redundancy_config(num_logical_experts,
|
)
|
||||||
total_physical_experts)
|
new_redundancy = create_redundancy_config(
|
||||||
|
num_logical_experts, total_physical_experts
|
||||||
|
)
|
||||||
|
|
||||||
old_indices = create_expert_indices_with_redundancy(
|
old_indices = create_expert_indices_with_redundancy(
|
||||||
num_layers, num_logical_experts, total_physical_experts,
|
num_layers, num_logical_experts, total_physical_experts, old_redundancy
|
||||||
old_redundancy)
|
)
|
||||||
new_indices = create_expert_indices_with_redundancy(
|
new_indices = create_expert_indices_with_redundancy(
|
||||||
num_layers, num_logical_experts, total_physical_experts,
|
num_layers, num_logical_experts, total_physical_experts, new_redundancy
|
||||||
new_redundancy)
|
)
|
||||||
|
|
||||||
expert_weights = create_expert_weights(num_layers, num_local_experts,
|
expert_weights = create_expert_weights(
|
||||||
hidden_sizes, ep_rank, device,
|
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||||
old_indices)
|
)
|
||||||
|
|
||||||
# Save original weights
|
# Save original weights
|
||||||
original_weights = []
|
original_weights = []
|
||||||
@@ -490,7 +502,7 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
|||||||
new_indices,
|
new_indices,
|
||||||
expert_weights,
|
expert_weights,
|
||||||
ep_group,
|
ep_group,
|
||||||
is_profile=True # Profile mode
|
is_profile=True, # Profile mode
|
||||||
)
|
)
|
||||||
|
|
||||||
# In profile mode, the weights should remain unchanged
|
# In profile mode, the weights should remain unchanged
|
||||||
@@ -499,6 +511,7 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
|||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
expert_weights[layer][weight_idx],
|
expert_weights[layer][weight_idx],
|
||||||
original_weights[layer][weight_idx],
|
original_weights[layer][weight_idx],
|
||||||
msg="In profile mode, the weights should remain unchanged")
|
msg="In profile mode, the weights should remain unchanged",
|
||||||
|
)
|
||||||
|
|
||||||
distributed_run(worker_fn, world_size)
|
distributed_run(worker_fn, world_size)
|
||||||
|
|||||||
@@ -6,8 +6,11 @@ import time
|
|||||||
import msgspec
|
import msgspec
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
|
from vllm.distributed.kv_events import (
|
||||||
NullEventPublisher)
|
EventBatch,
|
||||||
|
EventPublisherFactory,
|
||||||
|
NullEventPublisher,
|
||||||
|
)
|
||||||
|
|
||||||
DP_RANK = 0
|
DP_RANK = 0
|
||||||
|
|
||||||
@@ -15,15 +18,17 @@ DP_RANK = 0
|
|||||||
class EventSample(
|
class EventSample(
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
tag=True, # type: ignore
|
tag=True, # type: ignore
|
||||||
array_like=True # type: ignore
|
array_like=True, # type: ignore
|
||||||
):
|
):
|
||||||
"""Test event for publisher testing"""
|
"""Test event for publisher testing"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
value: str
|
value: str
|
||||||
|
|
||||||
|
|
||||||
class SampleBatch(EventBatch):
|
class SampleBatch(EventBatch):
|
||||||
"""Test event batch for publisher testing"""
|
"""Test event batch for publisher testing"""
|
||||||
|
|
||||||
events: list[EventSample]
|
events: list[EventSample]
|
||||||
|
|
||||||
|
|
||||||
@@ -44,10 +49,8 @@ def test_basic_publishing(publisher, subscriber):
|
|||||||
|
|
||||||
seq, received = result
|
seq, received = result
|
||||||
assert seq == 0, "Sequence number mismatch"
|
assert seq == 0, "Sequence number mismatch"
|
||||||
assert received.ts == pytest.approx(test_batch.ts,
|
assert received.ts == pytest.approx(test_batch.ts, abs=0.1), "Timestamp mismatch"
|
||||||
abs=0.1), ("Timestamp mismatch")
|
assert len(received.events) == len(test_batch.events), "Number of events mismatch"
|
||||||
assert len(received.events) == len(
|
|
||||||
test_batch.events), ("Number of events mismatch")
|
|
||||||
|
|
||||||
for i, event in enumerate(received.events):
|
for i, event in enumerate(received.events):
|
||||||
assert event.id == i, "Event id mismatch"
|
assert event.id == i, "Event id mismatch"
|
||||||
@@ -88,9 +91,9 @@ def test_replay_mechanism(publisher, subscriber):
|
|||||||
assert len(replayed) > 0, "No replayed messages received"
|
assert len(replayed) > 0, "No replayed messages received"
|
||||||
seqs = [seq for seq, _ in replayed]
|
seqs = [seq for seq, _ in replayed]
|
||||||
assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
|
assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
|
||||||
assert seqs == list(range(min(seqs),
|
assert seqs == list(range(min(seqs), max(seqs) + 1)), (
|
||||||
max(seqs) +
|
"Replayed messages not consecutive"
|
||||||
1)), ("Replayed messages not consecutive")
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_buffer_limit(publisher, subscriber, publisher_config):
|
def test_buffer_limit(publisher, subscriber, publisher_config):
|
||||||
@@ -126,6 +129,7 @@ def test_topic_filtering(publisher_config):
|
|||||||
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
|
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
|
||||||
|
|
||||||
from .conftest import MockSubscriber
|
from .conftest import MockSubscriber
|
||||||
|
|
||||||
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
|
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
|
||||||
sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar")
|
sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar")
|
||||||
|
|
||||||
@@ -137,11 +141,13 @@ def test_topic_filtering(publisher_config):
|
|||||||
|
|
||||||
foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
|
foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
|
||||||
assert all(msg is not None for msg in foo_received), (
|
assert all(msg is not None for msg in foo_received), (
|
||||||
"Subscriber with matching topic should receive messages")
|
"Subscriber with matching topic should receive messages"
|
||||||
|
)
|
||||||
|
|
||||||
bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
|
bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
|
||||||
assert all(msg is None for msg in bar_received), (
|
assert all(msg is None for msg in bar_received), (
|
||||||
"Subscriber with non-matching topic should receive no messages")
|
"Subscriber with non-matching topic should receive no messages"
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
pub.shutdown()
|
pub.shutdown()
|
||||||
sub_foo.close()
|
sub_foo.close()
|
||||||
@@ -178,8 +184,7 @@ def test_high_volume(publisher, subscriber):
|
|||||||
|
|
||||||
publisher_thread.join()
|
publisher_thread.join()
|
||||||
|
|
||||||
assert len(received) >= num_batches * 0.9, (
|
assert len(received) >= num_batches * 0.9, "We should have received most messages"
|
||||||
"We should have received most messages")
|
|
||||||
|
|
||||||
seqs = [seq for seq, _ in received]
|
seqs = [seq for seq, _ in received]
|
||||||
assert sorted(seqs) == seqs, "Sequence numbers should be in order"
|
assert sorted(seqs) == seqs, "Sequence numbers should be in order"
|
||||||
@@ -209,13 +214,15 @@ def test_data_parallel_rank_tagging(publisher_config):
|
|||||||
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
|
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
|
||||||
expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port
|
expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port
|
||||||
expected_endpoint_1 = base_endpoint.replace(
|
expected_endpoint_1 = base_endpoint.replace(
|
||||||
":5557", ":5558") # rank 1 gets port + 1
|
":5557", ":5558"
|
||||||
|
) # rank 1 gets port + 1
|
||||||
else:
|
else:
|
||||||
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
|
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
|
||||||
expected_endpoint_0 = base_endpoint # rank 0 gets base
|
expected_endpoint_0 = base_endpoint # rank 0 gets base
|
||||||
expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1
|
expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1
|
||||||
|
|
||||||
from .conftest import MockSubscriber
|
from .conftest import MockSubscriber
|
||||||
|
|
||||||
sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic)
|
sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic)
|
||||||
sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic)
|
sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic)
|
||||||
|
|
||||||
@@ -241,15 +248,15 @@ def test_data_parallel_rank_tagging(publisher_config):
|
|||||||
|
|
||||||
# Verify DP rank tagging
|
# Verify DP rank tagging
|
||||||
assert received_0.data_parallel_rank == 0, (
|
assert received_0.data_parallel_rank == 0, (
|
||||||
f"Expected DP rank 0, got {received_0.data_parallel_rank}")
|
f"Expected DP rank 0, got {received_0.data_parallel_rank}"
|
||||||
|
)
|
||||||
assert received_1.data_parallel_rank == 1, (
|
assert received_1.data_parallel_rank == 1, (
|
||||||
f"Expected DP rank 1, got {received_1.data_parallel_rank}")
|
f"Expected DP rank 1, got {received_1.data_parallel_rank}"
|
||||||
|
)
|
||||||
|
|
||||||
# Verify event content is correct
|
# Verify event content is correct
|
||||||
assert len(
|
assert len(received_0.events) == 2, "Wrong number of events from rank 0"
|
||||||
received_0.events) == 2, "Wrong number of events from rank 0"
|
assert len(received_1.events) == 3, "Wrong number of events from rank 1"
|
||||||
assert len(
|
|
||||||
received_1.events) == 3, "Wrong number of events from rank 1"
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
pub_0.shutdown()
|
pub_0.shutdown()
|
||||||
|
|||||||
@@ -46,28 +46,24 @@ class EPTestSettings:
|
|||||||
):
|
):
|
||||||
return EPTestSettings(
|
return EPTestSettings(
|
||||||
parallel_setups=[
|
parallel_setups=[
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=False),
|
||||||
eager_mode=False,
|
ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=True),
|
||||||
chunked_prefill=False),
|
ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False),
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(
|
||||||
eager_mode=False,
|
tp_size=2 * tp_base, eager_mode=False, chunked_prefill=True
|
||||||
chunked_prefill=True),
|
),
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(
|
||||||
eager_mode=True,
|
tp_size=2 * tp_base, eager_mode=True, chunked_prefill=False
|
||||||
chunked_prefill=False),
|
),
|
||||||
ParallelSetup(tp_size=2 * tp_base,
|
|
||||||
eager_mode=False,
|
|
||||||
chunked_prefill=True),
|
|
||||||
ParallelSetup(tp_size=2 * tp_base,
|
|
||||||
eager_mode=True,
|
|
||||||
chunked_prefill=False),
|
|
||||||
],
|
],
|
||||||
distributed_backends=["mp", "ray"],
|
distributed_backends=["mp", "ray"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=EPTestOptions(trust_remote_code=trust_remote_code,
|
test_options=EPTestOptions(
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode,
|
||||||
load_format=load_format,
|
load_format=load_format,
|
||||||
hf_overrides=hf_overrides),
|
hf_overrides=hf_overrides,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -82,16 +78,16 @@ class EPTestSettings:
|
|||||||
):
|
):
|
||||||
return EPTestSettings(
|
return EPTestSettings(
|
||||||
parallel_setups=[
|
parallel_setups=[
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False),
|
||||||
eager_mode=True,
|
|
||||||
chunked_prefill=False),
|
|
||||||
],
|
],
|
||||||
distributed_backends=["mp"],
|
distributed_backends=["mp"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=EPTestOptions(trust_remote_code=trust_remote_code,
|
test_options=EPTestOptions(
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode,
|
||||||
load_format=load_format,
|
load_format=load_format,
|
||||||
hf_overrides=hf_overrides),
|
hf_overrides=hf_overrides,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_params(self, model_name: str):
|
def iter_params(self, model_name: str):
|
||||||
@@ -99,8 +95,13 @@ class EPTestSettings:
|
|||||||
|
|
||||||
for parallel_setup in self.parallel_setups:
|
for parallel_setup in self.parallel_setups:
|
||||||
for distributed_backend in self.distributed_backends:
|
for distributed_backend in self.distributed_backends:
|
||||||
yield (model_name, parallel_setup, distributed_backend,
|
yield (
|
||||||
self.runner, opts)
|
model_name,
|
||||||
|
parallel_setup,
|
||||||
|
distributed_backend,
|
||||||
|
self.runner,
|
||||||
|
opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# NOTE: You can adjust tp_base locally to fit the model in GPU
|
# NOTE: You can adjust tp_base locally to fit the model in GPU
|
||||||
|
|||||||
@@ -6,8 +6,7 @@ import pytest
|
|||||||
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
|
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
|
||||||
|
|
||||||
|
|
||||||
def verify_round_robin_pattern(expert_map, ep_rank, ep_size,
|
def verify_round_robin_pattern(expert_map, ep_rank, ep_size, global_num_experts):
|
||||||
global_num_experts):
|
|
||||||
"""Verify that the expert map follows the round_robin pattern."""
|
"""Verify that the expert map follows the round_robin pattern."""
|
||||||
# Calculate expected local experts (supporting non-divisible cases)
|
# Calculate expected local experts (supporting non-divisible cases)
|
||||||
base_experts = global_num_experts // ep_size
|
base_experts = global_num_experts // ep_size
|
||||||
@@ -30,24 +29,21 @@ def verify_round_robin_pattern(expert_map, ep_rank, ep_size,
|
|||||||
if global_expert_id in expected_expert_ids:
|
if global_expert_id in expected_expert_ids:
|
||||||
local_expert_id = expert_map[global_expert_id]
|
local_expert_id = expert_map[global_expert_id]
|
||||||
expected_local_id = expected_expert_ids.index(global_expert_id)
|
expected_local_id = expected_expert_ids.index(global_expert_id)
|
||||||
assert (
|
assert local_expert_id == expected_local_id, (
|
||||||
local_expert_id == expected_local_id
|
f"Global expert {global_expert_id} should map to local expert "
|
||||||
), f"Global expert {global_expert_id} should map to local expert " \
|
|
||||||
f"{expected_local_id}, got {local_expert_id}"
|
f"{expected_local_id}, got {local_expert_id}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert expert_map[global_expert_id] == -1, (
|
||||||
expert_map[global_expert_id] == -1
|
f"Global expert {global_expert_id} should not be mapped to this rank"
|
||||||
), f"Global expert {global_expert_id} should not be mapped to " \
|
)
|
||||||
f"this rank"
|
|
||||||
|
|
||||||
# Verify that all local expert IDs are consecutive starting from 0
|
# Verify that all local expert IDs are consecutive starting from 0
|
||||||
local_expert_ids = [
|
local_expert_ids = [expert_map[global_id] for global_id in expected_expert_ids]
|
||||||
expert_map[global_id] for global_id in expected_expert_ids
|
|
||||||
]
|
|
||||||
expected_local_ids = list(range(local_num_experts))
|
expected_local_ids = list(range(local_num_experts))
|
||||||
assert (
|
assert local_expert_ids == expected_local_ids, (
|
||||||
local_expert_ids == expected_local_ids
|
f"Expected local expert IDs {expected_local_ids}, got {local_expert_ids}"
|
||||||
), f"Expected local expert IDs {expected_local_ids}, got {local_expert_ids}"
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
|
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
|
||||||
@@ -78,8 +74,9 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size):
|
|||||||
|
|
||||||
for test_global_experts, test_ep_size in test_cases:
|
for test_global_experts, test_ep_size in test_cases:
|
||||||
# Ensure ep_size matches world_size
|
# Ensure ep_size matches world_size
|
||||||
assert (test_ep_size == world_size
|
assert test_ep_size == world_size, (
|
||||||
), f"ep_size {test_ep_size} must equal world_size {world_size}"
|
f"ep_size {test_ep_size} must equal world_size {world_size}"
|
||||||
|
)
|
||||||
|
|
||||||
# Test each rank
|
# Test each rank
|
||||||
for ep_rank in range(world_size):
|
for ep_rank in range(world_size):
|
||||||
@@ -98,21 +95,22 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size):
|
|||||||
expert_placement_strategy=expert_placement_strategy,
|
expert_placement_strategy=expert_placement_strategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert test_local_experts == expected_test_local, (
|
||||||
test_local_experts == expected_test_local
|
f"For {test_global_experts} experts on {test_ep_size} ranks, "
|
||||||
), f"For {test_global_experts} experts on {test_ep_size} ranks, " \
|
f"rank {ep_rank}: expected {expected_test_local} local"
|
||||||
f"rank {ep_rank}: expected {expected_test_local} local" \
|
|
||||||
f"experts, got {test_local_experts}"
|
f"experts, got {test_local_experts}"
|
||||||
|
)
|
||||||
|
|
||||||
if test_expert_map is not None:
|
if test_expert_map is not None:
|
||||||
assert test_expert_map.shape == (
|
assert test_expert_map.shape == (test_global_experts,), (
|
||||||
test_global_experts,
|
f"Expected expert map shape ({test_global_experts},), "
|
||||||
), f"Expected expert map shape ({test_global_experts},), " \
|
|
||||||
f"got {test_expert_map.shape}"
|
f"got {test_expert_map.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
# Verify round_robin pattern for this test case
|
# Verify round_robin pattern for this test case
|
||||||
verify_round_robin_pattern(test_expert_map, ep_rank,
|
verify_round_robin_pattern(
|
||||||
test_ep_size, test_global_experts)
|
test_expert_map, ep_rank, test_ep_size, test_global_experts
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
|
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
|
||||||
@@ -147,28 +145,81 @@ def test_determine_expert_map_comprehensive():
|
|||||||
# expert_placement_strategy, expected_local, expected_map_pattern)
|
# expert_placement_strategy, expected_local, expected_map_pattern)
|
||||||
test_cases = [
|
test_cases = [
|
||||||
# Round robin placement tests
|
# Round robin placement tests
|
||||||
(2, 0, 8, "round_robin", 4, [0, -1, 1, -1, 2, -1, 3,
|
(
|
||||||
-1]), # rank 0 gets even experts
|
2,
|
||||||
(2, 1, 8, "round_robin", 4, [-1, 0, -1, 1, -1, 2, -1,
|
0,
|
||||||
3]), # rank 1 gets odd experts
|
8,
|
||||||
(2, 0, 9, "round_robin", 5, [0, -1, 1, -1, 2, -1, 3, -1, 4
|
"round_robin",
|
||||||
]), # rank 0 gets 5 experts (even + last)
|
4,
|
||||||
(2, 1, 9, "round_robin", 4, [-1, 0, -1, 1, -1, 2, -1, 3,
|
[0, -1, 1, -1, 2, -1, 3, -1],
|
||||||
-1]), # rank 1 gets 4 experts (odd)
|
), # rank 0 gets even experts
|
||||||
|
(
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
"round_robin",
|
||||||
|
4,
|
||||||
|
[-1, 0, -1, 1, -1, 2, -1, 3],
|
||||||
|
), # rank 1 gets odd experts
|
||||||
|
(
|
||||||
|
2,
|
||||||
|
0,
|
||||||
|
9,
|
||||||
|
"round_robin",
|
||||||
|
5,
|
||||||
|
[0, -1, 1, -1, 2, -1, 3, -1, 4],
|
||||||
|
), # rank 0 gets 5 experts (even + last)
|
||||||
|
(
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
9,
|
||||||
|
"round_robin",
|
||||||
|
4,
|
||||||
|
[-1, 0, -1, 1, -1, 2, -1, 3, -1],
|
||||||
|
), # rank 1 gets 4 experts (odd)
|
||||||
# 4-rank tests
|
# 4-rank tests
|
||||||
(4, 0, 8, "round_robin", 2, [0, -1, -1, -1, 1, -1, -1,
|
(
|
||||||
-1]), # rank 0 gets experts 0, 4
|
4,
|
||||||
(4, 1, 8, "round_robin", 2, [-1, 0, -1, -1, -1, 1, -1,
|
0,
|
||||||
-1]), # rank 1 gets experts 1, 5
|
8,
|
||||||
(4, 2, 8, "round_robin", 2, [-1, -1, 0, -1, -1, -1, 1,
|
"round_robin",
|
||||||
-1]), # rank 2 gets experts 2, 6
|
2,
|
||||||
(4, 3, 8, "round_robin", 2, [-1, -1, -1, 0, -1, -1, -1,
|
[0, -1, -1, -1, 1, -1, -1, -1],
|
||||||
1]), # rank 3 gets experts 3, 7
|
), # rank 0 gets experts 0, 4
|
||||||
|
(
|
||||||
|
4,
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
"round_robin",
|
||||||
|
2,
|
||||||
|
[-1, 0, -1, -1, -1, 1, -1, -1],
|
||||||
|
), # rank 1 gets experts 1, 5
|
||||||
|
(
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
8,
|
||||||
|
"round_robin",
|
||||||
|
2,
|
||||||
|
[-1, -1, 0, -1, -1, -1, 1, -1],
|
||||||
|
), # rank 2 gets experts 2, 6
|
||||||
|
(
|
||||||
|
4,
|
||||||
|
3,
|
||||||
|
8,
|
||||||
|
"round_robin",
|
||||||
|
2,
|
||||||
|
[-1, -1, -1, 0, -1, -1, -1, 1],
|
||||||
|
), # rank 3 gets experts 3, 7
|
||||||
]
|
]
|
||||||
|
|
||||||
for ep_size, ep_rank, global_num_experts, expert_placement_strategy, \
|
for (
|
||||||
expected_local, expected_map_pattern in test_cases:
|
ep_size,
|
||||||
|
ep_rank,
|
||||||
|
global_num_experts,
|
||||||
|
expert_placement_strategy,
|
||||||
|
expected_local,
|
||||||
|
expected_map_pattern,
|
||||||
|
) in test_cases:
|
||||||
local_num_experts, expert_map = determine_expert_map(
|
local_num_experts, expert_map = determine_expert_map(
|
||||||
ep_size=ep_size,
|
ep_size=ep_size,
|
||||||
ep_rank=ep_rank,
|
ep_rank=ep_rank,
|
||||||
@@ -176,19 +227,21 @@ def test_determine_expert_map_comprehensive():
|
|||||||
expert_placement_strategy=expert_placement_strategy,
|
expert_placement_strategy=expert_placement_strategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert local_num_experts == expected_local, \
|
assert local_num_experts == expected_local, (
|
||||||
f"ep_size={ep_size}, ep_rank={ep_rank}, " \
|
f"ep_size={ep_size}, ep_rank={ep_rank}, "
|
||||||
f"global_num_experts={global_num_experts}, " \
|
f"global_num_experts={global_num_experts}, "
|
||||||
f"expert_placement_strategy={expert_placement_strategy}: " \
|
f"expert_placement_strategy={expert_placement_strategy}: "
|
||||||
f"expected {expected_local} local experts, got {local_num_experts}"
|
f"expected {expected_local} local experts, got {local_num_experts}"
|
||||||
|
)
|
||||||
|
|
||||||
if expected_map_pattern is None:
|
if expected_map_pattern is None:
|
||||||
assert expert_map is None, "Expected expert_map to be None"
|
assert expert_map is None, "Expected expert_map to be None"
|
||||||
else:
|
else:
|
||||||
assert expert_map is not None, "Expected expert_map to not be None"
|
assert expert_map is not None, "Expected expert_map to not be None"
|
||||||
actual_map = expert_map.tolist()
|
actual_map = expert_map.tolist()
|
||||||
assert actual_map == expected_map_pattern, \
|
assert actual_map == expected_map_pattern, (
|
||||||
f"ep_size={ep_size}, ep_rank={ep_rank}, " \
|
f"ep_size={ep_size}, ep_rank={ep_rank}, "
|
||||||
f"global_num_experts={global_num_experts}, " \
|
f"global_num_experts={global_num_experts}, "
|
||||||
f"expert_placement_strategy={expert_placement_strategy}: " \
|
f"expert_placement_strategy={expert_placement_strategy}: "
|
||||||
f"expected map {expected_map_pattern}, got {actual_map}"
|
f"expected map {expected_map_pattern}, got {actual_map}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from vllm.config import (DeviceConfig, KVTransferConfig, ModelConfig,
|
from vllm.config import (
|
||||||
VllmConfig, set_current_vllm_config)
|
DeviceConfig,
|
||||||
|
KVTransferConfig,
|
||||||
|
ModelConfig,
|
||||||
|
VllmConfig,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||||
get_kv_connector_cache_layout)
|
get_kv_connector_cache_layout,
|
||||||
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger("test_expert_parallel")
|
logger = init_logger("test_expert_parallel")
|
||||||
@@ -23,8 +29,9 @@ def test_get_kv_connector_cache_layout_with_lmcache_connector():
|
|||||||
kv_connector="LMCacheConnectorV1",
|
kv_connector="LMCacheConnectorV1",
|
||||||
kv_role="kv_both",
|
kv_role="kv_both",
|
||||||
)
|
)
|
||||||
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
|
vllm_config = VllmConfig(
|
||||||
kv_transfer_config=kv_transfer_config)
|
device_config=DeviceConfig("cpu"), kv_transfer_config=kv_transfer_config
|
||||||
|
)
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
# Test with default settings
|
# Test with default settings
|
||||||
layout = get_kv_connector_cache_layout()
|
layout = get_kv_connector_cache_layout()
|
||||||
@@ -37,9 +44,11 @@ def test_get_kv_connector_cache_layout_with_nixl_connector():
|
|||||||
kv_role="kv_both",
|
kv_role="kv_both",
|
||||||
)
|
)
|
||||||
model_config = ModelConfig()
|
model_config = ModelConfig()
|
||||||
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
|
vllm_config = VllmConfig(
|
||||||
|
device_config=DeviceConfig("cpu"),
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
kv_transfer_config=kv_transfer_config)
|
kv_transfer_config=kv_transfer_config,
|
||||||
|
)
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
# Test with default settings
|
# Test with default settings
|
||||||
layout = get_kv_connector_cache_layout()
|
layout = get_kv_connector_cache_layout()
|
||||||
@@ -47,25 +56,22 @@ def test_get_kv_connector_cache_layout_with_nixl_connector():
|
|||||||
|
|
||||||
|
|
||||||
def test_get_kv_connector_cache_layout_with_multi_connector():
|
def test_get_kv_connector_cache_layout_with_multi_connector():
|
||||||
kv_transfer_config = KVTransferConfig(kv_connector="MultiConnector",
|
kv_transfer_config = KVTransferConfig(
|
||||||
|
kv_connector="MultiConnector",
|
||||||
kv_role="kv_both",
|
kv_role="kv_both",
|
||||||
kv_connector_extra_config={
|
kv_connector_extra_config={
|
||||||
"connectors": [{
|
"connectors": [
|
||||||
"kv_connector":
|
{"kv_connector": "SharedStorageConnector", "kv_role": "kv_both"},
|
||||||
"SharedStorageConnector",
|
{"kv_connector": "NixlConnector", "kv_role": "kv_both"},
|
||||||
"kv_role":
|
]
|
||||||
"kv_both"
|
},
|
||||||
}, {
|
)
|
||||||
"kv_connector":
|
|
||||||
"NixlConnector",
|
|
||||||
"kv_role":
|
|
||||||
"kv_both"
|
|
||||||
}]
|
|
||||||
})
|
|
||||||
model_config = ModelConfig()
|
model_config = ModelConfig()
|
||||||
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
|
vllm_config = VllmConfig(
|
||||||
|
device_config=DeviceConfig("cpu"),
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
kv_transfer_config=kv_transfer_config)
|
kv_transfer_config=kv_transfer_config,
|
||||||
|
)
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
# Test with default settings
|
# Test with default settings
|
||||||
layout = get_kv_connector_cache_layout()
|
layout = get_kv_connector_cache_layout()
|
||||||
|
|||||||
@@ -24,14 +24,13 @@ from vllm.utils import get_ip
|
|||||||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not VLLM_MULTI_NODE,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 nodes to run the test.")
|
not VLLM_MULTI_NODE, reason="Need at least 2 nodes to run the test."
|
||||||
|
)
|
||||||
def test_multi_node_assignment() -> None:
|
def test_multi_node_assignment() -> None:
|
||||||
|
|
||||||
# NOTE: important to keep this class definition here
|
# NOTE: important to keep this class definition here
|
||||||
# to let ray use cloudpickle to serialize it.
|
# to let ray use cloudpickle to serialize it.
|
||||||
class Actor:
|
class Actor:
|
||||||
|
|
||||||
def get_ip(self):
|
def get_ip(self):
|
||||||
return get_ip()
|
return get_ip()
|
||||||
|
|
||||||
@@ -41,8 +40,7 @@ def test_multi_node_assignment() -> None:
|
|||||||
|
|
||||||
current_ip = get_ip()
|
current_ip = get_ip()
|
||||||
workers = []
|
workers = []
|
||||||
for bundle_id, bundle in enumerate(
|
for bundle_id, bundle in enumerate(config.placement_group.bundle_specs):
|
||||||
config.placement_group.bundle_specs):
|
|
||||||
if not bundle.get("GPU", 0):
|
if not bundle.get("GPU", 0):
|
||||||
continue
|
continue
|
||||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||||
|
|||||||
@@ -11,15 +11,17 @@ import torch.multiprocessing as mp
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
|
||||||
CudaCommunicator)
|
from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops
|
||||||
from vllm.distributed.device_communicators.pynccl import (
|
|
||||||
register_nccl_symmetric_ops)
|
|
||||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||||
get_nccl_mem_pool, is_symmetric_memory_enabled)
|
get_nccl_mem_pool,
|
||||||
from vllm.distributed.parallel_state import (get_tp_group,
|
is_symmetric_memory_enabled,
|
||||||
|
)
|
||||||
|
from vllm.distributed.parallel_state import (
|
||||||
|
get_tp_group,
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
initialize_model_parallel)
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
@@ -38,31 +40,32 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
|||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
update_environment_variables({
|
update_environment_variables(
|
||||||
|
{
|
||||||
"RANK": str(local_rank),
|
"RANK": str(local_rank),
|
||||||
"LOCAL_RANK": str(local_rank),
|
"LOCAL_RANK": str(local_rank),
|
||||||
"WORLD_SIZE": str(world_size),
|
"WORLD_SIZE": str(world_size),
|
||||||
"MASTER_ADDR": "localhost",
|
"MASTER_ADDR": "localhost",
|
||||||
"MASTER_PORT": "12345",
|
"MASTER_PORT": "12345",
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
|
||||||
cuda_communicator = typing.cast(CudaCommunicator,
|
cuda_communicator = typing.cast(
|
||||||
get_tp_group().device_communicator)
|
CudaCommunicator, get_tp_group().device_communicator
|
||||||
|
)
|
||||||
pynccl_comm = cuda_communicator.pynccl_comm
|
pynccl_comm = cuda_communicator.pynccl_comm
|
||||||
if get_nccl_mem_pool() is None:
|
if get_nccl_mem_pool() is None:
|
||||||
pytest.skip("NCCL allocator compilation failed "
|
pytest.skip(
|
||||||
"(probably missing NCCL headers).")
|
"NCCL allocator compilation failed (probably missing NCCL headers)."
|
||||||
|
)
|
||||||
if not is_symmetric_memory_enabled():
|
if not is_symmetric_memory_enabled():
|
||||||
pytest.skip("NCCL symmetric memory allreduce is disabled.")
|
pytest.skip("NCCL symmetric memory allreduce is disabled.")
|
||||||
|
|
||||||
register_nccl_symmetric_ops(pynccl_comm)
|
register_nccl_symmetric_ops(pynccl_comm)
|
||||||
input = torch.randint(1,
|
input = torch.randint(1, 23, (test_size_elements,), dtype=dtype, device=device)
|
||||||
23, (test_size_elements, ),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
input_clone = input.clone()
|
input_clone = input.clone()
|
||||||
output = torch.ops.vllm.all_reduce_symmetric_with_copy(input)
|
output = torch.ops.vllm.all_reduce_symmetric_with_copy(input)
|
||||||
assert output is not None
|
assert output is not None
|
||||||
@@ -77,8 +80,7 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
|||||||
reason="NCCLSymmMemAllreduce is only available for CUDA platforms.",
|
reason="NCCLSymmMemAllreduce is only available for CUDA platforms.",
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("world_size", [2])
|
@pytest.mark.parametrize("world_size", [2])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||||
reason="Only test on CUDA")
|
|
||||||
def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
|
def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
|
||||||
if world_size > torch.cuda.device_count():
|
if world_size > torch.cuda.device_count():
|
||||||
pytest.skip("Not enough GPUs to run the test.")
|
pytest.skip("Not enough GPUs to run the test.")
|
||||||
@@ -88,7 +90,5 @@ def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
|
|||||||
monkeypatch.setenv("NCCL_NVLS_ENABLE", "1")
|
monkeypatch.setenv("NCCL_NVLS_ENABLE", "1")
|
||||||
monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1")
|
monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1")
|
||||||
|
|
||||||
mp.spawn(nccl_symm_mem_allreduce_worker,
|
mp.spawn(nccl_symm_mem_allreduce_worker, args=(world_size,), nprocs=world_size)
|
||||||
args=(world_size, ),
|
|
||||||
nprocs=world_size)
|
|
||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|||||||
@@ -32,12 +32,15 @@ if __name__ == "__main__":
|
|||||||
# Expected node count based on environment variable)
|
# Expected node count based on environment variable)
|
||||||
expected = int(os.environ.get("NUM_NODES", "1"))
|
expected = int(os.environ.get("NUM_NODES", "1"))
|
||||||
|
|
||||||
assert test_result == expected, \
|
assert test_result == expected, f"Expected {expected} nodes, got {test_result}"
|
||||||
f"Expected {expected} nodes, got {test_result}"
|
|
||||||
|
|
||||||
if pg == dist.group.WORLD:
|
if pg == dist.group.WORLD:
|
||||||
print(f"Node count test passed! Got {test_result} nodes "
|
print(
|
||||||
f"when using torch distributed!")
|
f"Node count test passed! Got {test_result} nodes "
|
||||||
|
f"when using torch distributed!"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f"Node count test passed! Got {test_result} nodes "
|
print(
|
||||||
f"when using StatelessProcessGroup!")
|
f"Node count test passed! Got {test_result} nodes "
|
||||||
|
f"when using StatelessProcessGroup!"
|
||||||
|
)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
|||||||
all workers in a node other than the head node, which can cause the test
|
all workers in a node other than the head node, which can cause the test
|
||||||
to fail.
|
to fail.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -55,26 +56,17 @@ class PPTestSettings:
|
|||||||
):
|
):
|
||||||
return PPTestSettings(
|
return PPTestSettings(
|
||||||
parallel_setups=[
|
parallel_setups=[
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(tp_size=tp_base, pp_size=pp_base, eager_mode=False),
|
||||||
pp_size=pp_base,
|
ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, eager_mode=False),
|
||||||
eager_mode=False),
|
ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, eager_mode=True),
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, eager_mode=False),
|
||||||
pp_size=2 * pp_base,
|
ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, eager_mode=True),
|
||||||
eager_mode=False),
|
|
||||||
ParallelSetup(tp_size=tp_base,
|
|
||||||
pp_size=2 * pp_base,
|
|
||||||
eager_mode=True),
|
|
||||||
ParallelSetup(tp_size=2 * tp_base,
|
|
||||||
pp_size=pp_base,
|
|
||||||
eager_mode=False),
|
|
||||||
ParallelSetup(tp_size=2 * tp_base,
|
|
||||||
pp_size=pp_base,
|
|
||||||
eager_mode=True),
|
|
||||||
],
|
],
|
||||||
distributed_backends=["mp", "ray"],
|
distributed_backends=["mp", "ray"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
test_options=PPTestOptions(
|
||||||
load_format=load_format),
|
multi_node_only=multi_node_only, load_format=load_format
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -86,17 +78,15 @@ class PPTestSettings:
|
|||||||
multi_node_only: bool = False,
|
multi_node_only: bool = False,
|
||||||
load_format: Optional[str] = None,
|
load_format: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
return PPTestSettings(
|
return PPTestSettings(
|
||||||
parallel_setups=[
|
parallel_setups=[
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(tp_size=tp_base, pp_size=pp_base, eager_mode=True),
|
||||||
pp_size=pp_base,
|
|
||||||
eager_mode=True),
|
|
||||||
],
|
],
|
||||||
distributed_backends=["mp"],
|
distributed_backends=["mp"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
test_options=PPTestOptions(
|
||||||
load_format=load_format),
|
multi_node_only=multi_node_only, load_format=load_format
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_params(self, model_id: str):
|
def iter_params(self, model_id: str):
|
||||||
@@ -281,8 +271,10 @@ def _compare_tp(
|
|||||||
if num_gpus_available < tp_size * pp_size:
|
if num_gpus_available < tp_size * pp_size:
|
||||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
pytest.skip(
|
||||||
"multiprocessing distributed backend")
|
"Skipping multi-node pipeline parallel test for "
|
||||||
|
"multiprocessing distributed backend"
|
||||||
|
)
|
||||||
if multi_node_only and not VLLM_MULTI_NODE:
|
if multi_node_only and not VLLM_MULTI_NODE:
|
||||||
pytest.skip("Not in multi-node setting")
|
pytest.skip("Not in multi-node setting")
|
||||||
|
|
||||||
@@ -357,20 +349,16 @@ def _compare_tp(
|
|||||||
"mp",
|
"mp",
|
||||||
]
|
]
|
||||||
|
|
||||||
compare_two_settings(model_id,
|
compare_two_settings(model_id, pp_args, tp_args, pp_env, tp_env, method=method)
|
||||||
pp_args,
|
|
||||||
tp_args,
|
|
||||||
pp_env,
|
|
||||||
tp_env,
|
|
||||||
method=method)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_id", "parallel_setup", "distributed_backend", "runner",
|
("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"),
|
||||||
"test_options"),
|
|
||||||
[
|
[
|
||||||
params for model_id, settings in TEXT_GENERATION_MODELS.items()
|
params
|
||||||
for params in settings.iter_params(model_id) if model_id in TEST_MODELS
|
for model_id, settings in TEXT_GENERATION_MODELS.items()
|
||||||
|
for params in settings.iter_params(model_id)
|
||||||
|
if model_id in TEST_MODELS
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
@@ -382,22 +370,25 @@ def test_tp_language_generation(
|
|||||||
test_options: PPTestOptions,
|
test_options: PPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
_compare_tp(model_id,
|
_compare_tp(
|
||||||
|
model_id,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
runner,
|
runner,
|
||||||
test_options,
|
test_options,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
method="generate",
|
method="generate",
|
||||||
is_multimodal=False)
|
is_multimodal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_id", "parallel_setup", "distributed_backend", "runner",
|
("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"),
|
||||||
"test_options"),
|
|
||||||
[
|
[
|
||||||
params for model_id, settings in EMBEDDING_MODELS.items()
|
params
|
||||||
for params in settings.iter_params(model_id) if model_id in TEST_MODELS
|
for model_id, settings in EMBEDDING_MODELS.items()
|
||||||
|
for params in settings.iter_params(model_id)
|
||||||
|
if model_id in TEST_MODELS
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
@@ -409,22 +400,25 @@ def test_tp_language_embedding(
|
|||||||
test_options: PPTestOptions,
|
test_options: PPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
_compare_tp(model_id,
|
_compare_tp(
|
||||||
|
model_id,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
runner,
|
runner,
|
||||||
test_options,
|
test_options,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
method="encode",
|
method="encode",
|
||||||
is_multimodal=False)
|
is_multimodal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_id", "parallel_setup", "distributed_backend", "runner",
|
("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"),
|
||||||
"test_options"),
|
|
||||||
[
|
[
|
||||||
params for model_id, settings in MULTIMODAL_MODELS.items()
|
params
|
||||||
for params in settings.iter_params(model_id) if model_id in TEST_MODELS
|
for model_id, settings in MULTIMODAL_MODELS.items()
|
||||||
|
for params in settings.iter_params(model_id)
|
||||||
|
if model_id in TEST_MODELS
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
@@ -436,11 +430,13 @@ def test_tp_multimodal_generation(
|
|||||||
test_options: PPTestOptions,
|
test_options: PPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
_compare_tp(model_id,
|
_compare_tp(
|
||||||
|
model_id,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
runner,
|
runner,
|
||||||
test_options,
|
test_options,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
method="generate",
|
method="generate",
|
||||||
is_multimodal=True)
|
is_multimodal=True,
|
||||||
|
)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from vllm.distributed.utils import get_pp_indices
|
|||||||
|
|
||||||
|
|
||||||
def test_custom_layer_partition(monkeypatch: pytest.MonkeyPatch):
|
def test_custom_layer_partition(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
|
|
||||||
def _verify(partition_str, num_layers, pp_size, goldens):
|
def _verify(partition_str, num_layers, pp_size, goldens):
|
||||||
@@ -57,7 +56,8 @@ def test_custom_layer_partition(monkeypatch: pytest.MonkeyPatch):
|
|||||||
(5, 3, 0, (0, 2)),
|
(5, 3, 0, (0, 2)),
|
||||||
(5, 3, 1, (2, 4)),
|
(5, 3, 1, (2, 4)),
|
||||||
(5, 3, 2, (4, 5)),
|
(5, 3, 2, (4, 5)),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_uneven_auto_partition(
|
def test_uneven_auto_partition(
|
||||||
num_hidden_layers: int,
|
num_hidden_layers: int,
|
||||||
pp_size: int,
|
pp_size: int,
|
||||||
|
|||||||
@@ -12,12 +12,18 @@ if TYPE_CHECKING:
|
|||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [
|
@pytest.mark.parametrize(
|
||||||
|
"PP_SIZE, MODEL_NAME",
|
||||||
|
[
|
||||||
(2, "JackFram/llama-160m"),
|
(2, "JackFram/llama-160m"),
|
||||||
])
|
],
|
||||||
@pytest.mark.parametrize("ATTN_BACKEND", [
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"ATTN_BACKEND",
|
||||||
|
[
|
||||||
"FLASH_ATTN",
|
"FLASH_ATTN",
|
||||||
])
|
],
|
||||||
|
)
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_pp_cudagraph(
|
def test_pp_cudagraph(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
|||||||
@@ -9,13 +9,15 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from vllm.distributed.communication_op import ( # noqa
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||||
tensor_model_parallel_all_reduce)
|
|
||||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
||||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
from vllm.distributed.parallel_state import (
|
||||||
get_world_group, graph_capture,
|
ensure_model_parallel_initialized,
|
||||||
init_distributed_environment)
|
get_world_group,
|
||||||
|
graph_capture,
|
||||||
|
init_distributed_environment,
|
||||||
|
)
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
|
|
||||||
@@ -24,13 +26,13 @@ def distributed_run(fn, world_size):
|
|||||||
processes: list[multiprocessing.Process] = []
|
processes: list[multiprocessing.Process] = []
|
||||||
for i in range(number_of_processes):
|
for i in range(number_of_processes):
|
||||||
env: dict[str, str] = {}
|
env: dict[str, str] = {}
|
||||||
env['RANK'] = str(i)
|
env["RANK"] = str(i)
|
||||||
env['LOCAL_RANK'] = str(i)
|
env["LOCAL_RANK"] = str(i)
|
||||||
env['WORLD_SIZE'] = str(number_of_processes)
|
env["WORLD_SIZE"] = str(number_of_processes)
|
||||||
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
|
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||||
env['MASTER_ADDR'] = 'localhost'
|
env["MASTER_ADDR"] = "localhost"
|
||||||
env['MASTER_PORT'] = '12345'
|
env["MASTER_PORT"] = "12345"
|
||||||
p = multiprocessing.Process(target=fn, args=(env, ))
|
p = multiprocessing.Process(target=fn, args=(env,))
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
@@ -47,7 +49,7 @@ def worker_fn_wrapper(fn):
|
|||||||
# and update the environment variables in the function
|
# and update the environment variables in the function
|
||||||
def wrapped_fn(env):
|
def wrapped_fn(env):
|
||||||
update_environment_variables(env)
|
update_environment_variables(env)
|
||||||
local_rank = os.environ['LOCAL_RANK']
|
local_rank = os.environ["LOCAL_RANK"]
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
@@ -58,17 +60,18 @@ def worker_fn_wrapper(fn):
|
|||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def worker_fn():
|
def worker_fn():
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
tensor = torch.ones(16, 1024, 1024,
|
)
|
||||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
tensor = pynccl_comm.all_reduce(tensor)
|
tensor = pynccl_comm.all_reduce(tensor)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
|
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl():
|
def test_pynccl():
|
||||||
distributed_run(worker_fn, 2)
|
distributed_run(worker_fn, 2)
|
||||||
|
|
||||||
@@ -78,7 +81,7 @@ def multiple_allreduce_worker_fn():
|
|||||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||||
groups = [
|
groups = [
|
||||||
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
|
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
|
||||||
torch.distributed.new_group(ranks=[2, 3], backend="gloo")
|
torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
|
||||||
]
|
]
|
||||||
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
|
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
|
||||||
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
||||||
@@ -95,8 +98,9 @@ def multiple_allreduce_worker_fn():
|
|||||||
assert torch.all(tensor == 2).cpu().item()
|
assert torch.all(tensor == 2).cpu().item()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 4 GPUs to run the test.")
|
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_multiple_allreduce():
|
def test_pynccl_multiple_allreduce():
|
||||||
# this tests pynccl for multiple tp groups, in a standalone way
|
# this tests pynccl for multiple tp groups, in a standalone way
|
||||||
# i.e. call `pynccl_comm.all_reduce` directly
|
# i.e. call `pynccl_comm.all_reduce` directly
|
||||||
@@ -121,8 +125,9 @@ def multiple_allreduce_with_vllm_worker_fn():
|
|||||||
assert torch.all(tensor == 2).cpu().item()
|
assert torch.all(tensor == 2).cpu().item()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 4 GPUs to run the test.")
|
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_multiple_allreduce_with_vllm():
|
def test_pynccl_multiple_allreduce_with_vllm():
|
||||||
# this tests pynccl for multiple tp groups, together with vllm
|
# this tests pynccl for multiple tp groups, together with vllm
|
||||||
# i.e. call `tensor_model_parallel_all_reduce`
|
# i.e. call `tensor_model_parallel_all_reduce`
|
||||||
@@ -133,10 +138,11 @@ def test_pynccl_multiple_allreduce_with_vllm():
|
|||||||
def worker_fn_with_cudagraph():
|
def worker_fn_with_cudagraph():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
|
)
|
||||||
# run something in the default stream to initialize torch engine
|
# run something in the default stream to initialize torch engine
|
||||||
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
|
a = torch.ones((4, 4), device=f"cuda:{pynccl_comm.rank}")
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
with torch.cuda.graph(graph):
|
with torch.cuda.graph(graph):
|
||||||
a_out = pynccl_comm.all_reduce(a)
|
a_out = pynccl_comm.all_reduce(a)
|
||||||
@@ -148,84 +154,90 @@ def worker_fn_with_cudagraph():
|
|||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def all_gather_worker_fn():
|
def all_gather_worker_fn():
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
|
)
|
||||||
|
|
||||||
rank = pynccl_comm.rank
|
rank = pynccl_comm.rank
|
||||||
world_size = pynccl_comm.world_size
|
world_size = pynccl_comm.world_size
|
||||||
device = f'cuda:{pynccl_comm.rank}'
|
device = f"cuda:{pynccl_comm.rank}"
|
||||||
|
|
||||||
num_elems = 1000
|
num_elems = 1000
|
||||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
tensor = (
|
||||||
device=device) + rank * num_elems
|
torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems
|
||||||
result = torch.zeros(num_elems * world_size,
|
)
|
||||||
dtype=torch.float32,
|
result = torch.zeros(num_elems * world_size, dtype=torch.float32, device=device)
|
||||||
device=device)
|
|
||||||
|
|
||||||
expected = torch.cat([
|
expected = torch.cat(
|
||||||
|
[
|
||||||
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
||||||
for r in range(world_size)
|
for r in range(world_size)
|
||||||
]).to(device)
|
]
|
||||||
|
).to(device)
|
||||||
|
|
||||||
pynccl_comm.all_gather(result, tensor)
|
pynccl_comm.all_gather(result, tensor)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_all_gather():
|
def test_pynccl_all_gather():
|
||||||
distributed_run(all_gather_worker_fn, 2)
|
distributed_run(all_gather_worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def all_gatherv_worker_fn():
|
def all_gatherv_worker_fn():
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
|
)
|
||||||
|
|
||||||
rank = pynccl_comm.rank
|
rank = pynccl_comm.rank
|
||||||
world_size = pynccl_comm.world_size
|
world_size = pynccl_comm.world_size
|
||||||
device = f'cuda:{pynccl_comm.rank}'
|
device = f"cuda:{pynccl_comm.rank}"
|
||||||
|
|
||||||
assert world_size <= 8
|
assert world_size <= 8
|
||||||
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
||||||
num_elems = sizes[rank]
|
num_elems = sizes[rank]
|
||||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100
|
||||||
device=device) + rank * 100
|
|
||||||
result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)
|
result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)
|
||||||
|
|
||||||
expected = torch.cat([
|
expected = torch.cat(
|
||||||
|
[
|
||||||
torch.arange(sizes[r], dtype=torch.float32) + r * 100
|
torch.arange(sizes[r], dtype=torch.float32) + r * 100
|
||||||
for r in range(world_size)
|
for r in range(world_size)
|
||||||
]).to(device)
|
]
|
||||||
|
).to(device)
|
||||||
|
|
||||||
pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
|
pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_all_gatherv():
|
def test_pynccl_all_gatherv():
|
||||||
distributed_run(all_gatherv_worker_fn, 2)
|
distributed_run(all_gatherv_worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def reduce_scatter_worker_fn():
|
def reduce_scatter_worker_fn():
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
|
)
|
||||||
|
|
||||||
rank = pynccl_comm.rank
|
rank = pynccl_comm.rank
|
||||||
world_size = pynccl_comm.world_size
|
world_size = pynccl_comm.world_size
|
||||||
device = f'cuda:{pynccl_comm.rank}'
|
device = f"cuda:{pynccl_comm.rank}"
|
||||||
|
|
||||||
num_elems = 1000
|
num_elems = 1000
|
||||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
tensor = (
|
||||||
device=device) + rank * num_elems
|
torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems
|
||||||
assert (num_elems % world_size == 0)
|
)
|
||||||
result = torch.zeros(num_elems // world_size,
|
assert num_elems % world_size == 0
|
||||||
dtype=torch.float32,
|
result = torch.zeros(num_elems // world_size, dtype=torch.float32, device=device)
|
||||||
device=device)
|
|
||||||
|
|
||||||
# Calculate expected result for this rank's chunk
|
# Calculate expected result for this rank's chunk
|
||||||
scattered_size = num_elems // world_size
|
scattered_size = num_elems // world_size
|
||||||
@@ -233,34 +245,37 @@ def reduce_scatter_worker_fn():
|
|||||||
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
||||||
for r in range(world_size)
|
for r in range(world_size)
|
||||||
]
|
]
|
||||||
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
|
expected = sum(
|
||||||
for tensor in all_tensors).to(device)
|
tensor[rank * scattered_size : (rank + 1) * scattered_size]
|
||||||
|
for tensor in all_tensors
|
||||||
|
).to(device)
|
||||||
|
|
||||||
pynccl_comm.reduce_scatter(result, tensor)
|
pynccl_comm.reduce_scatter(result, tensor)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_reduce_scatter():
|
def test_pynccl_reduce_scatter():
|
||||||
distributed_run(reduce_scatter_worker_fn, 2)
|
distributed_run(reduce_scatter_worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def reduce_scatterv_worker_fn():
|
def reduce_scatterv_worker_fn():
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
|
)
|
||||||
|
|
||||||
rank = pynccl_comm.rank
|
rank = pynccl_comm.rank
|
||||||
world_size = pynccl_comm.world_size
|
world_size = pynccl_comm.world_size
|
||||||
device = f'cuda:{pynccl_comm.rank}'
|
device = f"cuda:{pynccl_comm.rank}"
|
||||||
|
|
||||||
assert world_size <= 8
|
assert world_size <= 8
|
||||||
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
||||||
num_elems = sum(sizes)
|
num_elems = sum(sizes)
|
||||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100
|
||||||
device=device) + rank * 100
|
|
||||||
result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)
|
result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)
|
||||||
|
|
||||||
# Calculate expected result for this rank's chunk
|
# Calculate expected result for this rank's chunk
|
||||||
@@ -278,41 +293,41 @@ def reduce_scatterv_worker_fn():
|
|||||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_reduce_scatterv():
|
def test_pynccl_reduce_scatterv():
|
||||||
distributed_run(reduce_scatterv_worker_fn, 2)
|
distributed_run(reduce_scatterv_worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_with_cudagraph():
|
def test_pynccl_with_cudagraph():
|
||||||
distributed_run(worker_fn_with_cudagraph, 2)
|
distributed_run(worker_fn_with_cudagraph, 2)
|
||||||
|
|
||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def send_recv_worker_fn():
|
def send_recv_worker_fn():
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
|
)
|
||||||
if pynccl_comm.rank == 0:
|
if pynccl_comm.rank == 0:
|
||||||
tensor = torch.ones(16, 1024, 1024,
|
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
|
||||||
else:
|
else:
|
||||||
tensor = torch.empty(16, 1024, 1024,
|
tensor = torch.empty(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
|
||||||
|
|
||||||
if pynccl_comm.rank == 0:
|
if pynccl_comm.rank == 0:
|
||||||
pynccl_comm.send(tensor,
|
pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
||||||
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
|
||||||
else:
|
else:
|
||||||
pynccl_comm.recv(tensor,
|
pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
||||||
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
assert torch.all(tensor == 1).cpu().item()
|
assert torch.all(tensor == 1).cpu().item()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_send_recv():
|
def test_pynccl_send_recv():
|
||||||
distributed_run(send_recv_worker_fn, 2)
|
distributed_run(send_recv_worker_fn, 2)
|
||||||
|
|
||||||
@@ -322,27 +337,20 @@ def multiple_send_recv_worker_fn():
|
|||||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||||
groups = [
|
groups = [
|
||||||
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
|
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
|
||||||
torch.distributed.new_group(ranks=[1, 3], backend="gloo")
|
torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
|
||||||
]
|
]
|
||||||
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
|
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
|
||||||
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
||||||
if torch.distributed.get_rank() == 0:
|
if torch.distributed.get_rank() == 0:
|
||||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||||
elif torch.distributed.get_rank() == 1:
|
elif torch.distributed.get_rank() == 1:
|
||||||
tensor = 2 * torch.ones(
|
tensor = 2 * torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||||
16, 1024, 1024, dtype=torch.float32, device=device)
|
|
||||||
else:
|
else:
|
||||||
tensor = torch.empty(16,
|
tensor = torch.empty(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||||
1024,
|
|
||||||
1024,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=device)
|
|
||||||
if torch.distributed.get_rank() in [0, 1]:
|
if torch.distributed.get_rank() in [0, 1]:
|
||||||
pynccl_comm.send(tensor,
|
pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
||||||
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
|
||||||
else:
|
else:
|
||||||
pynccl_comm.recv(tensor,
|
pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
||||||
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
if torch.distributed.get_rank() in [0, 2]:
|
if torch.distributed.get_rank() in [0, 2]:
|
||||||
assert torch.all(tensor == 1).cpu().item()
|
assert torch.all(tensor == 1).cpu().item()
|
||||||
@@ -350,14 +358,16 @@ def multiple_send_recv_worker_fn():
|
|||||||
assert torch.all(tensor == 2).cpu().item()
|
assert torch.all(tensor == 2).cpu().item()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 4 GPUs to run the test.")
|
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_multiple_send_recv():
|
def test_pynccl_multiple_send_recv():
|
||||||
distributed_run(multiple_send_recv_worker_fn, 4)
|
distributed_run(multiple_send_recv_worker_fn, 4)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 4 GPUs to run the test.")
|
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_broadcast():
|
def test_pynccl_broadcast():
|
||||||
distributed_run(broadcast_worker_fn, 4)
|
distributed_run(broadcast_worker_fn, 4)
|
||||||
|
|
||||||
@@ -366,19 +376,17 @@ def test_pynccl_broadcast():
|
|||||||
def broadcast_worker_fn():
|
def broadcast_worker_fn():
|
||||||
# Test broadcast for every root rank.
|
# Test broadcast for every root rank.
|
||||||
# Essentially this is an all-gather operation.
|
# Essentially this is an all-gather operation.
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
|
)
|
||||||
recv_tensors = [
|
recv_tensors = [
|
||||||
torch.empty(16,
|
torch.empty(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device)
|
||||||
1024,
|
|
||||||
1024,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=pynccl_comm.device)
|
|
||||||
for i in range(pynccl_comm.world_size)
|
for i in range(pynccl_comm.world_size)
|
||||||
]
|
]
|
||||||
recv_tensors[pynccl_comm.rank] = torch.ones(
|
recv_tensors[pynccl_comm.rank] = (
|
||||||
16, 1024, 1024, dtype=torch.float32,
|
torch.ones(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device)
|
||||||
device=pynccl_comm.device) * pynccl_comm.rank
|
* pynccl_comm.rank
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(pynccl_comm.world_size):
|
for i in range(pynccl_comm.world_size):
|
||||||
pynccl_comm.broadcast(recv_tensors[i], src=i)
|
pynccl_comm.broadcast(recv_tensors[i], src=i)
|
||||||
|
|||||||
@@ -8,20 +8,20 @@ import ray
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from vllm.distributed.communication_op import ( # noqa
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||||
tensor_model_parallel_all_reduce)
|
|
||||||
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import (ensure_model_parallel_initialized,
|
from ..utils import (
|
||||||
init_test_distributed_environment, multi_process_parallel)
|
ensure_model_parallel_initialized,
|
||||||
|
init_test_distributed_environment,
|
||||||
|
multi_process_parallel,
|
||||||
|
)
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
random.seed(44)
|
random.seed(44)
|
||||||
# Size over 8MB is sufficient for custom quick allreduce.
|
# Size over 8MB is sufficient for custom quick allreduce.
|
||||||
test_sizes = [
|
test_sizes = [random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)]
|
||||||
random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)
|
|
||||||
]
|
|
||||||
for i, v in enumerate(test_sizes):
|
for i, v in enumerate(test_sizes):
|
||||||
test_sizes[i] -= v % 8
|
test_sizes[i] -= v % 8
|
||||||
|
|
||||||
@@ -38,8 +38,7 @@ def graph_quickreduce(
|
|||||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||||
group = get_tp_group().device_group
|
group = get_tp_group().device_group
|
||||||
|
|
||||||
@@ -64,18 +63,15 @@ def graph_quickreduce(
|
|||||||
for sz in test_sizes:
|
for sz in test_sizes:
|
||||||
for dtype in [torch.float16, torch.bfloat16]:
|
for dtype in [torch.float16, torch.bfloat16]:
|
||||||
with graph_capture(device=device) as graph_capture_context:
|
with graph_capture(device=device) as graph_capture_context:
|
||||||
inp1 = torch.randint(1,
|
inp1 = torch.randint(
|
||||||
23, (sz, ),
|
1, 23, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||||
dtype=dtype,
|
)
|
||||||
device=torch.cuda.current_device())
|
inp2 = torch.randint(
|
||||||
inp2 = torch.randint(-23,
|
-23, 1, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||||
1, (sz, ),
|
)
|
||||||
dtype=dtype,
|
|
||||||
device=torch.cuda.current_device())
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(graph,
|
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
|
||||||
stream=graph_capture_context.stream):
|
|
||||||
for _ in range(num_communication):
|
for _ in range(num_communication):
|
||||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||||
dist.all_reduce(inp1, group=group)
|
dist.all_reduce(inp1, group=group)
|
||||||
@@ -99,39 +95,42 @@ def eager_quickreduce(
|
|||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
|
|
||||||
# Size over 8MB is sufficient for custom quick allreduce.
|
# Size over 8MB is sufficient for custom quick allreduce.
|
||||||
sz = 16 * 1024 * 1024
|
sz = 16 * 1024 * 1024
|
||||||
fa = get_tp_group().device_communicator.qr_comm
|
fa = get_tp_group().device_communicator.qr_comm
|
||||||
inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)],
|
inp = torch.tensor(
|
||||||
dtype=torch.float16,
|
[1.0 * ((i) % 23) for i in range(sz)], dtype=torch.float16, device=device
|
||||||
device=device)
|
)
|
||||||
out = fa.quick_all_reduce(inp)
|
out = fa.quick_all_reduce(inp)
|
||||||
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
|
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
|
||||||
|
|
||||||
inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)],
|
inp = torch.tensor(
|
||||||
dtype=torch.bfloat16,
|
[1.0 * ((i) % 23) for i in range(sz)], dtype=torch.bfloat16, device=device
|
||||||
device=device)
|
)
|
||||||
out = fa.quick_all_reduce(inp)
|
out = fa.quick_all_reduce(inp)
|
||||||
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
|
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
@pytest.mark.skipif(
|
||||||
reason="only test quick allreduce for rocm")
|
not current_platform.is_rocm(), reason="only test quick allreduce for rocm"
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"])
|
@pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"])
|
||||||
@pytest.mark.parametrize("tp_size", [2])
|
@pytest.mark.parametrize("tp_size", [2])
|
||||||
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
|
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
|
||||||
@pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce])
|
@pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce])
|
||||||
def test_custom_quick_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
|
def test_custom_quick_allreduce(
|
||||||
pipeline_parallel_size, test_target,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
quant_mode):
|
tp_size,
|
||||||
|
pipeline_parallel_size,
|
||||||
|
test_target,
|
||||||
|
quant_mode,
|
||||||
|
):
|
||||||
world_size = tp_size * pipeline_parallel_size
|
world_size = tp_size * pipeline_parallel_size
|
||||||
if world_size > torch.cuda.device_count():
|
if world_size > torch.cuda.device_count():
|
||||||
pytest.skip("Not enough GPUs to run the test.")
|
pytest.skip("Not enough GPUs to run the test.")
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
|
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
|
||||||
|
|
||||||
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size,
|
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
|
||||||
test_target)
|
|
||||||
|
|||||||
@@ -22,15 +22,13 @@ if __name__ == "__main__":
|
|||||||
dist.broadcast_object_list(recv, src=0)
|
dist.broadcast_object_list(recv, src=0)
|
||||||
ip, port = recv
|
ip, port = recv
|
||||||
|
|
||||||
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
|
stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
|
||||||
dist.get_world_size())
|
|
||||||
|
|
||||||
for pg in [dist.group.WORLD, stateless_pg]:
|
for pg in [dist.group.WORLD, stateless_pg]:
|
||||||
test_result = all(in_the_same_node_as(pg, source_rank=0))
|
test_result = all(in_the_same_node_as(pg, source_rank=0))
|
||||||
|
|
||||||
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
|
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
|
||||||
assert test_result == expected, \
|
assert test_result == expected, f"Expected {expected}, got {test_result}"
|
||||||
f"Expected {expected}, got {test_result}"
|
|
||||||
if pg == dist.group.WORLD:
|
if pg == dist.group.WORLD:
|
||||||
print("Same node test passed! when using torch distributed!")
|
print("Same node test passed! when using torch distributed!")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
|||||||
all workers in a node other than the head node, which can cause the test
|
all workers in a node other than the head node, which can cause the test
|
||||||
to fail.
|
to fail.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -56,7 +57,8 @@ class SPTestSettings:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Length mismatch: distributed_backends "
|
f"Length mismatch: distributed_backends "
|
||||||
f"({len(self.distributed_backends)}) != "
|
f"({len(self.distributed_backends)}) != "
|
||||||
f"vllm_major_versions ({len(self.vllm_major_versions)})")
|
f"vllm_major_versions ({len(self.vllm_major_versions)})"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def detailed(
|
def detailed(
|
||||||
@@ -72,18 +74,22 @@ class SPTestSettings:
|
|||||||
for pp_multiplier in [1, 2]:
|
for pp_multiplier in [1, 2]:
|
||||||
for chunked_prefill_val in [False, True]:
|
for chunked_prefill_val in [False, True]:
|
||||||
parallel_setups.append(
|
parallel_setups.append(
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(
|
||||||
|
tp_size=tp_base,
|
||||||
pp_size=pp_multiplier * pp_base,
|
pp_size=pp_multiplier * pp_base,
|
||||||
enable_fusion=False,
|
enable_fusion=False,
|
||||||
eager_mode=eager_mode_val,
|
eager_mode=eager_mode_val,
|
||||||
chunked_prefill=chunked_prefill_val))
|
chunked_prefill=chunked_prefill_val,
|
||||||
|
)
|
||||||
|
)
|
||||||
return SPTestSettings(
|
return SPTestSettings(
|
||||||
parallel_setups=parallel_setups,
|
parallel_setups=parallel_setups,
|
||||||
distributed_backends=["mp", "ray"],
|
distributed_backends=["mp", "ray"],
|
||||||
vllm_major_versions=["1", "1"],
|
vllm_major_versions=["1", "1"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=SPTestOptions(multi_node_only=multi_node_only,
|
test_options=SPTestOptions(
|
||||||
load_format=load_format),
|
multi_node_only=multi_node_only, load_format=load_format
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -100,18 +106,22 @@ class SPTestSettings:
|
|||||||
for pp_multiplier in [1, 2]:
|
for pp_multiplier in [1, 2]:
|
||||||
for chunked_prefill_val in [False, True]:
|
for chunked_prefill_val in [False, True]:
|
||||||
parallel_setups.append(
|
parallel_setups.append(
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(
|
||||||
|
tp_size=tp_base,
|
||||||
pp_size=pp_multiplier * pp_base,
|
pp_size=pp_multiplier * pp_base,
|
||||||
enable_fusion=False,
|
enable_fusion=False,
|
||||||
eager_mode=eager_mode_val,
|
eager_mode=eager_mode_val,
|
||||||
chunked_prefill=chunked_prefill_val))
|
chunked_prefill=chunked_prefill_val,
|
||||||
|
)
|
||||||
|
)
|
||||||
return SPTestSettings(
|
return SPTestSettings(
|
||||||
parallel_setups=parallel_setups,
|
parallel_setups=parallel_setups,
|
||||||
distributed_backends=["mp", "ray"],
|
distributed_backends=["mp", "ray"],
|
||||||
vllm_major_versions=["1", "1"],
|
vllm_major_versions=["1", "1"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=SPTestOptions(multi_node_only=multi_node_only,
|
test_options=SPTestOptions(
|
||||||
load_format=load_format),
|
multi_node_only=multi_node_only, load_format=load_format
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -126,28 +136,39 @@ class SPTestSettings:
|
|||||||
parallel_setups = []
|
parallel_setups = []
|
||||||
for fusion_val in [False, True]:
|
for fusion_val in [False, True]:
|
||||||
parallel_setups.append(
|
parallel_setups.append(
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(
|
||||||
|
tp_size=tp_base,
|
||||||
pp_size=pp_base,
|
pp_size=pp_base,
|
||||||
enable_fusion=fusion_val,
|
enable_fusion=fusion_val,
|
||||||
eager_mode=True,
|
eager_mode=True,
|
||||||
chunked_prefill=False))
|
chunked_prefill=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
return SPTestSettings(
|
return SPTestSettings(
|
||||||
parallel_setups=parallel_setups,
|
parallel_setups=parallel_setups,
|
||||||
distributed_backends=["mp", "ray"],
|
distributed_backends=["mp", "ray"],
|
||||||
vllm_major_versions=["1", "1"],
|
vllm_major_versions=["1", "1"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=SPTestOptions(multi_node_only=multi_node_only,
|
test_options=SPTestOptions(
|
||||||
load_format=load_format),
|
multi_node_only=multi_node_only, load_format=load_format
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_params(self, model_id: str):
|
def iter_params(self, model_id: str):
|
||||||
opts = self.test_options
|
opts = self.test_options
|
||||||
|
|
||||||
for parallel_setup in self.parallel_setups:
|
for parallel_setup in self.parallel_setups:
|
||||||
for backend, vllm_major_version in zip(self.distributed_backends,
|
for backend, vllm_major_version in zip(
|
||||||
self.vllm_major_versions):
|
self.distributed_backends, self.vllm_major_versions
|
||||||
yield (model_id, parallel_setup, backend, vllm_major_version,
|
):
|
||||||
self.runner, opts)
|
yield (
|
||||||
|
model_id,
|
||||||
|
parallel_setup,
|
||||||
|
backend,
|
||||||
|
vllm_major_version,
|
||||||
|
self.runner,
|
||||||
|
opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _compare_sp(
|
def _compare_sp(
|
||||||
@@ -200,8 +221,10 @@ def _compare_sp(
|
|||||||
if num_gpus_available < tp_size * pp_size:
|
if num_gpus_available < tp_size * pp_size:
|
||||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
pytest.skip(
|
||||||
"multiprocessing distributed backend")
|
"Skipping multi-node pipeline parallel test for "
|
||||||
|
"multiprocessing distributed backend"
|
||||||
|
)
|
||||||
if multi_node_only and not VLLM_MULTI_NODE:
|
if multi_node_only and not VLLM_MULTI_NODE:
|
||||||
pytest.skip("Not in multi-node setting")
|
pytest.skip("Not in multi-node setting")
|
||||||
|
|
||||||
@@ -232,13 +255,13 @@ def _compare_sp(
|
|||||||
common_args.append("--skip-tokenizer-init")
|
common_args.append("--skip-tokenizer-init")
|
||||||
|
|
||||||
compilation_config = {
|
compilation_config = {
|
||||||
'level': 3,
|
"level": 3,
|
||||||
'custom_ops': ["+rms_norm"],
|
"custom_ops": ["+rms_norm"],
|
||||||
'compile_sizes': [4, 8],
|
"compile_sizes": [4, 8],
|
||||||
'pass_config': {
|
"pass_config": {
|
||||||
'enable_sequence_parallelism': True,
|
"enable_sequence_parallelism": True,
|
||||||
'enable_fusion': enable_fusion,
|
"enable_fusion": enable_fusion,
|
||||||
'enable_noop': True,
|
"enable_noop": True,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -270,12 +293,9 @@ def _compare_sp(
|
|||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
compare_two_settings(model_id,
|
compare_two_settings(
|
||||||
tp_sp_args,
|
model_id, tp_sp_args, tp_args, tp_sp_env, tp_env, method=method
|
||||||
tp_args,
|
)
|
||||||
tp_sp_env,
|
|
||||||
tp_env,
|
|
||||||
method=method)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
testing_ray_compiled_graph = tp_sp_env is not None
|
testing_ray_compiled_graph = tp_sp_env is not None
|
||||||
if testing_ray_compiled_graph and vllm_major_version == "0":
|
if testing_ray_compiled_graph and vllm_major_version == "0":
|
||||||
@@ -301,10 +321,17 @@ SP_TEST_MODELS = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
(
|
||||||
"runner", "test_options"),
|
"model_id",
|
||||||
|
"parallel_setup",
|
||||||
|
"distributed_backend",
|
||||||
|
"vllm_major_version",
|
||||||
|
"runner",
|
||||||
|
"test_options",
|
||||||
|
),
|
||||||
[
|
[
|
||||||
params for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
|
params
|
||||||
|
for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
|
||||||
for params in settings.iter_params(model_id)
|
for params in settings.iter_params(model_id)
|
||||||
if model_id in SP_TEST_MODELS
|
if model_id in SP_TEST_MODELS
|
||||||
],
|
],
|
||||||
@@ -319,7 +346,8 @@ def test_tp_sp_generation(
|
|||||||
test_options: SPTestOptions,
|
test_options: SPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
_compare_sp(model_id,
|
_compare_sp(
|
||||||
|
model_id,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
vllm_major_version,
|
vllm_major_version,
|
||||||
@@ -327,4 +355,5 @@ def test_tp_sp_generation(
|
|||||||
test_options,
|
test_options,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
method="generate",
|
method="generate",
|
||||||
is_multimodal=False)
|
is_multimodal=False,
|
||||||
|
)
|
||||||
|
|||||||
@@ -26,13 +26,13 @@ def distributed_run(fn, world_size):
|
|||||||
processes = []
|
processes = []
|
||||||
for i in range(number_of_processes):
|
for i in range(number_of_processes):
|
||||||
env = {}
|
env = {}
|
||||||
env['RANK'] = str(i)
|
env["RANK"] = str(i)
|
||||||
env['LOCAL_RANK'] = str(i)
|
env["LOCAL_RANK"] = str(i)
|
||||||
env['WORLD_SIZE'] = str(number_of_processes)
|
env["WORLD_SIZE"] = str(number_of_processes)
|
||||||
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
|
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||||
env['MASTER_ADDR'] = 'localhost'
|
env["MASTER_ADDR"] = "localhost"
|
||||||
env['MASTER_PORT'] = '12345'
|
env["MASTER_PORT"] = "12345"
|
||||||
p = multiprocessing.Process(target=fn, args=(env, ))
|
p = multiprocessing.Process(target=fn, args=(env,))
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
@@ -57,25 +57,23 @@ def worker_fn_wrapper(fn):
|
|||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def worker_fn():
|
def worker_fn():
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
port = get_open_port()
|
port = get_open_port()
|
||||||
ip = '127.0.0.1'
|
ip = "127.0.0.1"
|
||||||
dist.broadcast_object_list([ip, port], src=0)
|
dist.broadcast_object_list([ip, port], src=0)
|
||||||
else:
|
else:
|
||||||
recv = [None, None]
|
recv = [None, None]
|
||||||
dist.broadcast_object_list(recv, src=0)
|
dist.broadcast_object_list(recv, src=0)
|
||||||
ip, port = recv # type: ignore
|
ip, port = recv # type: ignore
|
||||||
|
|
||||||
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
|
stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
|
||||||
dist.get_world_size())
|
|
||||||
|
|
||||||
for pg in [dist.group.WORLD, stateless_pg]:
|
for pg in [dist.group.WORLD, stateless_pg]:
|
||||||
|
|
||||||
writer_rank = 2
|
writer_rank = 2
|
||||||
broadcaster = MessageQueue.create_from_process_group(
|
broadcaster = MessageQueue.create_from_process_group(
|
||||||
pg, 40 * 1024, 2, writer_rank)
|
pg, 40 * 1024, 2, writer_rank
|
||||||
|
)
|
||||||
if rank == writer_rank:
|
if rank == writer_rank:
|
||||||
seed = random.randint(0, 1000)
|
seed = random.randint(0, 1000)
|
||||||
dist.broadcast_object_list([seed], writer_rank)
|
dist.broadcast_object_list([seed], writer_rank)
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ import traceback
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.shm_object_storage import (
|
from vllm.distributed.device_communicators.shm_object_storage import (
|
||||||
SingleWriterShmRingBuffer)
|
SingleWriterShmRingBuffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
||||||
@@ -25,18 +26,21 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
|||||||
"""Test opening an existing buffer"""
|
"""Test opening an existing buffer"""
|
||||||
# First create a buffer
|
# First create a buffer
|
||||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||||
data_buffer_size=self.buffer_size, create=True)
|
data_buffer_size=self.buffer_size, create=True
|
||||||
|
)
|
||||||
|
|
||||||
# Then open it with another instance
|
# Then open it with another instance
|
||||||
reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle())
|
reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle())
|
||||||
self.assertFalse(reader_buffer.is_writer)
|
self.assertFalse(reader_buffer.is_writer)
|
||||||
self.assertEqual(reader_buffer.shared_memory.name,
|
self.assertEqual(
|
||||||
self.ring_buffer.shared_memory.name)
|
reader_buffer.shared_memory.name, self.ring_buffer.shared_memory.name
|
||||||
|
)
|
||||||
|
|
||||||
def test_buffer_access(self):
|
def test_buffer_access(self):
|
||||||
"""Test accessing allocated buffers"""
|
"""Test accessing allocated buffers"""
|
||||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||||
data_buffer_size=self.buffer_size, create=True)
|
data_buffer_size=self.buffer_size, create=True
|
||||||
|
)
|
||||||
|
|
||||||
size = 100
|
size = 100
|
||||||
address, monotonic_id = self.ring_buffer.allocate_buf(size)
|
address, monotonic_id = self.ring_buffer.allocate_buf(size)
|
||||||
@@ -44,11 +48,11 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
|||||||
# Write some test data
|
# Write some test data
|
||||||
test_data = b"Hello, World!" * 7 # 91 bytes
|
test_data = b"Hello, World!" * 7 # 91 bytes
|
||||||
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
|
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
|
||||||
data_buf[0:len(test_data)] = test_data
|
data_buf[0 : len(test_data)] = test_data
|
||||||
|
|
||||||
# Read it back
|
# Read it back
|
||||||
with self.ring_buffer.access_buf(address) as (data_buf2, metadata2):
|
with self.ring_buffer.access_buf(address) as (data_buf2, metadata2):
|
||||||
read_data = bytes(data_buf2[0:len(test_data)])
|
read_data = bytes(data_buf2[0 : len(test_data)])
|
||||||
read_id = metadata2[0]
|
read_id = metadata2[0]
|
||||||
|
|
||||||
self.assertEqual(read_data, test_data)
|
self.assertEqual(read_data, test_data)
|
||||||
@@ -58,7 +62,8 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
|||||||
"""Test that MemoryError is raised when buffer is full"""
|
"""Test that MemoryError is raised when buffer is full"""
|
||||||
small_buffer_size = 200
|
small_buffer_size = 200
|
||||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||||
data_buffer_size=small_buffer_size, create=True)
|
data_buffer_size=small_buffer_size, create=True
|
||||||
|
)
|
||||||
|
|
||||||
# Fill up the buffer
|
# Fill up the buffer
|
||||||
self.ring_buffer.allocate_buf(100)
|
self.ring_buffer.allocate_buf(100)
|
||||||
@@ -72,7 +77,8 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
|||||||
"""Test allocation and freeing of buffers"""
|
"""Test allocation and freeing of buffers"""
|
||||||
small_buffer_size = 200
|
small_buffer_size = 200
|
||||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||||
data_buffer_size=small_buffer_size, create=True)
|
data_buffer_size=small_buffer_size, create=True
|
||||||
|
)
|
||||||
|
|
||||||
size = 80
|
size = 80
|
||||||
# Write some data
|
# Write some data
|
||||||
@@ -81,7 +87,7 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
|||||||
address, monotonic_id = self.ring_buffer.allocate_buf(size)
|
address, monotonic_id = self.ring_buffer.allocate_buf(size)
|
||||||
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
|
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
|
||||||
data_buf[0:4] = (0).to_bytes(4, "little") # 0 for not in-use
|
data_buf[0:4] = (0).to_bytes(4, "little") # 0 for not in-use
|
||||||
data_buf[4:len(test_data) + 4] = test_data
|
data_buf[4 : len(test_data) + 4] = test_data
|
||||||
print(self.ring_buffer.metadata)
|
print(self.ring_buffer.metadata)
|
||||||
freed_ids = self.ring_buffer.free_buf(lambda *args: True)
|
freed_ids = self.ring_buffer.free_buf(lambda *args: True)
|
||||||
print(f" Freed IDs: {freed_ids}")
|
print(f" Freed IDs: {freed_ids}")
|
||||||
@@ -90,7 +96,8 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
|||||||
def test_clear_buffer(self):
|
def test_clear_buffer(self):
|
||||||
"""Test clearing the buffer"""
|
"""Test clearing the buffer"""
|
||||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||||
data_buffer_size=self.buffer_size, create=True)
|
data_buffer_size=self.buffer_size, create=True
|
||||||
|
)
|
||||||
|
|
||||||
# Allocate some buffers
|
# Allocate some buffers
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
@@ -121,8 +128,7 @@ def main():
|
|||||||
# Manual demonstration
|
# Manual demonstration
|
||||||
try:
|
try:
|
||||||
print("Creating ring buffer...")
|
print("Creating ring buffer...")
|
||||||
writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048,
|
writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048, create=True)
|
||||||
create=True)
|
|
||||||
reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle())
|
reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle())
|
||||||
|
|
||||||
print(f"Buffer created with name: {writer_buffer.shared_memory.name}")
|
print(f"Buffer created with name: {writer_buffer.shared_memory.name}")
|
||||||
@@ -140,7 +146,7 @@ def main():
|
|||||||
# Write some test data
|
# Write some test data
|
||||||
with writer_buffer.access_buf(address) as (data_buf, metadata):
|
with writer_buffer.access_buf(address) as (data_buf, metadata):
|
||||||
test_message = f"Test message {i}".encode()
|
test_message = f"Test message {i}".encode()
|
||||||
data_buf[0:len(test_message)] = test_message
|
data_buf[0 : len(test_message)] = test_message
|
||||||
|
|
||||||
except MemoryError as e:
|
except MemoryError as e:
|
||||||
print(f" Failed to allocate {size} bytes: {e}")
|
print(f" Failed to allocate {size} bytes: {e}")
|
||||||
|
|||||||
@@ -12,28 +12,33 @@ import torch
|
|||||||
|
|
||||||
# Assuming these are imported from your module
|
# Assuming these are imported from your module
|
||||||
from vllm.distributed.device_communicators.shm_object_storage import (
|
from vllm.distributed.device_communicators.shm_object_storage import (
|
||||||
MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer)
|
MsgpackSerde,
|
||||||
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
|
SingleWriterShmObjectStorage,
|
||||||
MultiModalSharedField)
|
SingleWriterShmRingBuffer,
|
||||||
|
)
|
||||||
|
from vllm.multimodal.inputs import (
|
||||||
|
MultiModalFieldElem,
|
||||||
|
MultiModalKwargsItem,
|
||||||
|
MultiModalSharedField,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _dummy_elem(modality: str, key: str, size: int):
|
def _dummy_elem(modality: str, key: str, size: int):
|
||||||
return MultiModalFieldElem(
|
return MultiModalFieldElem(
|
||||||
modality=modality,
|
modality=modality,
|
||||||
key=key,
|
key=key,
|
||||||
data=torch.empty((size, ), dtype=torch.int8),
|
data=torch.empty((size,), dtype=torch.int8),
|
||||||
field=MultiModalSharedField(1),
|
field=MultiModalSharedField(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _dummy_item(modality: str, size_by_key: dict[str, int]):
|
def _dummy_item(modality: str, size_by_key: dict[str, int]):
|
||||||
return MultiModalKwargsItem.from_elems([
|
return MultiModalKwargsItem.from_elems(
|
||||||
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
|
[_dummy_elem(modality, key, size) for key, size in size_by_key.items()]
|
||||||
])
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSingleWriterShmObjectStorage(unittest.TestCase):
|
class TestSingleWriterShmObjectStorage(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Set up test fixtures before each test method."""
|
"""Set up test fixtures before each test method."""
|
||||||
ring_buffer = SingleWriterShmRingBuffer(
|
ring_buffer = SingleWriterShmRingBuffer(
|
||||||
@@ -208,8 +213,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError) as context:
|
with self.assertRaises(ValueError) as context:
|
||||||
self.storage.get(address, monotonic_id + 100)
|
self.storage.get(address, monotonic_id + 100)
|
||||||
|
|
||||||
self.assertIn("has been modified or is invalid", \
|
self.assertIn("has been modified or is invalid", str(context.exception))
|
||||||
str(context.exception))
|
|
||||||
|
|
||||||
def test_clear_storage(self):
|
def test_clear_storage(self):
|
||||||
"""Test clearing the storage."""
|
"""Test clearing the storage."""
|
||||||
@@ -234,8 +238,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase):
|
|||||||
# Reader process function
|
# Reader process function
|
||||||
def reader_process(process_id, storage_handle, items_to_read):
|
def reader_process(process_id, storage_handle, items_to_read):
|
||||||
"""Reader process that connects to existing shared memory and reads data."""
|
"""Reader process that connects to existing shared memory and reads data."""
|
||||||
reader_storage = SingleWriterShmObjectStorage.create_from_handle(
|
reader_storage = SingleWriterShmObjectStorage.create_from_handle(storage_handle)
|
||||||
storage_handle)
|
|
||||||
|
|
||||||
print(f"Reader {process_id} started")
|
print(f"Reader {process_id} started")
|
||||||
|
|
||||||
@@ -276,11 +279,7 @@ def run_multiprocess_example():
|
|||||||
|
|
||||||
# Test basic data types
|
# Test basic data types
|
||||||
test_data = [
|
test_data = [
|
||||||
("user_data", {
|
("user_data", {"name": "Alice", "age": 30, "scores": [95, 87, 92]}),
|
||||||
"name": "Alice",
|
|
||||||
"age": 30,
|
|
||||||
"scores": [95, 87, 92]
|
|
||||||
}),
|
|
||||||
("simple_string", "Hello, World!"),
|
("simple_string", "Hello, World!"),
|
||||||
("number", 42),
|
("number", 42),
|
||||||
("list_data", [1, 2, 3, "four", 5.0]),
|
("list_data", [1, 2, 3, "four", 5.0]),
|
||||||
@@ -301,8 +300,9 @@ def run_multiprocess_example():
|
|||||||
# initialize lock for reader processes
|
# initialize lock for reader processes
|
||||||
handle.reader_lock = Lock()
|
handle.reader_lock = Lock()
|
||||||
for i in range(storage.n_readers):
|
for i in range(storage.n_readers):
|
||||||
p = multiprocessing.Process(target=reader_process,
|
p = multiprocessing.Process(
|
||||||
args=(i, handle, stored_items))
|
target=reader_process, args=(i, handle, stored_items)
|
||||||
|
)
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
|
|||||||
@@ -14,11 +14,12 @@ import vllm.envs as envs
|
|||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
|
||||||
CudaCommunicator)
|
from vllm.distributed.parallel_state import (
|
||||||
from vllm.distributed.parallel_state import (get_tp_group,
|
get_tp_group,
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
initialize_model_parallel)
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@@ -32,8 +33,7 @@ test_size_elements = 1024 * 1024
|
|||||||
|
|
||||||
def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
|
def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
|
||||||
monkeypatch = pytest.MonkeyPatch()
|
monkeypatch = pytest.MonkeyPatch()
|
||||||
config = VllmConfig(parallel_config=ParallelConfig(
|
config = VllmConfig(parallel_config=ParallelConfig(tensor_parallel_size=world_size))
|
||||||
tensor_parallel_size=world_size))
|
|
||||||
|
|
||||||
with monkeypatch.context() as m, set_current_vllm_config(config):
|
with monkeypatch.context() as m, set_current_vllm_config(config):
|
||||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
@@ -42,34 +42,34 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
|
|||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
update_environment_variables({
|
update_environment_variables(
|
||||||
'RANK': str(local_rank),
|
{
|
||||||
'LOCAL_RANK': str(local_rank),
|
"RANK": str(local_rank),
|
||||||
'WORLD_SIZE': str(world_size),
|
"LOCAL_RANK": str(local_rank),
|
||||||
'MASTER_ADDR': 'localhost',
|
"WORLD_SIZE": str(world_size),
|
||||||
'MASTER_PORT': '12345',
|
"MASTER_ADDR": "localhost",
|
||||||
})
|
"MASTER_PORT": "12345",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
|
||||||
cuda_communicator = typing.cast(CudaCommunicator,
|
cuda_communicator = typing.cast(
|
||||||
get_tp_group().device_communicator)
|
CudaCommunicator, get_tp_group().device_communicator
|
||||||
|
)
|
||||||
symm_mem_comm = cuda_communicator.symm_mem_comm
|
symm_mem_comm = cuda_communicator.symm_mem_comm
|
||||||
if symm_mem_comm is None or symm_mem_comm.disabled:
|
if symm_mem_comm is None or symm_mem_comm.disabled:
|
||||||
# can't use skip under multiprocessing
|
# can't use skip under multiprocessing
|
||||||
q.put("SymmMemCommunicator is not available or disabled.")
|
q.put("SymmMemCommunicator is not available or disabled.")
|
||||||
return
|
return
|
||||||
|
|
||||||
inp_direct_symm_mem = torch.randint(1,
|
inp_direct_symm_mem = torch.randint(
|
||||||
23, (test_size_elements, ),
|
1, 23, (test_size_elements,), dtype=dtype, device=device
|
||||||
dtype=dtype,
|
)
|
||||||
device=device)
|
|
||||||
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
|
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
|
||||||
# can't use skip under multiprocessing
|
# can't use skip under multiprocessing
|
||||||
q.put(
|
q.put("SymmMemCommunicator isn't used for this world and input size.")
|
||||||
"SymmMemCommunicator isn't used for this world and input size."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
|
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
|
||||||
@@ -78,42 +78,37 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
|
|||||||
|
|
||||||
group = get_tp_group().device_group
|
group = get_tp_group().device_group
|
||||||
dist.all_reduce(original_inp_direct_symm_mem, group=group)
|
dist.all_reduce(original_inp_direct_symm_mem, group=group)
|
||||||
torch.testing.assert_close(out_direct_symm_mem,
|
torch.testing.assert_close(
|
||||||
original_inp_direct_symm_mem,
|
out_direct_symm_mem, original_inp_direct_symm_mem, atol=2.5, rtol=0.1
|
||||||
atol=2.5,
|
)
|
||||||
rtol=0.1)
|
|
||||||
|
|
||||||
# Test tensor_model_parallel_all_reduce which should use symm_mem
|
# Test tensor_model_parallel_all_reduce which should use symm_mem
|
||||||
inp_tensor_parallel = torch.randint(-23,
|
inp_tensor_parallel = torch.randint(
|
||||||
1, (test_size_elements, ),
|
-23, 1, (test_size_elements,), dtype=dtype, device=device
|
||||||
dtype=dtype,
|
)
|
||||||
device=device)
|
|
||||||
original_inp_tensor_parallel = inp_tensor_parallel.clone()
|
original_inp_tensor_parallel = inp_tensor_parallel.clone()
|
||||||
out_tensor_parallel = tensor_model_parallel_all_reduce(
|
out_tensor_parallel = tensor_model_parallel_all_reduce(inp_tensor_parallel)
|
||||||
inp_tensor_parallel)
|
|
||||||
dist.all_reduce(original_inp_tensor_parallel, group=group)
|
dist.all_reduce(original_inp_tensor_parallel, group=group)
|
||||||
torch.testing.assert_close(out_tensor_parallel,
|
torch.testing.assert_close(
|
||||||
original_inp_tensor_parallel,
|
out_tensor_parallel, original_inp_tensor_parallel, atol=2.5, rtol=0.1
|
||||||
atol=2.5,
|
)
|
||||||
rtol=0.1)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not current_platform.is_cuda(),
|
not current_platform.is_cuda(),
|
||||||
reason="SymmMemAllreduce is only available for CUDA platforms.")
|
reason="SymmMemAllreduce is only available for CUDA platforms.",
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("tp_size", [2])
|
@pytest.mark.parametrize("tp_size", [2])
|
||||||
@pytest.mark.parametrize("pipeline_parallel_size", [1])
|
@pytest.mark.parametrize("pipeline_parallel_size", [1])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||||
reason="Only test on CUDA")
|
def test_symm_mem_allreduce(
|
||||||
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
|
monkeypatch: pytest.MonkeyPatch, tp_size, pipeline_parallel_size
|
||||||
pipeline_parallel_size):
|
):
|
||||||
world_size = tp_size * pipeline_parallel_size
|
world_size = tp_size * pipeline_parallel_size
|
||||||
if world_size > torch.cuda.device_count():
|
if world_size > torch.cuda.device_count():
|
||||||
pytest.skip("Not enough GPUs to run the test.")
|
pytest.skip("Not enough GPUs to run the test.")
|
||||||
q = mp.get_context('spawn').Queue()
|
q = mp.get_context("spawn").Queue()
|
||||||
mp.spawn(symm_mem_allreduce_worker,
|
mp.spawn(symm_mem_allreduce_worker, args=(world_size, q), nprocs=world_size)
|
||||||
args=(world_size, q),
|
|
||||||
nprocs=world_size)
|
|
||||||
try:
|
try:
|
||||||
val = q.get(timeout=1)
|
val = q.get(timeout=1)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
@@ -126,18 +121,20 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
|
|||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not current_platform.is_cuda(),
|
not current_platform.is_cuda(),
|
||||||
reason="SymmMemAllreduce is only available for CUDA platforms.")
|
reason="SymmMemAllreduce is only available for CUDA platforms.",
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
)
|
||||||
reason="Only test on CUDA")
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||||
def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch):
|
def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch):
|
||||||
world_size = 4
|
world_size = 4
|
||||||
if world_size > torch.cuda.device_count():
|
if world_size > torch.cuda.device_count():
|
||||||
pytest.skip("Not enough GPUs to run the test.")
|
pytest.skip("Not enough GPUs to run the test.")
|
||||||
# Verify that the DataParallel runs without error
|
# Verify that the DataParallel runs without error
|
||||||
engine_args = EngineArgs(model="distilbert/distilgpt2",
|
engine_args = EngineArgs(
|
||||||
|
model="distilbert/distilgpt2",
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
enable_prefix_caching=True,
|
enable_prefix_caching=True,
|
||||||
data_parallel_size=2,
|
data_parallel_size=2,
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
data_parallel_backend="mp")
|
data_parallel_backend="mp",
|
||||||
|
)
|
||||||
LLMEngine.from_engine_args(engine_args)
|
LLMEngine.from_engine_args(engine_args)
|
||||||
|
|||||||
@@ -24,13 +24,15 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|||||||
|
|
||||||
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
|
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
|
||||||
# to test if all ranks agree on the same kv cache configuration.
|
# to test if all ranks agree on the same kv cache configuration.
|
||||||
llm = LLM(model="facebook/opt-125m",
|
llm = LLM(
|
||||||
|
model="facebook/opt-125m",
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)),
|
pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)),
|
||||||
distributed_executor_backend="external_launcher",
|
distributed_executor_backend="external_launcher",
|
||||||
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
||||||
swap_space=random.randint(1, 4),
|
swap_space=random.randint(1, 4),
|
||||||
seed=0)
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
@@ -48,15 +50,14 @@ def test_consistent_across_ranks(obj):
|
|||||||
assert container[0] == obj
|
assert container[0] == obj
|
||||||
|
|
||||||
|
|
||||||
test_consistent_across_ranks(
|
test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
||||||
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
||||||
test_consistent_across_ranks(
|
|
||||||
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
|
||||||
|
|
||||||
# make sure we can access the model parameters from the calling process
|
# make sure we can access the model parameters from the calling process
|
||||||
# of the `LLM` instance.
|
# of the `LLM` instance.
|
||||||
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
|
params = list(
|
||||||
model.parameters())
|
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model.parameters()
|
||||||
|
)
|
||||||
test_consistent_across_ranks(len(params))
|
test_consistent_across_ranks(len(params))
|
||||||
|
|
||||||
# all ranks should have the same outputs
|
# all ranks should have the same outputs
|
||||||
@@ -65,5 +66,4 @@ for output in outputs:
|
|||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
test_consistent_across_ranks(prompt)
|
test_consistent_across_ranks(prompt)
|
||||||
test_consistent_across_ranks(generated_text)
|
test_consistent_across_ranks(generated_text)
|
||||||
print(f"Rank {torch_rank}, Prompt: {prompt!r}, "
|
print(f"Rank {torch_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
f"Generated text: {generated_text!r}")
|
|
||||||
|
|||||||
@@ -24,23 +24,22 @@ dp_rank = int(os.getenv("DP_RANK", "0"))
|
|||||||
|
|
||||||
if dp_size > 1:
|
if dp_size > 1:
|
||||||
# distribute the prompts across the data parallel ranks
|
# distribute the prompts across the data parallel ranks
|
||||||
prompts = [
|
prompts = [prompt for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank]
|
||||||
prompt for idx, prompt in enumerate(prompts)
|
|
||||||
if idx % dp_size == dp_rank
|
|
||||||
]
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
|
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
|
||||||
# to test if all ranks agree on the same kv cache configuration.
|
# to test if all ranks agree on the same kv cache configuration.
|
||||||
llm = LLM(model="microsoft/Phi-mini-MoE-instruct",
|
llm = LLM(
|
||||||
|
model="microsoft/Phi-mini-MoE-instruct",
|
||||||
tensor_parallel_size=int(os.getenv("TP_SIZE", "1")),
|
tensor_parallel_size=int(os.getenv("TP_SIZE", "1")),
|
||||||
pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")),
|
pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")),
|
||||||
enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1,
|
enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1,
|
||||||
distributed_executor_backend="external_launcher",
|
distributed_executor_backend="external_launcher",
|
||||||
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
||||||
swap_space=random.randint(1, 4),
|
swap_space=random.randint(1, 4),
|
||||||
seed=0)
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
@@ -54,21 +53,18 @@ def test_consistent_across_ranks(obj):
|
|||||||
dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group)
|
dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group)
|
||||||
else:
|
else:
|
||||||
container = [None]
|
container = [None]
|
||||||
dist.broadcast_object_list(container,
|
dist.broadcast_object_list(container, src=group.ranks[0], group=cpu_group)
|
||||||
src=group.ranks[0],
|
|
||||||
group=cpu_group)
|
|
||||||
assert container[0] == obj
|
assert container[0] == obj
|
||||||
|
|
||||||
|
|
||||||
test_consistent_across_ranks(
|
test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
||||||
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
||||||
test_consistent_across_ranks(
|
|
||||||
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
|
||||||
|
|
||||||
# make sure we can access the model parameters from the calling process
|
# make sure we can access the model parameters from the calling process
|
||||||
# of the `LLM` instance.
|
# of the `LLM` instance.
|
||||||
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
|
params = list(
|
||||||
model.parameters())
|
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model.parameters()
|
||||||
|
)
|
||||||
test_consistent_across_ranks(len(params))
|
test_consistent_across_ranks(len(params))
|
||||||
|
|
||||||
# all ranks should have the same outputs
|
# all ranks should have the same outputs
|
||||||
@@ -77,5 +73,4 @@ for output in outputs:
|
|||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
test_consistent_across_ranks(prompt)
|
test_consistent_across_ranks(prompt)
|
||||||
test_consistent_across_ranks(generated_text)
|
test_consistent_across_ranks(generated_text)
|
||||||
print(f"Rank {group_rank}, Prompt: {prompt!r}, "
|
print(f"Rank {group_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
f"Generated text: {generated_text!r}")
|
|
||||||
|
|||||||
@@ -10,21 +10,22 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
from vllm.distributed.utils import StatelessProcessGroup
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
from vllm.utils import (cuda_device_count_stateless, get_open_port,
|
from vllm.utils import (
|
||||||
update_environment_variables)
|
cuda_device_count_stateless,
|
||||||
|
get_open_port,
|
||||||
|
update_environment_variables,
|
||||||
|
)
|
||||||
|
|
||||||
from ..utils import multi_gpu_test
|
from ..utils import multi_gpu_test
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
class _CUDADeviceCountStatelessTestActor:
|
class _CUDADeviceCountStatelessTestActor:
|
||||||
|
|
||||||
def get_count(self):
|
def get_count(self):
|
||||||
return cuda_device_count_stateless()
|
return cuda_device_count_stateless()
|
||||||
|
|
||||||
def set_cuda_visible_devices(self, cuda_visible_devices: str):
|
def set_cuda_visible_devices(self, cuda_visible_devices: str):
|
||||||
update_environment_variables(
|
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||||
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
|
||||||
|
|
||||||
def get_cuda_visible_devices(self):
|
def get_cuda_visible_devices(self):
|
||||||
return envs.CUDA_VISIBLE_DEVICES
|
return envs.CUDA_VISIBLE_DEVICES
|
||||||
@@ -34,10 +35,9 @@ def test_cuda_device_count_stateless():
|
|||||||
"""Test that cuda_device_count_stateless changes return value if
|
"""Test that cuda_device_count_stateless changes return value if
|
||||||
CUDA_VISIBLE_DEVICES is changed."""
|
CUDA_VISIBLE_DEVICES is changed."""
|
||||||
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
|
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
|
||||||
num_gpus=2).remote()
|
num_gpus=2
|
||||||
assert len(
|
).remote()
|
||||||
sorted(ray.get(
|
assert len(sorted(ray.get(actor.get_cuda_visible_devices.remote()).split(","))) == 2
|
||||||
actor.get_cuda_visible_devices.remote()).split(","))) == 2
|
|
||||||
assert ray.get(actor.get_count.remote()) == 2
|
assert ray.get(actor.get_count.remote()) == 2
|
||||||
ray.get(actor.set_cuda_visible_devices.remote("0"))
|
ray.get(actor.set_cuda_visible_devices.remote("0"))
|
||||||
assert ray.get(actor.get_count.remote()) == 1
|
assert ray.get(actor.get_count.remote()) == 1
|
||||||
@@ -46,15 +46,13 @@ def test_cuda_device_count_stateless():
|
|||||||
|
|
||||||
|
|
||||||
def cpu_worker(rank, WORLD_SIZE, port1, port2):
|
def cpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
pg1 = StatelessProcessGroup.create(
|
||||||
port=port1,
|
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||||
rank=rank,
|
)
|
||||||
world_size=WORLD_SIZE)
|
|
||||||
if rank <= 2:
|
if rank <= 2:
|
||||||
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
pg2 = StatelessProcessGroup.create(
|
||||||
port=port2,
|
host="127.0.0.1", port=port2, rank=rank, world_size=3
|
||||||
rank=rank,
|
)
|
||||||
world_size=3)
|
|
||||||
data = torch.tensor([rank])
|
data = torch.tensor([rank])
|
||||||
data = pg1.broadcast_obj(data, src=2)
|
data = pg1.broadcast_obj(data, src=2)
|
||||||
assert data.item() == 2
|
assert data.item() == 2
|
||||||
@@ -68,16 +66,14 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2):
|
|||||||
|
|
||||||
def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
pg1 = StatelessProcessGroup.create(
|
||||||
port=port1,
|
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||||
rank=rank,
|
)
|
||||||
world_size=WORLD_SIZE)
|
|
||||||
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
||||||
if rank <= 2:
|
if rank <= 2:
|
||||||
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
pg2 = StatelessProcessGroup.create(
|
||||||
port=port2,
|
host="127.0.0.1", port=port2, rank=rank, world_size=3
|
||||||
rank=rank,
|
)
|
||||||
world_size=3)
|
|
||||||
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
||||||
data = torch.tensor([rank]).cuda()
|
data = torch.tensor([rank]).cuda()
|
||||||
pynccl1.all_reduce(data)
|
pynccl1.all_reduce(data)
|
||||||
@@ -96,10 +92,9 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
|||||||
|
|
||||||
|
|
||||||
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
|
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
pg1 = StatelessProcessGroup.create(
|
||||||
port=port1,
|
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||||
rank=rank,
|
)
|
||||||
world_size=WORLD_SIZE)
|
|
||||||
if rank == 2:
|
if rank == 2:
|
||||||
pg1.broadcast_obj("secret", src=2)
|
pg1.broadcast_obj("secret", src=2)
|
||||||
else:
|
else:
|
||||||
@@ -109,10 +104,9 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2):
|
|||||||
|
|
||||||
|
|
||||||
def allgather_worker(rank, WORLD_SIZE, port1, port2):
|
def allgather_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
pg1 = StatelessProcessGroup.create(
|
||||||
port=port1,
|
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||||
rank=rank,
|
)
|
||||||
world_size=WORLD_SIZE)
|
|
||||||
data = pg1.all_gather_obj(rank)
|
data = pg1.all_gather_obj(rank)
|
||||||
assert data == list(range(WORLD_SIZE))
|
assert data == list(range(WORLD_SIZE))
|
||||||
pg1.barrier()
|
pg1.barrier()
|
||||||
@@ -121,7 +115,8 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2):
|
|||||||
@pytest.mark.skip(reason="This test is flaky and prone to hang.")
|
@pytest.mark.skip(reason="This test is flaky and prone to hang.")
|
||||||
@multi_gpu_test(num_gpus=4)
|
@multi_gpu_test(num_gpus=4)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
|
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker]
|
||||||
|
)
|
||||||
def test_stateless_process_group(worker):
|
def test_stateless_process_group(worker):
|
||||||
port1 = get_open_port()
|
port1 = get_open_port()
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
@@ -129,12 +124,14 @@ def test_stateless_process_group(worker):
|
|||||||
port2 = get_open_port()
|
port2 = get_open_port()
|
||||||
WORLD_SIZE = 4
|
WORLD_SIZE = 4
|
||||||
from multiprocessing import get_context
|
from multiprocessing import get_context
|
||||||
|
|
||||||
ctx = get_context("fork")
|
ctx = get_context("fork")
|
||||||
processes = []
|
processes = []
|
||||||
for i in range(WORLD_SIZE):
|
for i in range(WORLD_SIZE):
|
||||||
rank = i
|
rank = i
|
||||||
processes.append(
|
processes.append(
|
||||||
ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2)))
|
ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2))
|
||||||
|
)
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.start()
|
p.start()
|
||||||
for p in processes:
|
for p in processes:
|
||||||
|
|||||||
@@ -10,22 +10,30 @@ from typing import Annotated, Literal, Optional, Union
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import CompilationConfig, config
|
from vllm.config import CompilationConfig, config
|
||||||
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
from vllm.engine.arg_utils import (
|
||||||
get_type, get_type_hints, is_not_builtin,
|
EngineArgs,
|
||||||
is_type, literal_to_kwargs, optional_type,
|
contains_type,
|
||||||
parse_type)
|
get_kwargs,
|
||||||
|
get_type,
|
||||||
|
get_type_hints,
|
||||||
|
is_not_builtin,
|
||||||
|
is_type,
|
||||||
|
literal_to_kwargs,
|
||||||
|
optional_type,
|
||||||
|
parse_type,
|
||||||
|
)
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type", "value", "expected"), [
|
@pytest.mark.parametrize(
|
||||||
|
("type", "value", "expected"),
|
||||||
|
[
|
||||||
(int, "42", 42),
|
(int, "42", 42),
|
||||||
(float, "3.14", 3.14),
|
(float, "3.14", 3.14),
|
||||||
(str, "Hello World!", "Hello World!"),
|
(str, "Hello World!", "Hello World!"),
|
||||||
(json.loads, '{"foo":1,"bar":2}', {
|
(json.loads, '{"foo":1,"bar":2}', {"foo": 1, "bar": 2}),
|
||||||
"foo": 1,
|
],
|
||||||
"bar": 2
|
)
|
||||||
}),
|
|
||||||
])
|
|
||||||
def test_parse_type(type, value, expected):
|
def test_parse_type(type, value, expected):
|
||||||
parse_type_func = parse_type(type)
|
parse_type_func = parse_type(type)
|
||||||
assert parse_type_func(value) == expected
|
assert parse_type_func(value) == expected
|
||||||
@@ -37,18 +45,23 @@ def test_optional_type():
|
|||||||
assert optional_type_func("42") == 42
|
assert optional_type_func("42") == 42
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
|
@pytest.mark.parametrize(
|
||||||
|
("type_hint", "type", "expected"),
|
||||||
|
[
|
||||||
(int, int, True),
|
(int, int, True),
|
||||||
(int, float, False),
|
(int, float, False),
|
||||||
(list[int], list, True),
|
(list[int], list, True),
|
||||||
(list[int], tuple, False),
|
(list[int], tuple, False),
|
||||||
(Literal[0, 1], Literal, True),
|
(Literal[0, 1], Literal, True),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_is_type(type_hint, type, expected):
|
def test_is_type(type_hint, type, expected):
|
||||||
assert is_type(type_hint, type) == expected
|
assert is_type(type_hint, type) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
|
@pytest.mark.parametrize(
|
||||||
|
("type_hints", "type", "expected"),
|
||||||
|
[
|
||||||
({float, int}, int, True),
|
({float, int}, int, True),
|
||||||
({int, tuple}, int, True),
|
({int, tuple}, int, True),
|
||||||
({int, tuple[int]}, int, True),
|
({int, tuple[int]}, int, True),
|
||||||
@@ -56,31 +69,32 @@ def test_is_type(type_hint, type, expected):
|
|||||||
({int, tuple[int]}, float, False),
|
({int, tuple[int]}, float, False),
|
||||||
({int, tuple[int, ...]}, float, False),
|
({int, tuple[int, ...]}, float, False),
|
||||||
({str, Literal["x", "y"]}, Literal, True),
|
({str, Literal["x", "y"]}, Literal, True),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_contains_type(type_hints, type, expected):
|
def test_contains_type(type_hints, type, expected):
|
||||||
assert contains_type(type_hints, type) == expected
|
assert contains_type(type_hints, type) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
|
@pytest.mark.parametrize(
|
||||||
|
("type_hints", "type", "expected"),
|
||||||
|
[
|
||||||
({int, float}, int, int),
|
({int, float}, int, int),
|
||||||
({int, float}, str, None),
|
({int, float}, str, None),
|
||||||
({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
|
({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_get_type(type_hints, type, expected):
|
def test_get_type(type_hints, type, expected):
|
||||||
assert get_type(type_hints, type) == expected
|
assert get_type(type_hints, type) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type_hints", "expected"), [
|
@pytest.mark.parametrize(
|
||||||
({Literal[1, 2]}, {
|
("type_hints", "expected"),
|
||||||
"type": int,
|
[
|
||||||
"choices": [1, 2]
|
({Literal[1, 2]}, {"type": int, "choices": [1, 2]}),
|
||||||
}),
|
({str, Literal["x", "y"]}, {"type": str, "metavar": ["x", "y"]}),
|
||||||
({str, Literal["x", "y"]}, {
|
|
||||||
"type": str,
|
|
||||||
"metavar": ["x", "y"]
|
|
||||||
}),
|
|
||||||
({Literal[1, "a"]}, Exception),
|
({Literal[1, "a"]}, Exception),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_literal_to_kwargs(type_hints, expected):
|
def test_literal_to_kwargs(type_hints, expected):
|
||||||
context = nullcontext()
|
context = nullcontext()
|
||||||
if expected is Exception:
|
if expected is Exception:
|
||||||
@@ -123,22 +137,27 @@ class DummyConfig:
|
|||||||
"""Nested config"""
|
"""Nested config"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type_hint", "expected"), [
|
@pytest.mark.parametrize(
|
||||||
|
("type_hint", "expected"),
|
||||||
|
[
|
||||||
(int, False),
|
(int, False),
|
||||||
(DummyConfig, True),
|
(DummyConfig, True),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_is_not_builtin(type_hint, expected):
|
def test_is_not_builtin(type_hint, expected):
|
||||||
assert is_not_builtin(type_hint) == expected
|
assert is_not_builtin(type_hint) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("type_hint", "expected"), [
|
("type_hint", "expected"),
|
||||||
|
[
|
||||||
(Annotated[int, "annotation"], {int}),
|
(Annotated[int, "annotation"], {int}),
|
||||||
(Optional[int], {int, type(None)}),
|
(Optional[int], {int, type(None)}),
|
||||||
(Annotated[Optional[int], "annotation"], {int, type(None)}),
|
(Annotated[Optional[int], "annotation"], {int, type(None)}),
|
||||||
(Optional[Annotated[int, "annotation"]], {int, type(None)}),
|
(Optional[Annotated[int, "annotation"]], {int, type(None)}),
|
||||||
],
|
],
|
||||||
ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"])
|
ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"],
|
||||||
|
)
|
||||||
def test_get_type_hints(type_hint, expected):
|
def test_get_type_hints(type_hint, expected):
|
||||||
assert get_type_hints(type_hint) == expected
|
assert get_type_hints(type_hint) == expected
|
||||||
|
|
||||||
@@ -178,24 +197,16 @@ def test_get_kwargs():
|
|||||||
("arg", "expected"),
|
("arg", "expected"),
|
||||||
[
|
[
|
||||||
(None, dict()),
|
(None, dict()),
|
||||||
('{"video": {"num_frames": 123} }', {
|
('{"video": {"num_frames": 123} }', {"video": {"num_frames": 123}}),
|
||||||
"video": {
|
|
||||||
"num_frames": 123
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
(
|
(
|
||||||
'{"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, "image": {"foo": "bar"} }', # noqa
|
'{"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, "image": {"foo": "bar"} }', # noqa
|
||||||
{
|
{
|
||||||
"video": {
|
"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"},
|
||||||
"num_frames": 123,
|
"image": {"foo": "bar"},
|
||||||
"fps": 1.0,
|
|
||||||
"foo": "bar"
|
|
||||||
},
|
},
|
||||||
"image": {
|
),
|
||||||
"foo": "bar"
|
],
|
||||||
}
|
)
|
||||||
}),
|
|
||||||
])
|
|
||||||
def test_media_io_kwargs_parser(arg, expected):
|
def test_media_io_kwargs_parser(arg, expected):
|
||||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||||
if arg is None:
|
if arg is None:
|
||||||
@@ -230,24 +241,32 @@ def test_compilation_config():
|
|||||||
assert args.compilation_config.level == 3
|
assert args.compilation_config.level == 3
|
||||||
|
|
||||||
# set to string form of a dict
|
# set to string form of a dict
|
||||||
args = parser.parse_args([
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
"-O",
|
"-O",
|
||||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||||
'"use_inductor": false}',
|
'"use_inductor": false}',
|
||||||
])
|
]
|
||||||
assert (args.compilation_config.level == 3 and
|
)
|
||||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
assert (
|
||||||
and not args.compilation_config.use_inductor)
|
args.compilation_config.level == 3
|
||||||
|
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||||
|
and not args.compilation_config.use_inductor
|
||||||
|
)
|
||||||
|
|
||||||
# set to string form of a dict
|
# set to string form of a dict
|
||||||
args = parser.parse_args([
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
"--compilation-config="
|
"--compilation-config="
|
||||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||||
'"use_inductor": true}',
|
'"use_inductor": true}',
|
||||||
])
|
]
|
||||||
assert (args.compilation_config.level == 3 and
|
)
|
||||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
assert (
|
||||||
and args.compilation_config.use_inductor)
|
args.compilation_config.level == 3
|
||||||
|
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||||
|
and args.compilation_config.use_inductor
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_prefix_cache_default():
|
def test_prefix_cache_default():
|
||||||
@@ -255,8 +274,7 @@ def test_prefix_cache_default():
|
|||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
|
|
||||||
engine_args = EngineArgs.from_cli_args(args=args)
|
engine_args = EngineArgs.from_cli_args(args=args)
|
||||||
assert (not engine_args.enable_prefix_caching
|
assert not engine_args.enable_prefix_caching, "prefix caching defaults to off."
|
||||||
), "prefix caching defaults to off."
|
|
||||||
|
|
||||||
# with flag to turn it on.
|
# with flag to turn it on.
|
||||||
args = parser.parse_args(["--enable-prefix-caching"])
|
args = parser.parse_args(["--enable-prefix-caching"])
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ import pytest
|
|||||||
|
|
||||||
from ..conftest import IMAGE_ASSETS
|
from ..conftest import IMAGE_ASSETS
|
||||||
|
|
||||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts(
|
||||||
"stop_sign":
|
{
|
||||||
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
|
"stop_sign": "USER: <image>\nWhat's the content of the image?\nASSISTANT:",
|
||||||
"cherry_blossom":
|
"cherry_blossom": "USER: <image>\nWhat is the season?\nASSISTANT:",
|
||||||
"USER: <image>\nWhat is the season?\nASSISTANT:",
|
}
|
||||||
})
|
)
|
||||||
|
|
||||||
models = ["llava-hf/llava-1.5-7b-hf"]
|
models = ["llava-hf/llava-1.5-7b-hf"]
|
||||||
|
|
||||||
@@ -19,8 +19,7 @@ models = ["llava-hf/llava-1.5-7b-hf"]
|
|||||||
def test_context_length_too_short(vllm_runner, image_assets, model):
|
def test_context_length_too_short(vllm_runner, image_assets, model):
|
||||||
images = [asset.pil_image for asset in image_assets]
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
with pytest.raises(ValueError,
|
with pytest.raises(ValueError, match="longer than the maximum model length"):
|
||||||
match="longer than the maximum model length"):
|
|
||||||
vllm_model = vllm_runner(
|
vllm_model = vllm_runner(
|
||||||
model,
|
model,
|
||||||
max_model_len=128, # LLaVA has a feature size of 576
|
max_model_len=128, # LLaVA has a feature size of 576
|
||||||
@@ -29,6 +28,6 @@ def test_context_length_too_short(vllm_runner, image_assets, model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with vllm_model:
|
with vllm_model:
|
||||||
vllm_model.generate_greedy([HF_IMAGE_PROMPTS[0]],
|
vllm_model.generate_greedy(
|
||||||
max_tokens=1,
|
[HF_IMAGE_PROMPTS[0]], max_tokens=1, images=[images[0]]
|
||||||
images=[images[0]])
|
)
|
||||||
|
|||||||
@@ -26,8 +26,10 @@ def sample_token_ids():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_regex():
|
def sample_regex():
|
||||||
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
return (
|
||||||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
||||||
|
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -35,40 +37,27 @@ def sample_json_schema():
|
|||||||
return {
|
return {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"name": {
|
"name": {"type": "string"},
|
||||||
"type": "string"
|
"age": {"type": "integer"},
|
||||||
},
|
|
||||||
"age": {
|
|
||||||
"type": "integer"
|
|
||||||
},
|
|
||||||
"skills": {
|
"skills": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {"type": "string", "maxLength": 10},
|
||||||
"type": "string",
|
"minItems": 3,
|
||||||
"maxLength": 10
|
|
||||||
},
|
|
||||||
"minItems": 3
|
|
||||||
},
|
},
|
||||||
"work_history": {
|
"work_history": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"company": {
|
"company": {"type": "string"},
|
||||||
"type": "string"
|
"duration": {"type": "number"},
|
||||||
|
"position": {"type": "string"},
|
||||||
},
|
},
|
||||||
"duration": {
|
"required": ["company", "position"],
|
||||||
"type": "number"
|
|
||||||
},
|
},
|
||||||
"position": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["company", "position"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["name", "age", "skills", "work_history"]
|
"required": ["name", "age", "skills", "work_history"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -80,65 +69,53 @@ def sample_complex_json_schema():
|
|||||||
"score": {
|
"score": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"minimum": 0,
|
"minimum": 0,
|
||||||
"maximum": 100 # Numeric range
|
"maximum": 100, # Numeric range
|
||||||
},
|
},
|
||||||
"grade": {
|
"grade": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"pattern": "^[A-D]$" # Regex pattern
|
"pattern": "^[A-D]$", # Regex pattern
|
||||||
},
|
},
|
||||||
"email": {
|
"email": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
|
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$",
|
||||||
},
|
},
|
||||||
"tags": {
|
"tags": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"pattern":
|
"pattern": "^[a-z]{1,10}$", # Combining length and pattern restrictions
|
||||||
"^[a-z]{1,10}$" # Combining length and pattern restrictions
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["score", "grade", "email", "tags"]
|
},
|
||||||
|
},
|
||||||
|
"required": ["score", "grade", "email", "tags"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_definition_json_schema():
|
def sample_definition_json_schema():
|
||||||
return {
|
return {
|
||||||
'$defs': {
|
"$defs": {
|
||||||
'Step': {
|
"Step": {
|
||||||
'properties': {
|
"properties": {
|
||||||
'explanation': {
|
"explanation": {"title": "Explanation", "type": "string"},
|
||||||
'title': 'Explanation',
|
"output": {"title": "Output", "type": "string"},
|
||||||
'type': 'string'
|
|
||||||
},
|
},
|
||||||
'output': {
|
"required": ["explanation", "output"],
|
||||||
'title': 'Output',
|
"title": "Step",
|
||||||
'type': 'string'
|
"type": "object",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'required': ['explanation', 'output'],
|
"properties": {
|
||||||
'title': 'Step',
|
"steps": {
|
||||||
'type': 'object'
|
"items": {"$ref": "#/$defs/Step"},
|
||||||
}
|
"title": "Steps",
|
||||||
|
"type": "array",
|
||||||
},
|
},
|
||||||
'properties': {
|
"final_answer": {"title": "Final Answer", "type": "string"},
|
||||||
'steps': {
|
|
||||||
'items': {
|
|
||||||
'$ref': '#/$defs/Step'
|
|
||||||
},
|
},
|
||||||
'title': 'Steps',
|
"required": ["steps", "final_answer"],
|
||||||
'type': 'array'
|
"title": "MathReasoning",
|
||||||
},
|
"type": "object",
|
||||||
'final_answer': {
|
|
||||||
'title': 'Final Answer',
|
|
||||||
'type': 'string'
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'required': ['steps', 'final_answer'],
|
|
||||||
'title': 'MathReasoning',
|
|
||||||
'type': 'object'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -149,64 +126,71 @@ def sample_enum_json_schema():
|
|||||||
"properties": {
|
"properties": {
|
||||||
"status": {
|
"status": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["active", "inactive",
|
"enum": ["active", "inactive", "pending"], # Literal values using enum
|
||||||
"pending"] # Literal values using enum
|
|
||||||
},
|
},
|
||||||
"priority": {
|
"priority": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["low", "medium", "high", "critical"]
|
"enum": ["low", "medium", "high", "critical"],
|
||||||
},
|
},
|
||||||
"category": {
|
"category": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"type": {
|
"type": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["bug", "feature", "improvement"]
|
"enum": ["bug", "feature", "improvement"],
|
||||||
},
|
},
|
||||||
"severity": {
|
"severity": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"enum": [1, 2, 3, 4,
|
"enum": [1, 2, 3, 4, 5], # Enum can also contain numbers
|
||||||
5] # Enum can also contain numbers
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["type", "severity"]
|
},
|
||||||
|
"required": ["type", "severity"],
|
||||||
},
|
},
|
||||||
"flags": {
|
"flags": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["urgent", "blocked", "needs_review", "approved"]
|
"enum": ["urgent", "blocked", "needs_review", "approved"],
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["status", "priority", "category", "flags"]
|
},
|
||||||
|
},
|
||||||
|
"required": ["status", "priority", "category", "flags"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_structured_outputs_choices():
|
def sample_structured_outputs_choices():
|
||||||
return [
|
return [
|
||||||
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
|
"Python",
|
||||||
"Ruby", "Swift", "Kotlin"
|
"Java",
|
||||||
|
"JavaScript",
|
||||||
|
"C++",
|
||||||
|
"C#",
|
||||||
|
"PHP",
|
||||||
|
"TypeScript",
|
||||||
|
"Ruby",
|
||||||
|
"Swift",
|
||||||
|
"Kotlin",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_sql_statements():
|
def sample_sql_statements():
|
||||||
return ("""
|
return """
|
||||||
start: select_statement
|
start: select_statement
|
||||||
select_statement: "SELECT" column "from" table "where" condition
|
select_statement: "SELECT" column "from" table "where" condition
|
||||||
column: "col_1" | "col_2"
|
column: "col_1" | "col_2"
|
||||||
table: "table_1" | "table_2"
|
table: "table_1" | "table_2"
|
||||||
condition: column "=" number
|
condition: column "=" number
|
||||||
number: "1" | "2"
|
number: "1" | "2"
|
||||||
""")
|
"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def zephyr_lora_files():
|
def zephyr_lora_files():
|
||||||
"""Download zephyr LoRA files once per test session."""
|
"""Download zephyr LoRA files once per test session."""
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora")
|
return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora")
|
||||||
|
|
||||||
|
|
||||||
@@ -214,5 +198,5 @@ def zephyr_lora_files():
|
|||||||
def opt125_lora_files() -> str:
|
def opt125_lora_files() -> str:
|
||||||
"""Download opt-125m LoRA files once per test session."""
|
"""Download opt-125m LoRA files once per test session."""
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
return snapshot_download(
|
|
||||||
repo_id="peft-internal-testing/opt-125m-dummy-lora")
|
return snapshot_download(repo_id="peft-internal-testing/opt-125m-dummy-lora")
|
||||||
|
|||||||
@@ -48,20 +48,23 @@ def run_test(model_name, more_args=None):
|
|||||||
|
|
||||||
measured_value = results["results"][TASK][FILTER]
|
measured_value = results["results"][TASK][FILTER]
|
||||||
assert model_name in EXPECTED_VALUES, (
|
assert model_name in EXPECTED_VALUES, (
|
||||||
f"Cannot find the expected value for the model {model_name=}")
|
f"Cannot find the expected value for the model {model_name=}"
|
||||||
|
)
|
||||||
expected_value = EXPECTED_VALUES[model_name]
|
expected_value = EXPECTED_VALUES[model_name]
|
||||||
assert (measured_value - RTOL < expected_value
|
assert (
|
||||||
|
measured_value - RTOL < expected_value
|
||||||
and measured_value + RTOL > expected_value
|
and measured_value + RTOL > expected_value
|
||||||
), f"Expected: {expected_value} | Measured: {measured_value}"
|
), f"Expected: {expected_value} | Measured: {measured_value}"
|
||||||
|
|
||||||
|
|
||||||
# TODO: [AlexM] Fix it with new CI/CD tests
|
# TODO: [AlexM] Fix it with new CI/CD tests
|
||||||
TPU_TP_TEST_STR = "" #"tensor_parallel_size=4"
|
TPU_TP_TEST_STR = "" # "tensor_parallel_size=4"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
@pytest.mark.skipif(
|
||||||
and not current_platform.is_tpu(),
|
not current_platform.is_cuda() and not current_platform.is_tpu(),
|
||||||
reason="V1 is currently only supported on CUDA and TPU")
|
reason="V1 is currently only supported on CUDA and TPU",
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("model", MODEL_NAMES)
|
@pytest.mark.parametrize("model", MODEL_NAMES)
|
||||||
def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
|
def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Run with the V1 Engine."""
|
"""Run with the V1 Engine."""
|
||||||
@@ -82,12 +85,14 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
|
|||||||
run_test(model, more_args)
|
run_test(model, more_args)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
@pytest.mark.skipif(
|
||||||
and not current_platform.is_tpu(),
|
not current_platform.is_cuda() and not current_platform.is_tpu(),
|
||||||
reason="V1 is currently only supported on CUDA and TPU")
|
reason="V1 is currently only supported on CUDA and TPU",
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES)
|
@pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES)
|
||||||
def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(
|
def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(
|
||||||
model, monkeypatch: pytest.MonkeyPatch):
|
model, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
"""Run with the V1 Engine."""
|
"""Run with the V1 Engine."""
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
|
|||||||
@@ -14,9 +14,7 @@ from ..openai.test_vision import TEST_IMAGE_ASSETS
|
|||||||
def text_llm():
|
def text_llm():
|
||||||
# pytest caches the fixture so we use weakref.proxy to
|
# pytest caches the fixture so we use weakref.proxy to
|
||||||
# enable garbage collection
|
# enable garbage collection
|
||||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, seed=0)
|
||||||
enforce_eager=True,
|
|
||||||
seed=0)
|
|
||||||
|
|
||||||
yield weakref.proxy(llm)
|
yield weakref.proxy(llm)
|
||||||
|
|
||||||
@@ -28,14 +26,8 @@ def text_llm():
|
|||||||
def test_chat(text_llm):
|
def test_chat(text_llm):
|
||||||
prompt1 = "Explain the concept of entropy."
|
prompt1 = "Explain the concept of entropy."
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
"role": "system",
|
{"role": "user", "content": prompt1},
|
||||||
"content": "You are a helpful assistant"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt1
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
outputs = text_llm.chat(messages)
|
outputs = text_llm.chat(messages)
|
||||||
assert len(outputs) == 1
|
assert len(outputs) == 1
|
||||||
@@ -46,25 +38,13 @@ def test_multi_chat(text_llm):
|
|||||||
prompt2 = "Explain what among us is."
|
prompt2 = "Explain what among us is."
|
||||||
|
|
||||||
conversation1 = [
|
conversation1 = [
|
||||||
{
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
"role": "system",
|
{"role": "user", "content": prompt1},
|
||||||
"content": "You are a helpful assistant"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt1
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
conversation2 = [
|
conversation2 = [
|
||||||
{
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
"role": "system",
|
{"role": "user", "content": prompt2},
|
||||||
"content": "You are a helpful assistant"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt2
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
messages = [conversation1, conversation2]
|
messages = [conversation1, conversation2]
|
||||||
@@ -94,26 +74,22 @@ def vision_llm():
|
|||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("image_urls",
|
@pytest.mark.parametrize(
|
||||||
[[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]],
|
"image_urls", [[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]], indirect=True
|
||||||
indirect=True)
|
)
|
||||||
def test_chat_multi_image(vision_llm, image_urls: list[str]):
|
def test_chat_multi_image(vision_llm, image_urls: list[str]):
|
||||||
messages = [{
|
messages = [
|
||||||
"role":
|
|
||||||
"user",
|
|
||||||
"content": [
|
|
||||||
*({
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": image_url
|
|
||||||
}
|
|
||||||
} for image_url in image_urls),
|
|
||||||
{
|
{
|
||||||
"type": "text",
|
"role": "user",
|
||||||
"text": "What's in this image?"
|
"content": [
|
||||||
},
|
*(
|
||||||
|
{"type": "image_url", "image_url": {"url": image_url}}
|
||||||
|
for image_url in image_urls
|
||||||
|
),
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
],
|
],
|
||||||
}]
|
}
|
||||||
|
]
|
||||||
outputs = vision_llm.chat(messages)
|
outputs = vision_llm.chat(messages)
|
||||||
assert len(outputs) >= 0
|
assert len(outputs) >= 0
|
||||||
|
|
||||||
@@ -124,14 +100,8 @@ def test_llm_chat_tokenization_no_double_bos(text_llm):
|
|||||||
Check we get a single BOS token for llama chat.
|
Check we get a single BOS token for llama chat.
|
||||||
"""
|
"""
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
"role": "system",
|
{"role": "user", "content": "Hello!"},
|
||||||
"content": "You are a helpful assistant"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Hello!"
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
outputs = text_llm.chat(messages)
|
outputs = text_llm.chat(messages)
|
||||||
assert len(outputs) == 1
|
assert len(outputs) == 1
|
||||||
@@ -167,14 +137,8 @@ def thinking_llm():
|
|||||||
@pytest.mark.parametrize("enable_thinking", [True, False])
|
@pytest.mark.parametrize("enable_thinking", [True, False])
|
||||||
def test_chat_extra_kwargs(thinking_llm, enable_thinking):
|
def test_chat_extra_kwargs(thinking_llm, enable_thinking):
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
"role": "system",
|
{"role": "user", "content": "What is 1+1?"},
|
||||||
"content": "You are a helpful assistant"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "What is 1+1?"
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
outputs = thinking_llm.chat(
|
outputs = thinking_llm.chat(
|
||||||
|
|||||||
@@ -23,9 +23,11 @@ def test_collective_rpc(tp_size, backend, monkeypatch):
|
|||||||
return self.rank
|
return self.rank
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
llm = LLM(
|
||||||
|
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
load_format="dummy",
|
load_format="dummy",
|
||||||
tensor_parallel_size=tp_size,
|
tensor_parallel_size=tp_size,
|
||||||
distributed_executor_backend=backend)
|
distributed_executor_backend=backend,
|
||||||
|
)
|
||||||
assert llm.collective_rpc(echo_rank) == list(range(tp_size))
|
assert llm.collective_rpc(echo_rank) == list(range(tp_size))
|
||||||
|
|||||||
@@ -29,11 +29,13 @@ TOKEN_IDS = [
|
|||||||
def llm():
|
def llm():
|
||||||
# pytest caches the fixture so we use weakref.proxy to
|
# pytest caches the fixture so we use weakref.proxy to
|
||||||
# enable garbage collection
|
# enable garbage collection
|
||||||
llm = LLM(model=MODEL_NAME,
|
llm = LLM(
|
||||||
|
model=MODEL_NAME,
|
||||||
max_num_batched_tokens=4096,
|
max_num_batched_tokens=4096,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
gpu_memory_utilization=0.10,
|
gpu_memory_utilization=0.10,
|
||||||
enforce_eager=True)
|
enforce_eager=True,
|
||||||
|
)
|
||||||
|
|
||||||
yield weakref.proxy(llm)
|
yield weakref.proxy(llm)
|
||||||
|
|
||||||
@@ -81,7 +83,8 @@ def test_max_model_len():
|
|||||||
outputs = llm.generate(PROMPTS, sampling_params)
|
outputs = llm.generate(PROMPTS, sampling_params)
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
num_total_tokens = len(output.prompt_token_ids) + len(
|
num_total_tokens = len(output.prompt_token_ids) + len(
|
||||||
output.outputs[0].token_ids)
|
output.outputs[0].token_ids
|
||||||
|
)
|
||||||
# Total tokens must not exceed max_model_len + 1 (the last token can be
|
# Total tokens must not exceed max_model_len + 1 (the last token can be
|
||||||
# generated with the context length equal to the max model length)
|
# generated with the context length equal to the max model length)
|
||||||
# It can be less if generation finishes due to other reasons (e.g., EOS)
|
# It can be less if generation finishes due to other reasons (e.g., EOS)
|
||||||
|
|||||||
@@ -16,9 +16,8 @@ def test_gpu_memory_utilization():
|
|||||||
# makes sure gpu_memory_utilization is per-instance limit,
|
# makes sure gpu_memory_utilization is per-instance limit,
|
||||||
# not a global limit
|
# not a global limit
|
||||||
llms = [
|
llms = [
|
||||||
LLM(model="facebook/opt-125m",
|
LLM(model="facebook/opt-125m", gpu_memory_utilization=0.3, enforce_eager=True)
|
||||||
gpu_memory_utilization=0.3,
|
for i in range(3)
|
||||||
enforce_eager=True) for i in range(3)
|
|
||||||
]
|
]
|
||||||
for llm in llms:
|
for llm in llms:
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ from vllm import LLM
|
|||||||
|
|
||||||
def test_empty_prompt():
|
def test_empty_prompt():
|
||||||
llm = LLM(model="openai-community/gpt2", enforce_eager=True)
|
llm = LLM(model="openai-community/gpt2", enforce_eager=True)
|
||||||
with pytest.raises(ValueError, match='decoder prompt cannot be empty'):
|
with pytest.raises(ValueError, match="decoder prompt cannot be empty"):
|
||||||
llm.generate([""])
|
llm.generate([""])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_v1
|
@pytest.mark.skip_v1
|
||||||
def test_out_of_vocab_token():
|
def test_out_of_vocab_token():
|
||||||
llm = LLM(model="openai-community/gpt2", enforce_eager=True)
|
llm = LLM(model="openai-community/gpt2", enforce_eager=True)
|
||||||
with pytest.raises(ValueError, match='out of vocabulary'):
|
with pytest.raises(ValueError, match="out of vocabulary"):
|
||||||
llm.generate({"prompt_token_ids": [999999]})
|
llm.generate({"prompt_token_ids": [999999]})
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Tests for HF_HUB_OFFLINE mode"""
|
"""Tests for HF_HUB_OFFLINE mode"""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import importlib
|
import importlib
|
||||||
import sys
|
import sys
|
||||||
@@ -91,12 +92,11 @@ def test_offline_mode(monkeypatch: pytest.MonkeyPatch):
|
|||||||
|
|
||||||
|
|
||||||
def _re_import_modules():
|
def _re_import_modules():
|
||||||
hf_hub_module_names = [
|
hf_hub_module_names = [k for k in sys.modules if k.startswith("huggingface_hub")]
|
||||||
k for k in sys.modules if k.startswith("huggingface_hub")
|
|
||||||
]
|
|
||||||
transformers_module_names = [
|
transformers_module_names = [
|
||||||
k for k in sys.modules if k.startswith("transformers")
|
k
|
||||||
and not k.startswith("transformers_modules")
|
for k in sys.modules
|
||||||
|
if k.startswith("transformers") and not k.startswith("transformers_modules")
|
||||||
]
|
]
|
||||||
|
|
||||||
reload_exception = None
|
reload_exception = None
|
||||||
|
|||||||
@@ -7,14 +7,14 @@ from vllm.assets.audio import AudioAsset
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mary_had_lamb():
|
def mary_had_lamb():
|
||||||
path = AudioAsset('mary_had_lamb').get_local_path()
|
path = AudioAsset("mary_had_lamb").get_local_path()
|
||||||
with open(str(path), "rb") as f:
|
with open(str(path), "rb") as f:
|
||||||
yield f
|
yield f
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def winning_call():
|
def winning_call():
|
||||||
path = AudioAsset('winning_call').get_local_path()
|
path = AudioAsset("winning_call").get_local_path()
|
||||||
with open(str(path), "rb") as f:
|
with open(str(path), "rb") as f:
|
||||||
yield f
|
yield f
|
||||||
|
|
||||||
@@ -22,6 +22,6 @@ def winning_call():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def foscolo():
|
def foscolo():
|
||||||
# Test translation it->en
|
# Test translation it->en
|
||||||
path = AudioAsset('azacinto_foscolo').get_local_path()
|
path = AudioAsset("azacinto_foscolo").get_local_path()
|
||||||
with open(str(path), "rb") as f:
|
with open(str(path), "rb") as f:
|
||||||
yield f
|
yield f
|
||||||
|
|||||||
@@ -44,14 +44,15 @@ def run_test(more_args):
|
|||||||
print(f"Running with: {args}")
|
print(f"Running with: {args}")
|
||||||
|
|
||||||
with RemoteOpenAIServer(
|
with RemoteOpenAIServer(
|
||||||
MODEL_NAME, args,
|
MODEL_NAME, args, max_wait_seconds=MAX_WAIT_SECONDS
|
||||||
max_wait_seconds=MAX_WAIT_SECONDS) as remote_server:
|
) as remote_server:
|
||||||
url = f"{remote_server.url_for('v1')}/completions"
|
url = f"{remote_server.url_for('v1')}/completions"
|
||||||
|
|
||||||
model_args = (
|
model_args = (
|
||||||
f"model={MODEL_NAME},"
|
f"model={MODEL_NAME},"
|
||||||
f"base_url={url},"
|
f"base_url={url},"
|
||||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
|
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False"
|
||||||
|
)
|
||||||
|
|
||||||
results = lm_eval.simple_evaluate(
|
results = lm_eval.simple_evaluate(
|
||||||
model="local-completions",
|
model="local-completions",
|
||||||
@@ -60,15 +61,18 @@ def run_test(more_args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
measured_value = results["results"][TASK][FILTER]
|
measured_value = results["results"][TASK][FILTER]
|
||||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
assert (
|
||||||
|
measured_value - RTOL < EXPECTED_VALUE
|
||||||
and measured_value + RTOL > EXPECTED_VALUE
|
and measured_value + RTOL > EXPECTED_VALUE
|
||||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.is_cuda()
|
||||||
and not current_platform.is_tpu()
|
and not current_platform.is_tpu()
|
||||||
and not current_platform.is_xpu(),
|
and not current_platform.is_xpu(),
|
||||||
reason="V1 currently only supported on CUDA, XPU and TPU")
|
reason="V1 currently only supported on CUDA, XPU and TPU",
|
||||||
|
)
|
||||||
def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
|
def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Run with the V1 Engine."""
|
"""Run with the V1 Engine."""
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ a baseline.
|
|||||||
This simulates real work usage of the API and makes sure that the frontend and
|
This simulates real work usage of the API and makes sure that the frontend and
|
||||||
AsyncLLMEngine are working correctly.
|
AsyncLLMEngine are working correctly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import time
|
import time
|
||||||
@@ -45,7 +46,8 @@ async def transcribe_audio(client, tokenizer, y, sr):
|
|||||||
# NOTE there's no streaming in transcriptions, can't measure ttft
|
# NOTE there's no streaming in transcriptions, can't measure ttft
|
||||||
latency = end_time - start_time
|
latency = end_time - start_time
|
||||||
num_output_tokens = len(
|
num_output_tokens = len(
|
||||||
tokenizer(transcription.text, add_special_tokens=False).input_ids)
|
tokenizer(transcription.text, add_special_tokens=False).input_ids
|
||||||
|
)
|
||||||
return latency, num_output_tokens, transcription.text
|
return latency, num_output_tokens, transcription.text
|
||||||
|
|
||||||
|
|
||||||
@@ -73,8 +75,8 @@ async def process_dataset(model, client, data, concurrent_request):
|
|||||||
for sample in data:
|
for sample in data:
|
||||||
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
|
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
bound_transcribe(sem, client, tokenizer, (audio, sr),
|
bound_transcribe(sem, client, tokenizer, (audio, sr), sample["text"])
|
||||||
sample["text"]))
|
)
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
return await asyncio.gather(*tasks)
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
@@ -98,34 +100,35 @@ def print_performance_metrics(results, total_time):
|
|||||||
|
|
||||||
|
|
||||||
def add_duration(sample):
|
def add_duration(sample):
|
||||||
y, sr = sample['audio']["array"], sample['audio']["sampling_rate"]
|
y, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
|
||||||
sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000
|
sample["duration_ms"] = librosa.get_duration(y=y, sr=sr) * 1000
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs):
|
def load_hf_dataset(dataset_repo: str, split="validation", **hf_kwargs):
|
||||||
## Load and filter the dataset
|
## Load and filter the dataset
|
||||||
dataset = load_dataset(dataset_repo, split=split, **hf_kwargs)
|
dataset = load_dataset(dataset_repo, split=split, **hf_kwargs)
|
||||||
if 'duration_ms' not in dataset[0]:
|
if "duration_ms" not in dataset[0]:
|
||||||
# compute duration to filter
|
# compute duration to filter
|
||||||
dataset = dataset.map(add_duration)
|
dataset = dataset.map(add_duration)
|
||||||
|
|
||||||
# Whisper max supported duration
|
# Whisper max supported duration
|
||||||
dataset = dataset.filter(lambda example: example['duration_ms'] < 30000)
|
dataset = dataset.filter(lambda example: example["duration_ms"] < 30000)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def run_evaluation(model: str,
|
def run_evaluation(
|
||||||
|
model: str,
|
||||||
client,
|
client,
|
||||||
dataset,
|
dataset,
|
||||||
max_concurrent_reqs: int,
|
max_concurrent_reqs: int,
|
||||||
n_examples: int = -1,
|
n_examples: int = -1,
|
||||||
print_metrics: bool = True):
|
print_metrics: bool = True,
|
||||||
|
):
|
||||||
if n_examples > 0:
|
if n_examples > 0:
|
||||||
dataset = dataset.select(range(n_examples))
|
dataset = dataset.select(range(n_examples))
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
results = asyncio.run(
|
results = asyncio.run(process_dataset(model, client, dataset, max_concurrent_reqs))
|
||||||
process_dataset(model, client, dataset, max_concurrent_reqs))
|
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
total_time = end - start
|
total_time = end - start
|
||||||
print(f"Total Test Time: {total_time:.4f} seconds")
|
print(f"Total Test Time: {total_time:.4f} seconds")
|
||||||
@@ -135,8 +138,7 @@ def run_evaluation(model: str,
|
|||||||
predictions = [res[2] for res in results]
|
predictions = [res[2] for res in results]
|
||||||
references = [res[3] for res in results]
|
references = [res[3] for res in results]
|
||||||
wer = load("wer")
|
wer = load("wer")
|
||||||
wer_score = 100 * wer.compute(references=references,
|
wer_score = 100 * wer.compute(references=references, predictions=predictions)
|
||||||
predictions=predictions)
|
|
||||||
print("WER:", wer_score)
|
print("WER:", wer_score)
|
||||||
return wer_score
|
return wer_score
|
||||||
|
|
||||||
@@ -145,26 +147,25 @@ def run_evaluation(model: str,
|
|||||||
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"])
|
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"])
|
||||||
# Original dataset is 20GB+ in size, hence we use a pre-filtered slice.
|
# Original dataset is 20GB+ in size, hence we use a pre-filtered slice.
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"])
|
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]
|
||||||
|
)
|
||||||
# NOTE: Expected WER measured with equivalent hf.transformers args:
|
# NOTE: Expected WER measured with equivalent hf.transformers args:
|
||||||
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
|
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
|
||||||
@pytest.mark.parametrize("expected_wer", [12.744980])
|
@pytest.mark.parametrize("expected_wer", [12.744980])
|
||||||
def test_wer_correctness(model_name,
|
def test_wer_correctness(
|
||||||
dataset_repo,
|
model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None
|
||||||
expected_wer,
|
):
|
||||||
n_examples=-1,
|
|
||||||
max_concurrent_request=None):
|
|
||||||
# TODO refactor to use `ASRDataset`
|
# TODO refactor to use `ASRDataset`
|
||||||
with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server:
|
with RemoteOpenAIServer(model_name, ["--enforce-eager"]) as remote_server:
|
||||||
dataset = load_hf_dataset(dataset_repo)
|
dataset = load_hf_dataset(dataset_repo)
|
||||||
|
|
||||||
if not max_concurrent_request:
|
if not max_concurrent_request:
|
||||||
# No max concurrency
|
# No max concurrency
|
||||||
max_concurrent_request = n_examples if n_examples > 0\
|
max_concurrent_request = n_examples if n_examples > 0 else len(dataset)
|
||||||
else len(dataset)
|
|
||||||
|
|
||||||
client = remote_server.get_async_client()
|
client = remote_server.get_async_client()
|
||||||
wer = run_evaluation(model_name, client, dataset,
|
wer = run_evaluation(
|
||||||
max_concurrent_request, n_examples)
|
model_name, client, dataset, max_concurrent_request, n_examples
|
||||||
|
)
|
||||||
if expected_wer:
|
if expected_wer:
|
||||||
torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2)
|
torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2)
|
||||||
|
|||||||
@@ -44,15 +44,11 @@ async def client(server):
|
|||||||
ids=["completion", "chat"],
|
ids=["completion", "chat"],
|
||||||
argnames=["create_func_gen", "content_body"],
|
argnames=["create_func_gen", "content_body"],
|
||||||
argvalues=[
|
argvalues=[
|
||||||
(lambda x: x.completions.create, {
|
(lambda x: x.completions.create, {"prompt": " ".join(["A"] * 10_000)}),
|
||||||
"prompt": " ".join(['A'] * 10_000)
|
(
|
||||||
}),
|
lambda x: x.chat.completions.create,
|
||||||
(lambda x: x.chat.completions.create, {
|
{"messages": [{"role": "user", "content": " ".join(["A"] * 10_000)}]},
|
||||||
"messages": [{
|
),
|
||||||
"role": "user",
|
|
||||||
"content": " ".join(['A'] * 10_000)
|
|
||||||
}]
|
|
||||||
}),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_with_and_without_truncate(
|
async def test_with_and_without_truncate(
|
||||||
@@ -65,15 +61,15 @@ async def test_with_and_without_truncate(
|
|||||||
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
|
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
|
||||||
|
|
||||||
num_requests = 10
|
num_requests = 10
|
||||||
truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] *
|
truncate_prompt_tokens = [1000] * (num_requests // 2) + [None] * (
|
||||||
(num_requests - num_requests // 2))
|
num_requests - num_requests // 2
|
||||||
|
)
|
||||||
random.shuffle(truncate_prompt_tokens)
|
random.shuffle(truncate_prompt_tokens)
|
||||||
|
|
||||||
bodies = [{
|
bodies = [
|
||||||
**body, "extra_body": {
|
{**body, "extra_body": {"truncate_prompt_tokens": t}}
|
||||||
'truncate_prompt_tokens': t
|
for t in truncate_prompt_tokens
|
||||||
}
|
]
|
||||||
} for t in truncate_prompt_tokens]
|
|
||||||
|
|
||||||
async def get_status_code(**kwargs):
|
async def get_status_code(**kwargs):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -56,24 +56,18 @@ def base64_encoded_audio() -> dict[str, str]:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||||
async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
async def test_single_chat_session_audio(
|
||||||
model_name: str, audio_url: str):
|
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||||
messages = [{
|
):
|
||||||
"role":
|
messages = [
|
||||||
"user",
|
{
|
||||||
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{"type": "audio_url", "audio_url": {"url": audio_url}},
|
||||||
"type": "audio_url",
|
{"type": "text", "text": "What's happening in this audio?"},
|
||||||
"audio_url": {
|
|
||||||
"url": audio_url
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What's happening in this audio?"
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
}]
|
}
|
||||||
|
]
|
||||||
|
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
@@ -82,13 +76,15 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
|||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
top_logprobs=5)
|
top_logprobs=5,
|
||||||
|
)
|
||||||
assert len(chat_completion.choices) == 1
|
assert len(chat_completion.choices) == 1
|
||||||
|
|
||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
completion_tokens=10, prompt_tokens=202, total_tokens=212
|
||||||
|
)
|
||||||
|
|
||||||
message = choice.message
|
message = choice.message
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
@@ -110,56 +106,52 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||||
async def test_error_on_invalid_audio_url_type(client: openai.AsyncOpenAI,
|
async def test_error_on_invalid_audio_url_type(
|
||||||
model_name: str,
|
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||||
audio_url: str):
|
):
|
||||||
messages = [{
|
messages = [
|
||||||
"role":
|
{
|
||||||
"user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{"type": "audio_url", "audio_url": audio_url},
|
||||||
"type": "audio_url",
|
{"type": "text", "text": "What's happening in this audio?"},
|
||||||
"audio_url": audio_url
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What's happening in this audio?"
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
}]
|
}
|
||||||
|
]
|
||||||
|
|
||||||
# audio_url should be a dict {"url": "some url"}, not directly a string
|
# audio_url should be a dict {"url": "some url"}, not directly a string
|
||||||
with pytest.raises(openai.BadRequestError):
|
with pytest.raises(openai.BadRequestError):
|
||||||
_ = await client.chat.completions.create(model=model_name,
|
_ = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
temperature=0.0)
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||||
async def test_single_chat_session_audio_base64encoded(
|
async def test_single_chat_session_audio_base64encoded(
|
||||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
|
client: openai.AsyncOpenAI,
|
||||||
base64_encoded_audio: dict[str, str]):
|
model_name: str,
|
||||||
|
audio_url: str,
|
||||||
messages = [{
|
base64_encoded_audio: dict[str, str],
|
||||||
"role":
|
):
|
||||||
"user",
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "audio_url",
|
"type": "audio_url",
|
||||||
"audio_url": {
|
"audio_url": {
|
||||||
"url":
|
"url": f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
|
||||||
f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What's happening in this audio?"
|
|
||||||
},
|
},
|
||||||
|
{"type": "text", "text": "What's happening in this audio?"},
|
||||||
],
|
],
|
||||||
}]
|
}
|
||||||
|
]
|
||||||
|
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
@@ -168,13 +160,15 @@ async def test_single_chat_session_audio_base64encoded(
|
|||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
top_logprobs=5)
|
top_logprobs=5,
|
||||||
|
)
|
||||||
assert len(chat_completion.choices) == 1
|
assert len(chat_completion.choices) == 1
|
||||||
|
|
||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
completion_tokens=10, prompt_tokens=202, total_tokens=212
|
||||||
|
)
|
||||||
|
|
||||||
message = choice.message
|
message = choice.message
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
@@ -198,25 +192,26 @@ async def test_single_chat_session_audio_base64encoded(
|
|||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||||
async def test_single_chat_session_input_audio(
|
async def test_single_chat_session_input_audio(
|
||||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
|
client: openai.AsyncOpenAI,
|
||||||
base64_encoded_audio: dict[str, str]):
|
model_name: str,
|
||||||
messages = [{
|
audio_url: str,
|
||||||
"role":
|
base64_encoded_audio: dict[str, str],
|
||||||
"user",
|
):
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "input_audio",
|
"type": "input_audio",
|
||||||
"input_audio": {
|
"input_audio": {
|
||||||
"data": base64_encoded_audio[audio_url],
|
"data": base64_encoded_audio[audio_url],
|
||||||
"format": "wav"
|
"format": "wav",
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What's happening in this audio?"
|
|
||||||
},
|
},
|
||||||
|
{"type": "text", "text": "What's happening in this audio?"},
|
||||||
],
|
],
|
||||||
}]
|
}
|
||||||
|
]
|
||||||
|
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
@@ -224,13 +219,15 @@ async def test_single_chat_session_input_audio(
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
top_logprobs=5)
|
top_logprobs=5,
|
||||||
|
)
|
||||||
assert len(chat_completion.choices) == 1
|
assert len(chat_completion.choices) == 1
|
||||||
|
|
||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
completion_tokens=10, prompt_tokens=202, total_tokens=212
|
||||||
|
)
|
||||||
|
|
||||||
message = choice.message
|
message = choice.message
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
@@ -252,24 +249,18 @@ async def test_single_chat_session_input_audio(
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||||
async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
|
async def test_chat_streaming_audio(
|
||||||
model_name: str, audio_url: str):
|
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||||
messages = [{
|
):
|
||||||
"role":
|
messages = [
|
||||||
"user",
|
{
|
||||||
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{"type": "audio_url", "audio_url": {"url": audio_url}},
|
||||||
"type": "audio_url",
|
{"type": "text", "text": "What's happening in this audio?"},
|
||||||
"audio_url": {
|
|
||||||
"url": audio_url
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What's happening in this audio?"
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
}]
|
}
|
||||||
|
]
|
||||||
|
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
@@ -309,27 +300,27 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||||
async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
|
async def test_chat_streaming_input_audio(
|
||||||
model_name: str, audio_url: str,
|
client: openai.AsyncOpenAI,
|
||||||
base64_encoded_audio: dict[str,
|
model_name: str,
|
||||||
str]):
|
audio_url: str,
|
||||||
messages = [{
|
base64_encoded_audio: dict[str, str],
|
||||||
"role":
|
):
|
||||||
"user",
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "input_audio",
|
"type": "input_audio",
|
||||||
"input_audio": {
|
"input_audio": {
|
||||||
"data": base64_encoded_audio[audio_url],
|
"data": base64_encoded_audio[audio_url],
|
||||||
"format": "wav"
|
"format": "wav",
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What's happening in this audio?"
|
|
||||||
},
|
},
|
||||||
|
{"type": "text", "text": "What's happening in this audio?"},
|
||||||
],
|
],
|
||||||
}]
|
}
|
||||||
|
]
|
||||||
|
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
@@ -369,26 +360,23 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]])
|
"audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]]
|
||||||
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
|
)
|
||||||
audio_urls: list[str]):
|
async def test_multi_audio_input(
|
||||||
|
client: openai.AsyncOpenAI, model_name: str, audio_urls: list[str]
|
||||||
messages = [{
|
):
|
||||||
"role":
|
messages = [
|
||||||
"user",
|
|
||||||
"content": [
|
|
||||||
*({
|
|
||||||
"type": "audio_url",
|
|
||||||
"audio_url": {
|
|
||||||
"url": audio_url
|
|
||||||
}
|
|
||||||
} for audio_url in audio_urls),
|
|
||||||
{
|
{
|
||||||
"type": "text",
|
"role": "user",
|
||||||
"text": "What's happening in this audio?"
|
"content": [
|
||||||
},
|
*(
|
||||||
|
{"type": "audio_url", "audio_url": {"url": audio_url}}
|
||||||
|
for audio_url in audio_urls
|
||||||
|
),
|
||||||
|
{"type": "text", "text": "What's happening in this audio?"},
|
||||||
],
|
],
|
||||||
}]
|
}
|
||||||
|
]
|
||||||
|
|
||||||
if len(audio_urls) > MAXIMUM_AUDIOS:
|
if len(audio_urls) > MAXIMUM_AUDIOS:
|
||||||
with pytest.raises(openai.BadRequestError): # test multi-audio input
|
with pytest.raises(openai.BadRequestError): # test multi-audio input
|
||||||
|
|||||||
@@ -16,9 +16,9 @@ from ...utils import RemoteOpenAIServer
|
|||||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='module')
|
@pytest.fixture(scope="module")
|
||||||
def server_args(request: pytest.FixtureRequest) -> list[str]:
|
def server_args(request: pytest.FixtureRequest) -> list[str]:
|
||||||
""" Provide extra arguments to the server via indirect parametrization
|
"""Provide extra arguments to the server via indirect parametrization
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
@@ -80,8 +80,10 @@ async def client(server):
|
|||||||
"server_args",
|
"server_args",
|
||||||
[
|
[
|
||||||
pytest.param([], id="default-frontend-multiprocessing"),
|
pytest.param([], id="default-frontend-multiprocessing"),
|
||||||
pytest.param(["--disable-frontend-multiprocessing"],
|
pytest.param(
|
||||||
id="disable-frontend-multiprocessing")
|
["--disable-frontend-multiprocessing"],
|
||||||
|
id="disable-frontend-multiprocessing",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@@ -97,8 +99,10 @@ async def test_show_version(server: RemoteOpenAIServer):
|
|||||||
"server_args",
|
"server_args",
|
||||||
[
|
[
|
||||||
pytest.param([], id="default-frontend-multiprocessing"),
|
pytest.param([], id="default-frontend-multiprocessing"),
|
||||||
pytest.param(["--disable-frontend-multiprocessing"],
|
pytest.param(
|
||||||
id="disable-frontend-multiprocessing")
|
["--disable-frontend-multiprocessing"],
|
||||||
|
id="disable-frontend-multiprocessing",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@@ -112,11 +116,13 @@ async def test_check_health(server: RemoteOpenAIServer):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"server_args",
|
"server_args",
|
||||||
[
|
[
|
||||||
pytest.param(["--max-model-len", "10100"],
|
pytest.param(
|
||||||
id="default-frontend-multiprocessing"),
|
["--max-model-len", "10100"], id="default-frontend-multiprocessing"
|
||||||
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
["--disable-frontend-multiprocessing", "--max-model-len", "10100"],
|
["--disable-frontend-multiprocessing", "--max-model-len", "10100"],
|
||||||
id="disable-frontend-multiprocessing")
|
id="disable-frontend-multiprocessing",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@@ -131,14 +137,16 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
|
|||||||
# Request about 2 million tokens
|
# Request about 2 million tokens
|
||||||
for _ in range(200):
|
for _ in range(200):
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
client.chat.completions.create(messages=chat_input,
|
client.chat.completions.create(
|
||||||
|
messages=chat_input,
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
max_tokens=10000,
|
max_tokens=10000,
|
||||||
extra_body={"min_tokens": 10000}))
|
extra_body={"min_tokens": 10000},
|
||||||
|
)
|
||||||
|
)
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
|
|
||||||
done, pending = await asyncio.wait(tasks,
|
done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
|
||||||
return_when=asyncio.ALL_COMPLETED)
|
|
||||||
|
|
||||||
# Make sure all requests were sent to the server and timed out
|
# Make sure all requests were sent to the server and timed out
|
||||||
# (We don't want to hide other errors like 400s that would invalidate this
|
# (We don't want to hide other errors like 400s that would invalidate this
|
||||||
@@ -151,16 +159,15 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
|
|||||||
# If the server had not cancelled all the other requests, then it would not
|
# If the server had not cancelled all the other requests, then it would not
|
||||||
# be able to respond to this one within the timeout
|
# be able to respond to this one within the timeout
|
||||||
client = server.get_async_client(timeout=5)
|
client = server.get_async_client(timeout=5)
|
||||||
response = await client.chat.completions.create(messages=chat_input,
|
response = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
messages=chat_input, model=MODEL_NAME, max_tokens=10
|
||||||
max_tokens=10)
|
)
|
||||||
|
|
||||||
assert len(response.choices) == 1
|
assert len(response.choices) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_request_wrong_content_type(server: RemoteOpenAIServer):
|
async def test_request_wrong_content_type(server: RemoteOpenAIServer):
|
||||||
|
|
||||||
chat_input = [{"role": "user", "content": "Write a long story"}]
|
chat_input = [{"role": "user", "content": "Write a long story"}]
|
||||||
client = server.get_async_client()
|
client = server.get_async_client()
|
||||||
|
|
||||||
@@ -169,17 +176,13 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer):
|
|||||||
messages=chat_input,
|
messages=chat_input,
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
max_tokens=10000,
|
max_tokens=10000,
|
||||||
extra_headers={
|
extra_headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
"Content-Type": "application/x-www-form-urlencoded"
|
)
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"server_args",
|
"server_args",
|
||||||
[
|
[pytest.param(["--enable-server-load-tracking"], id="enable-server-load-tracking")],
|
||||||
pytest.param(["--enable-server-load-tracking"],
|
|
||||||
id="enable-server-load-tracking")
|
|
||||||
],
|
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -202,7 +205,8 @@ async def test_server_load(server: RemoteOpenAIServer):
|
|||||||
|
|
||||||
# Start the completion request in a background thread.
|
# Start the completion request in a background thread.
|
||||||
completion_future = asyncio.create_task(
|
completion_future = asyncio.create_task(
|
||||||
asyncio.to_thread(make_long_completion_request))
|
asyncio.to_thread(make_long_completion_request)
|
||||||
|
)
|
||||||
|
|
||||||
# Give a short delay to ensure the request has started.
|
# Give a short delay to ensure the request has started.
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|||||||
@@ -23,14 +23,15 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
|||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def monkeypatch_module():
|
def monkeypatch_module():
|
||||||
from _pytest.monkeypatch import MonkeyPatch
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
|
|
||||||
mpatch = MonkeyPatch()
|
mpatch = MonkeyPatch()
|
||||||
yield mpatch
|
yield mpatch
|
||||||
mpatch.undo()
|
mpatch.undo()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server(monkeypatch_module, zephyr_lora_files): #noqa: F811
|
def server(monkeypatch_module, zephyr_lora_files): # noqa: F811
|
||||||
monkeypatch_module.setenv('VLLM_USE_V1', '1')
|
monkeypatch_module.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
args = [
|
args = [
|
||||||
# use half precision for speed and memory savings in CI environment
|
# use half precision for speed and memory savings in CI environment
|
||||||
@@ -68,20 +69,18 @@ async def client(server):
|
|||||||
[MODEL_NAME, "zephyr-lora"],
|
[MODEL_NAME, "zephyr-lora"],
|
||||||
)
|
)
|
||||||
async def test_no_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
|
async def test_no_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
|
||||||
messages = [{
|
messages = [
|
||||||
"role": "system",
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
"content": "you are a helpful assistant"
|
{"role": "user", "content": "what is 1+1?"},
|
||||||
}, {
|
]
|
||||||
"role": "user",
|
|
||||||
"content": "what is 1+1?"
|
|
||||||
}]
|
|
||||||
|
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=5,
|
max_completion_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
logprobs=False)
|
logprobs=False,
|
||||||
|
)
|
||||||
|
|
||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.logprobs is None
|
assert choice.logprobs is None
|
||||||
@@ -94,13 +93,10 @@ async def test_no_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
[MODEL_NAME, "zephyr-lora"],
|
[MODEL_NAME, "zephyr-lora"],
|
||||||
)
|
)
|
||||||
async def test_zero_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
|
async def test_zero_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
|
||||||
messages = [{
|
messages = [
|
||||||
"role": "system",
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
"content": "you are a helpful assistant"
|
{"role": "user", "content": "what is 1+1?"},
|
||||||
}, {
|
]
|
||||||
"role": "user",
|
|
||||||
"content": "what is 1+1?"
|
|
||||||
}]
|
|
||||||
|
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
@@ -108,7 +104,8 @@ async def test_zero_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
max_completion_tokens=5,
|
max_completion_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
top_logprobs=0)
|
top_logprobs=0,
|
||||||
|
)
|
||||||
|
|
||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.logprobs is not None
|
assert choice.logprobs is not None
|
||||||
@@ -122,13 +119,10 @@ async def test_zero_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
[MODEL_NAME, "zephyr-lora"],
|
[MODEL_NAME, "zephyr-lora"],
|
||||||
)
|
)
|
||||||
async def test_some_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
|
async def test_some_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
|
||||||
messages = [{
|
messages = [
|
||||||
"role": "system",
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
"content": "you are a helpful assistant"
|
{"role": "user", "content": "what is 1+1?"},
|
||||||
}, {
|
]
|
||||||
"role": "user",
|
|
||||||
"content": "what is 1+1?"
|
|
||||||
}]
|
|
||||||
|
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
@@ -136,7 +130,8 @@ async def test_some_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
max_completion_tokens=5,
|
max_completion_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
top_logprobs=5)
|
top_logprobs=5,
|
||||||
|
)
|
||||||
|
|
||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.logprobs is not None
|
assert choice.logprobs is not None
|
||||||
@@ -149,41 +144,39 @@ async def test_some_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
"model_name",
|
"model_name",
|
||||||
[MODEL_NAME, "zephyr-lora"],
|
[MODEL_NAME, "zephyr-lora"],
|
||||||
)
|
)
|
||||||
async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI,
|
async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||||
model_name: str):
|
messages = [
|
||||||
messages = [{
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
"role": "system",
|
{"role": "user", "content": "what is 1+1?"},
|
||||||
"content": "you are a helpful assistant"
|
]
|
||||||
}, {
|
|
||||||
"role": "user",
|
|
||||||
"content": "what is 1+1?"
|
|
||||||
}]
|
|
||||||
|
|
||||||
# Default max_logprobs is 20, so this should raise an error
|
# Default max_logprobs is 20, so this should raise an error
|
||||||
with pytest.raises((openai.BadRequestError, openai.APIError)):
|
with pytest.raises((openai.BadRequestError, openai.APIError)):
|
||||||
stream = await client.chat.completions.create(model=model_name,
|
stream = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
top_logprobs=21,
|
top_logprobs=21,
|
||||||
stream=True)
|
stream=True,
|
||||||
|
)
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
...
|
...
|
||||||
|
|
||||||
with pytest.raises(openai.BadRequestError):
|
with pytest.raises(openai.BadRequestError):
|
||||||
await client.chat.completions.create(model=model_name,
|
await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
top_logprobs=30,
|
top_logprobs=30,
|
||||||
stream=False)
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
# the server should still work afterwards
|
# the server should still work afterwards
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
model=model_name,
|
model=model_name, messages=messages, max_completion_tokens=10, stream=False
|
||||||
messages=messages,
|
)
|
||||||
max_completion_tokens=10,
|
|
||||||
stream=False)
|
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
assert message.content is not None and len(message.content) >= 0
|
assert message.content is not None and len(message.content) >= 0
|
||||||
|
|
||||||
@@ -193,27 +186,20 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI,
|
|||||||
"model_name, prompt_logprobs",
|
"model_name, prompt_logprobs",
|
||||||
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
|
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
|
||||||
)
|
)
|
||||||
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
async def test_prompt_logprobs_chat(
|
||||||
model_name: str,
|
client: openai.AsyncOpenAI, model_name: str, prompt_logprobs: Optional[int]
|
||||||
prompt_logprobs: Optional[int]):
|
):
|
||||||
params: dict = {
|
params: dict = {
|
||||||
"messages": [{
|
"messages": [
|
||||||
"role": "system",
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
"content": "You are a helpful assistant."
|
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||||
}, {
|
{
|
||||||
"role": "user",
|
"role": "assistant",
|
||||||
"content": "Who won the world series in 2020?"
|
"content": "The Los Angeles Dodgers won the World Series in 2020.",
|
||||||
}, {
|
},
|
||||||
"role":
|
{"role": "user", "content": "Where was it played?"},
|
||||||
"assistant",
|
],
|
||||||
"content":
|
"model": model_name,
|
||||||
"The Los Angeles Dodgers won the World Series in 2020."
|
|
||||||
}, {
|
|
||||||
"role": "user",
|
|
||||||
"content": "Where was it played?"
|
|
||||||
}],
|
|
||||||
"model":
|
|
||||||
model_name
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if prompt_logprobs is not None:
|
if prompt_logprobs is not None:
|
||||||
@@ -236,29 +222,21 @@ async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
|||||||
"model_name",
|
"model_name",
|
||||||
[MODEL_NAME],
|
[MODEL_NAME],
|
||||||
)
|
)
|
||||||
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
async def test_more_than_one_prompt_logprobs_chat(
|
||||||
model_name: str):
|
client: openai.AsyncOpenAI, model_name: str
|
||||||
|
):
|
||||||
params: dict = {
|
params: dict = {
|
||||||
"messages": [{
|
"messages": [
|
||||||
"role": "system",
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
"content": "You are a helpful assistant."
|
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||||
}, {
|
{
|
||||||
"role": "user",
|
"role": "assistant",
|
||||||
"content": "Who won the world series in 2020?"
|
"content": "The Los Angeles Dodgers won the World Series in 2020.",
|
||||||
}, {
|
},
|
||||||
"role":
|
{"role": "user", "content": "Where was it played?"},
|
||||||
"assistant",
|
],
|
||||||
"content":
|
"model": model_name,
|
||||||
"The Los Angeles Dodgers won the World Series in 2020."
|
"extra_body": {"prompt_logprobs": 1},
|
||||||
}, {
|
|
||||||
"role": "user",
|
|
||||||
"content": "Where was it played?"
|
|
||||||
}],
|
|
||||||
"model":
|
|
||||||
model_name,
|
|
||||||
"extra_body": {
|
|
||||||
"prompt_logprobs": 1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
completion_1 = await client.chat.completions.create(**params)
|
completion_1 = await client.chat.completions.create(**params)
|
||||||
@@ -275,15 +253,11 @@ async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
|||||||
"model_name",
|
"model_name",
|
||||||
[MODEL_NAME, "zephyr-lora"],
|
[MODEL_NAME, "zephyr-lora"],
|
||||||
)
|
)
|
||||||
async def test_single_chat_session(client: openai.AsyncOpenAI,
|
async def test_single_chat_session(client: openai.AsyncOpenAI, model_name: str):
|
||||||
model_name: str):
|
messages = [
|
||||||
messages = [{
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
"role": "system",
|
{"role": "user", "content": "what is 1+1?"},
|
||||||
"content": "you are a helpful assistant"
|
]
|
||||||
}, {
|
|
||||||
"role": "user",
|
|
||||||
"content": "what is 1+1?"
|
|
||||||
}]
|
|
||||||
|
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
@@ -291,14 +265,16 @@ async def test_single_chat_session(client: openai.AsyncOpenAI,
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
top_logprobs=5)
|
top_logprobs=5,
|
||||||
|
)
|
||||||
assert chat_completion.id is not None
|
assert chat_completion.id is not None
|
||||||
assert len(chat_completion.choices) == 1
|
assert len(chat_completion.choices) == 1
|
||||||
|
|
||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=10, prompt_tokens=37, total_tokens=47)
|
completion_tokens=10, prompt_tokens=37, total_tokens=47
|
||||||
|
)
|
||||||
|
|
||||||
message = choice.message
|
message = choice.message
|
||||||
assert message.content is not None and len(message.content) >= 10
|
assert message.content is not None and len(message.content) >= 10
|
||||||
@@ -323,13 +299,10 @@ async def test_single_chat_session(client: openai.AsyncOpenAI,
|
|||||||
[MODEL_NAME, "zephyr-lora"],
|
[MODEL_NAME, "zephyr-lora"],
|
||||||
)
|
)
|
||||||
async def test_chat_streaming(client: openai.AsyncOpenAI, model_name: str):
|
async def test_chat_streaming(client: openai.AsyncOpenAI, model_name: str):
|
||||||
messages = [{
|
messages = [
|
||||||
"role": "system",
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
"content": "you are a helpful assistant"
|
{"role": "user", "content": "what is 1+1?"},
|
||||||
}, {
|
]
|
||||||
"role": "user",
|
|
||||||
"content": "what is 1+1?"
|
|
||||||
}]
|
|
||||||
|
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
@@ -371,15 +344,13 @@ async def test_chat_streaming(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
"model_name",
|
"model_name",
|
||||||
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
|
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
|
||||||
)
|
)
|
||||||
async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
|
async def test_chat_completion_stream_options(
|
||||||
model_name: str):
|
client: openai.AsyncOpenAI, model_name: str
|
||||||
messages = [{
|
):
|
||||||
"role": "system",
|
messages = [
|
||||||
"content": "You are a helpful assistant."
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
}, {
|
{"role": "user", "content": "What is the capital of France?"},
|
||||||
"role": "user",
|
]
|
||||||
"content": "What is the capital of France?"
|
|
||||||
}]
|
|
||||||
|
|
||||||
# Test stream=True, stream_options={"include_usage": False}
|
# Test stream=True, stream_options={"include_usage": False}
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
@@ -388,23 +359,21 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
|
|||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
stream=True,
|
stream=True,
|
||||||
stream_options={"include_usage": False})
|
stream_options={"include_usage": False},
|
||||||
|
)
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
assert chunk.usage is None
|
assert chunk.usage is None
|
||||||
|
|
||||||
# Test stream=True, stream_options={"include_usage": True,
|
# Test stream=True, stream_options={"include_usage": True,
|
||||||
# "continuous_usage_stats": False}}
|
# "continuous_usage_stats": False}}
|
||||||
stream = await client.chat.completions.create(model=model_name,
|
stream = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
stream=True,
|
stream=True,
|
||||||
stream_options={
|
stream_options={"include_usage": True, "continuous_usage_stats": False},
|
||||||
"include_usage":
|
)
|
||||||
True,
|
|
||||||
"continuous_usage_stats":
|
|
||||||
False
|
|
||||||
})
|
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
if chunk.choices[0].finish_reason is None:
|
if chunk.choices[0].finish_reason is None:
|
||||||
@@ -416,8 +385,8 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
|
|||||||
assert final_chunk.usage.prompt_tokens > 0
|
assert final_chunk.usage.prompt_tokens > 0
|
||||||
assert final_chunk.usage.completion_tokens > 0
|
assert final_chunk.usage.completion_tokens > 0
|
||||||
assert final_chunk.usage.total_tokens == (
|
assert final_chunk.usage.total_tokens == (
|
||||||
final_chunk.usage.prompt_tokens +
|
final_chunk.usage.prompt_tokens + final_chunk.usage.completion_tokens
|
||||||
final_chunk.usage.completion_tokens)
|
)
|
||||||
assert final_chunk.choices == []
|
assert final_chunk.choices == []
|
||||||
|
|
||||||
# Test stream=False, stream_options={"include_usage": None}
|
# Test stream=False, stream_options={"include_usage": None}
|
||||||
@@ -428,7 +397,8 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
|
|||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
stream=False,
|
stream=False,
|
||||||
stream_options={"include_usage": None})
|
stream_options={"include_usage": None},
|
||||||
|
)
|
||||||
|
|
||||||
# Test stream=False, stream_options={"include_usage": True}
|
# Test stream=False, stream_options={"include_usage": True}
|
||||||
with pytest.raises(BadRequestError):
|
with pytest.raises(BadRequestError):
|
||||||
@@ -438,7 +408,8 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
|
|||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
stream=False,
|
stream=False,
|
||||||
stream_options={"include_usage": True})
|
stream_options={"include_usage": True},
|
||||||
|
)
|
||||||
|
|
||||||
# Test stream=True, stream_options={"include_usage": True,
|
# Test stream=True, stream_options={"include_usage": True,
|
||||||
# "continuous_usage_stats": True}
|
# "continuous_usage_stats": True}
|
||||||
@@ -457,14 +428,17 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
|
|||||||
last_completion_tokens = 0
|
last_completion_tokens = 0
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
assert chunk.usage.prompt_tokens >= 0
|
assert chunk.usage.prompt_tokens >= 0
|
||||||
assert last_completion_tokens == 0 or \
|
assert (
|
||||||
chunk.usage.completion_tokens > last_completion_tokens or \
|
last_completion_tokens == 0
|
||||||
(
|
or chunk.usage.completion_tokens > last_completion_tokens
|
||||||
not chunk.choices and
|
or (
|
||||||
chunk.usage.completion_tokens == last_completion_tokens
|
not chunk.choices
|
||||||
|
and chunk.usage.completion_tokens == last_completion_tokens
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert chunk.usage.total_tokens == (
|
||||||
|
chunk.usage.prompt_tokens + chunk.usage.completion_tokens
|
||||||
)
|
)
|
||||||
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
|
|
||||||
chunk.usage.completion_tokens)
|
|
||||||
last_completion_tokens = chunk.usage.completion_tokens
|
last_completion_tokens = chunk.usage.completion_tokens
|
||||||
|
|
||||||
assert last_completion_tokens == 10
|
assert last_completion_tokens == 10
|
||||||
@@ -475,37 +449,36 @@ async def test_structured_outputs_choice_chat(
|
|||||||
client: openai.AsyncOpenAI,
|
client: openai.AsyncOpenAI,
|
||||||
sample_structured_outputs_choices,
|
sample_structured_outputs_choices,
|
||||||
):
|
):
|
||||||
messages = [{
|
messages = [
|
||||||
"role": "system",
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
"content": "you are a helpful assistant"
|
{
|
||||||
}, {
|
"role": "user",
|
||||||
"role":
|
"content": "The best language for type-safe systems programming is ",
|
||||||
"user",
|
},
|
||||||
"content":
|
]
|
||||||
"The best language for type-safe systems programming is "
|
|
||||||
}]
|
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
extra_body=dict(
|
extra_body=dict(
|
||||||
structured_outputs={"choice": sample_structured_outputs_choices}))
|
structured_outputs={"choice": sample_structured_outputs_choices}
|
||||||
|
),
|
||||||
|
)
|
||||||
choice1 = chat_completion.choices[0].message.content
|
choice1 = chat_completion.choices[0].message.content
|
||||||
assert choice1 in sample_structured_outputs_choices
|
assert choice1 in sample_structured_outputs_choices
|
||||||
|
|
||||||
messages.append({"role": "assistant", "content": choice1})
|
messages.append({"role": "assistant", "content": choice1})
|
||||||
messages.append({
|
messages.append({"role": "user", "content": "I disagree, pick another one"})
|
||||||
"role": "user",
|
|
||||||
"content": "I disagree, pick another one"
|
|
||||||
})
|
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
extra_body=dict(
|
extra_body=dict(
|
||||||
structured_outputs={"choice": sample_structured_outputs_choices}))
|
structured_outputs={"choice": sample_structured_outputs_choices}
|
||||||
|
),
|
||||||
|
)
|
||||||
choice2 = chat_completion.choices[0].message.content
|
choice2 = chat_completion.choices[0].message.content
|
||||||
assert choice2 in sample_structured_outputs_choices
|
assert choice2 in sample_structured_outputs_choices
|
||||||
assert choice1 != choice2
|
assert choice1 != choice2
|
||||||
@@ -516,38 +489,35 @@ async def test_structured_outputs_json_chat(
|
|||||||
client: openai.AsyncOpenAI,
|
client: openai.AsyncOpenAI,
|
||||||
sample_json_schema,
|
sample_json_schema,
|
||||||
):
|
):
|
||||||
messages = [{
|
messages = [
|
||||||
"role": "system",
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
"content": "you are a helpful assistant"
|
{
|
||||||
}, {
|
"role": "user",
|
||||||
"role":
|
"content": f"Give an example JSON for an employee profile that "
|
||||||
"user",
|
f"fits this schema: {sample_json_schema}",
|
||||||
"content":
|
},
|
||||||
f"Give an example JSON for an employee profile that "
|
]
|
||||||
f"fits this schema: {sample_json_schema}"
|
|
||||||
}]
|
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=1000,
|
max_completion_tokens=1000,
|
||||||
extra_body=dict(structured_outputs={"json": sample_json_schema}))
|
extra_body=dict(structured_outputs={"json": sample_json_schema}),
|
||||||
|
)
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
assert message.content is not None
|
assert message.content is not None
|
||||||
json1 = json.loads(message.content)
|
json1 = json.loads(message.content)
|
||||||
jsonschema.validate(instance=json1, schema=sample_json_schema)
|
jsonschema.validate(instance=json1, schema=sample_json_schema)
|
||||||
|
|
||||||
messages.append({"role": "assistant", "content": message.content})
|
messages.append({"role": "assistant", "content": message.content})
|
||||||
messages.append({
|
messages.append(
|
||||||
"role":
|
{"role": "user", "content": "Give me another one with a different name and age"}
|
||||||
"user",
|
)
|
||||||
"content":
|
|
||||||
"Give me another one with a different name and age"
|
|
||||||
})
|
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=1000,
|
max_completion_tokens=1000,
|
||||||
extra_body=dict(structured_outputs={"json": sample_json_schema}))
|
extra_body=dict(structured_outputs={"json": sample_json_schema}),
|
||||||
|
)
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
assert message.content is not None
|
assert message.content is not None
|
||||||
json2 = json.loads(message.content)
|
json2 = json.loads(message.content)
|
||||||
@@ -561,21 +531,19 @@ async def test_structured_outputs_regex_chat(
|
|||||||
client: openai.AsyncOpenAI,
|
client: openai.AsyncOpenAI,
|
||||||
sample_regex,
|
sample_regex,
|
||||||
):
|
):
|
||||||
|
messages = [
|
||||||
messages = [{
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
"role": "system",
|
{
|
||||||
"content": "you are a helpful assistant"
|
"role": "user",
|
||||||
}, {
|
"content": f"Give an example IP address with this regex: {sample_regex}",
|
||||||
"role":
|
},
|
||||||
"user",
|
]
|
||||||
"content":
|
|
||||||
f"Give an example IP address with this regex: {sample_regex}"
|
|
||||||
}]
|
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=20,
|
max_completion_tokens=20,
|
||||||
extra_body=dict(structured_outputs={"regex": sample_regex}))
|
extra_body=dict(structured_outputs={"regex": sample_regex}),
|
||||||
|
)
|
||||||
ip1 = chat_completion.choices[0].message.content
|
ip1 = chat_completion.choices[0].message.content
|
||||||
assert ip1 is not None
|
assert ip1 is not None
|
||||||
assert re.fullmatch(sample_regex, ip1) is not None
|
assert re.fullmatch(sample_regex, ip1) is not None
|
||||||
@@ -586,7 +554,8 @@ async def test_structured_outputs_regex_chat(
|
|||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=20,
|
max_completion_tokens=20,
|
||||||
extra_body=dict(structured_outputs={"regex": sample_regex}))
|
extra_body=dict(structured_outputs={"regex": sample_regex}),
|
||||||
|
)
|
||||||
ip2 = chat_completion.choices[0].message.content
|
ip2 = chat_completion.choices[0].message.content
|
||||||
assert ip2 is not None
|
assert ip2 is not None
|
||||||
assert re.fullmatch(sample_regex, ip2) is not None
|
assert re.fullmatch(sample_regex, ip2) is not None
|
||||||
@@ -595,40 +564,33 @@ async def test_structured_outputs_regex_chat(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_structured_outputs_type_error(client: openai.AsyncOpenAI):
|
async def test_structured_outputs_type_error(client: openai.AsyncOpenAI):
|
||||||
messages = [{
|
messages = [
|
||||||
"role": "system",
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
"content": "you are a helpful assistant"
|
{
|
||||||
}, {
|
"role": "user",
|
||||||
"role":
|
"content": "The best language for type-safe systems programming is ",
|
||||||
"user",
|
},
|
||||||
"content":
|
]
|
||||||
"The best language for type-safe systems programming is "
|
|
||||||
}]
|
|
||||||
|
|
||||||
with pytest.raises(openai.BadRequestError):
|
with pytest.raises(openai.BadRequestError):
|
||||||
_ = await client.chat.completions.create(
|
_ = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
extra_body=dict(
|
extra_body=dict(structured_outputs={"regex": {1: "Python", 2: "C++"}}),
|
||||||
structured_outputs={"regex": {
|
)
|
||||||
1: "Python",
|
|
||||||
2: "C++"
|
|
||||||
}}))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_structured_outputs_choice_chat_logprobs(
|
async def test_structured_outputs_choice_chat_logprobs(
|
||||||
client: openai.AsyncOpenAI, sample_structured_outputs_choices):
|
client: openai.AsyncOpenAI, sample_structured_outputs_choices
|
||||||
|
):
|
||||||
messages = [{
|
messages = [
|
||||||
"role": "system",
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
"content": "you are a helpful assistant"
|
{
|
||||||
}, {
|
"role": "user",
|
||||||
"role":
|
"content": "The best language for type-safe systems programming is ",
|
||||||
"user",
|
},
|
||||||
"content":
|
]
|
||||||
"The best language for type-safe systems programming is "
|
|
||||||
}]
|
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@@ -636,7 +598,9 @@ async def test_structured_outputs_choice_chat_logprobs(
|
|||||||
logprobs=True,
|
logprobs=True,
|
||||||
top_logprobs=5,
|
top_logprobs=5,
|
||||||
extra_body=dict(
|
extra_body=dict(
|
||||||
structured_outputs={"choice": sample_structured_outputs_choices}))
|
structured_outputs={"choice": sample_structured_outputs_choices}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
assert chat_completion.choices[0].logprobs is not None
|
assert chat_completion.choices[0].logprobs is not None
|
||||||
assert chat_completion.choices[0].logprobs.content is not None
|
assert chat_completion.choices[0].logprobs.content is not None
|
||||||
@@ -652,29 +616,26 @@ async def test_named_tool_use(
|
|||||||
client: openai.AsyncOpenAI,
|
client: openai.AsyncOpenAI,
|
||||||
sample_json_schema,
|
sample_json_schema,
|
||||||
):
|
):
|
||||||
messages = [{
|
messages = [
|
||||||
"role": "system",
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
"content": "you are a helpful assistant"
|
{
|
||||||
}, {
|
"role": "user",
|
||||||
"role":
|
"content": (
|
||||||
"user",
|
"Give an example JSON for an employee profile using the specified tool."
|
||||||
"content": ("Give an example JSON for an employee "
|
),
|
||||||
"profile using the specified tool.")
|
},
|
||||||
}]
|
]
|
||||||
tools = [{
|
tools = [
|
||||||
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "dummy_function_name",
|
"name": "dummy_function_name",
|
||||||
"description": "This is a dummy function",
|
"description": "This is a dummy function",
|
||||||
"parameters": sample_json_schema
|
"parameters": sample_json_schema,
|
||||||
}
|
},
|
||||||
}]
|
|
||||||
tool_choice = {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "dummy_function_name"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
]
|
||||||
|
tool_choice = {"type": "function", "function": {"name": "dummy_function_name"}}
|
||||||
|
|
||||||
# non-streaming
|
# non-streaming
|
||||||
|
|
||||||
@@ -692,21 +653,20 @@ async def test_named_tool_use(
|
|||||||
jsonschema.validate(instance=json1, schema=sample_json_schema)
|
jsonschema.validate(instance=json1, schema=sample_json_schema)
|
||||||
|
|
||||||
messages.append({"role": "assistant", "content": json_string})
|
messages.append({"role": "assistant", "content": json_string})
|
||||||
messages.append({
|
messages.append(
|
||||||
"role":
|
{"role": "user", "content": "Give me another one with a different name and age"}
|
||||||
"user",
|
)
|
||||||
"content":
|
|
||||||
"Give me another one with a different name and age"
|
|
||||||
})
|
|
||||||
|
|
||||||
# streaming
|
# streaming
|
||||||
|
|
||||||
stream = await client.chat.completions.create(model=MODEL_NAME,
|
stream = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=1000,
|
max_completion_tokens=1000,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
stream=True)
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
output = []
|
output = []
|
||||||
finish_reason_count = 0
|
finish_reason_count = 0
|
||||||
@@ -728,64 +688,66 @@ async def test_named_tool_use(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
|
async def test_inconsistent_tool_choice_and_tools(
|
||||||
sample_json_schema):
|
client: openai.AsyncOpenAI, sample_json_schema
|
||||||
messages = [{
|
):
|
||||||
"role": "system",
|
messages = [
|
||||||
"content": "you are a helpful assistant"
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
}, {
|
{
|
||||||
"role":
|
"role": "user",
|
||||||
"user",
|
"content": f"Give an example JSON for an employee profile that "
|
||||||
"content":
|
f"fits this schema: {sample_json_schema}",
|
||||||
f"Give an example JSON for an employee profile that "
|
},
|
||||||
f"fits this schema: {sample_json_schema}"
|
]
|
||||||
}]
|
|
||||||
|
|
||||||
with pytest.raises(openai.BadRequestError):
|
|
||||||
await client.chat.completions.create(model=MODEL_NAME,
|
|
||||||
messages=messages,
|
|
||||||
max_completion_tokens=1000,
|
|
||||||
tool_choice={
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name":
|
|
||||||
"dummy_function_name"
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
with pytest.raises(openai.BadRequestError):
|
with pytest.raises(openai.BadRequestError):
|
||||||
await client.chat.completions.create(
|
await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=1000,
|
max_completion_tokens=1000,
|
||||||
tools=[{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "dummy_function_name",
|
|
||||||
"description": "This is a dummy function",
|
|
||||||
"parameters": sample_json_schema
|
|
||||||
}
|
|
||||||
}],
|
|
||||||
tool_choice={
|
tool_choice={
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {"name": "dummy_function_name"},
|
||||||
"name": "nondefined_function_name"
|
},
|
||||||
}
|
)
|
||||||
})
|
|
||||||
with pytest.raises(openai.BadRequestError):
|
with pytest.raises(openai.BadRequestError):
|
||||||
await client.chat.completions.create(
|
await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=1000,
|
max_completion_tokens=1000,
|
||||||
tools=[{
|
tools=[
|
||||||
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "dummy_function_name",
|
"name": "dummy_function_name",
|
||||||
"description": "This is a dummy function",
|
"description": "This is a dummy function",
|
||||||
"parameters": sample_json_schema
|
"parameters": sample_json_schema,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}],
|
],
|
||||||
tool_choice={})
|
tool_choice={
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "nondefined_function_name"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_completion_tokens=1000,
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "dummy_function_name",
|
||||||
|
"description": "This is a dummy function",
|
||||||
|
"parameters": sample_json_schema,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_choice={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -793,13 +755,17 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
|
|||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
resp = await client.chat.completions.create(
|
resp = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=[{
|
messages=[
|
||||||
"role":
|
{
|
||||||
"user",
|
"role": "user",
|
||||||
"content": ('what is 1+1? please respond with a JSON object, '
|
"content": (
|
||||||
'the format is {"result": 2}')
|
"what is 1+1? please respond with a JSON object, "
|
||||||
}],
|
'the format is {"result": 2}'
|
||||||
response_format={"type": "json_object"})
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
)
|
||||||
|
|
||||||
content = resp.choices[0].message.content
|
content = resp.choices[0].message.content
|
||||||
assert content is not None
|
assert content is not None
|
||||||
@@ -815,10 +781,7 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI):
|
|||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
resp = await client.chat.completions.create(
|
resp = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=[{
|
messages=[{"role": "user", "content": prompt}],
|
||||||
"role": "user",
|
|
||||||
"content": prompt
|
|
||||||
}],
|
|
||||||
)
|
)
|
||||||
content = resp.choices[0].message.content
|
content = resp.choices[0].message.content
|
||||||
assert content is not None
|
assert content is not None
|
||||||
@@ -829,10 +792,7 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI):
|
|||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
resp = await client.chat.completions.create(
|
resp = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=[{
|
messages=[{"role": "user", "content": prompt}],
|
||||||
"role": "user",
|
|
||||||
"content": prompt
|
|
||||||
}],
|
|
||||||
response_format={
|
response_format={
|
||||||
"type": "json_schema",
|
"type": "json_schema",
|
||||||
"json_schema": {
|
"json_schema": {
|
||||||
@@ -840,13 +800,12 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI):
|
|||||||
"schema": {
|
"schema": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"result": {
|
"result": {"type": "integer"},
|
||||||
"type": "integer"
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
},
|
||||||
})
|
)
|
||||||
|
|
||||||
content = resp.choices[0].message.content
|
content = resp.choices[0].message.content
|
||||||
assert content is not None
|
assert content is not None
|
||||||
@@ -859,13 +818,16 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI):
|
|||||||
async def test_extra_fields_allowed(client: openai.AsyncOpenAI):
|
async def test_extra_fields_allowed(client: openai.AsyncOpenAI):
|
||||||
resp = await client.chat.completions.create(
|
resp = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=[{
|
messages=[
|
||||||
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "what is 1+1?",
|
"content": "what is 1+1?",
|
||||||
"extra_field": "0",
|
"extra_field": "0",
|
||||||
}], # type: ignore
|
}
|
||||||
|
], # type: ignore
|
||||||
temperature=0,
|
temperature=0,
|
||||||
seed=0)
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
content = resp.choices[0].message.content
|
content = resp.choices[0].message.content
|
||||||
assert content is not None
|
assert content is not None
|
||||||
@@ -875,18 +837,20 @@ async def test_extra_fields_allowed(client: openai.AsyncOpenAI):
|
|||||||
async def test_complex_message_content(client: openai.AsyncOpenAI):
|
async def test_complex_message_content(client: openai.AsyncOpenAI):
|
||||||
resp = await client.chat.completions.create(
|
resp = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=[{
|
messages=[
|
||||||
"role":
|
{
|
||||||
"user",
|
"role": "user",
|
||||||
"content": [{
|
"content": [
|
||||||
"type":
|
{
|
||||||
"text",
|
"type": "text",
|
||||||
"text":
|
"text": "what is 1+1? please provide the result without any other text.",
|
||||||
"what is 1+1? please provide the result without any other text."
|
}
|
||||||
}]
|
],
|
||||||
}],
|
}
|
||||||
|
],
|
||||||
temperature=0,
|
temperature=0,
|
||||||
seed=0)
|
seed=0,
|
||||||
|
)
|
||||||
content = resp.choices[0].message.content
|
content = resp.choices[0].message.content
|
||||||
assert content == "2"
|
assert content == "2"
|
||||||
|
|
||||||
@@ -898,24 +862,27 @@ async def test_custom_role(client: openai.AsyncOpenAI):
|
|||||||
|
|
||||||
resp1 = await client.chat.completions.create(
|
resp1 = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=[{
|
messages=[
|
||||||
|
{
|
||||||
"role": "my-custom-role",
|
"role": "my-custom-role",
|
||||||
"content": "what is 1+1?",
|
"content": "what is 1+1?",
|
||||||
}], # type: ignore
|
}
|
||||||
|
], # type: ignore
|
||||||
temperature=0,
|
temperature=0,
|
||||||
seed=0)
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
resp2 = await client.chat.completions.create(
|
resp2 = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=[{
|
messages=[
|
||||||
|
{
|
||||||
"role": "my-custom-role",
|
"role": "my-custom-role",
|
||||||
"content": [{
|
"content": [{"type": "text", "text": "what is 1+1?"}],
|
||||||
"type": "text",
|
}
|
||||||
"text": "what is 1+1?"
|
], # type: ignore
|
||||||
}]
|
|
||||||
}], # type: ignore
|
|
||||||
temperature=0,
|
temperature=0,
|
||||||
seed=0)
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
content1 = resp1.choices[0].message.content
|
content1 = resp1.choices[0].message.content
|
||||||
content2 = resp2.choices[0].message.content
|
content2 = resp2.choices[0].message.content
|
||||||
@@ -924,34 +891,32 @@ async def test_custom_role(client: openai.AsyncOpenAI):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_long_seed(client: openai.AsyncOpenAI):
|
async def test_long_seed(client: openai.AsyncOpenAI):
|
||||||
for seed in [
|
for seed in [torch.iinfo(torch.long).min - 1, torch.iinfo(torch.long).max + 1]:
|
||||||
torch.iinfo(torch.long).min - 1,
|
|
||||||
torch.iinfo(torch.long).max + 1
|
|
||||||
]:
|
|
||||||
with pytest.raises(BadRequestError) as exc_info:
|
with pytest.raises(BadRequestError) as exc_info:
|
||||||
await client.chat.completions.create(
|
await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=[{
|
messages=[
|
||||||
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "You are a helpful assistant.",
|
"content": "You are a helpful assistant.",
|
||||||
}],
|
}
|
||||||
|
],
|
||||||
temperature=0,
|
temperature=0,
|
||||||
seed=seed)
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
assert ("greater_than_equal" in exc_info.value.message
|
assert (
|
||||||
or "less_than_equal" in exc_info.value.message)
|
"greater_than_equal" in exc_info.value.message
|
||||||
|
or "less_than_equal" in exc_info.value.message
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invocations(server: RemoteOpenAIServer,
|
async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI):
|
||||||
client: openai.AsyncOpenAI):
|
messages = [
|
||||||
messages = [{
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
"role": "system",
|
{"role": "user", "content": "what is 1+1?"},
|
||||||
"content": "you are a helpful assistant"
|
]
|
||||||
}, {
|
|
||||||
"role": "user",
|
|
||||||
"content": "what is 1+1?"
|
|
||||||
}]
|
|
||||||
|
|
||||||
request_args = {
|
request_args = {
|
||||||
"model": MODEL_NAME,
|
"model": MODEL_NAME,
|
||||||
@@ -963,8 +928,9 @@ async def test_invocations(server: RemoteOpenAIServer,
|
|||||||
|
|
||||||
chat_completion = await client.chat.completions.create(**request_args)
|
chat_completion = await client.chat.completions.create(**request_args)
|
||||||
|
|
||||||
invocation_response = requests.post(server.url_for("invocations"),
|
invocation_response = requests.post(
|
||||||
json=request_args)
|
server.url_for("invocations"), json=request_args
|
||||||
|
)
|
||||||
invocation_response.raise_for_status()
|
invocation_response.raise_for_status()
|
||||||
|
|
||||||
chat_output = chat_completion.model_dump()
|
chat_output = chat_completion.model_dump()
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user