Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,46 +0,0 @@
|
|||||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
|
||||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
|
||||||
# following differences:
|
|
||||||
# - ruff line length is overridden to 88
|
|
||||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
line-length = 88
|
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
|
||||||
"vllm/third_party/**" = ["ALL"]
|
|
||||||
"vllm/version.py" = ["F401"]
|
|
||||||
"vllm/_version.py" = ["ALL"]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = [
|
|
||||||
# pycodestyle
|
|
||||||
"E",
|
|
||||||
# Pyflakes
|
|
||||||
"F",
|
|
||||||
# pyupgrade
|
|
||||||
"UP",
|
|
||||||
# flake8-bugbear
|
|
||||||
"B",
|
|
||||||
# flake8-simplify
|
|
||||||
"SIM",
|
|
||||||
# isort
|
|
||||||
"I",
|
|
||||||
# flake8-logging-format
|
|
||||||
"G",
|
|
||||||
]
|
|
||||||
ignore = [
|
|
||||||
# star imports
|
|
||||||
"F405", "F403",
|
|
||||||
# lambda expression assignment
|
|
||||||
"E731",
|
|
||||||
# Loop control variable not used within loop body
|
|
||||||
"B007",
|
|
||||||
# f-string format
|
|
||||||
"UP032",
|
|
||||||
# Can remove once 3.10+ is the minimum Python version
|
|
||||||
"UP007",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
|
||||||
docstring-code-format = true
|
|
||||||
@@ -6,28 +6,16 @@ default_stages:
|
|||||||
- manual # Run in CI
|
- manual # Run in CI
|
||||||
exclude: 'vllm/third_party/.*'
|
exclude: 'vllm/third_party/.*'
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/google/yapf
|
|
||||||
rev: v0.43.0
|
|
||||||
hooks:
|
|
||||||
- id: yapf
|
|
||||||
args: [--in-place, --verbose]
|
|
||||||
# Keep the same list from yapfignore here to avoid yapf failing without any inputs
|
|
||||||
exclude: '(.buildkite|benchmarks|build|examples)/.*'
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.11.7
|
rev: v0.11.7
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--output-format, github, --fix]
|
args: [--output-format, github, --fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
files: ^(.buildkite|benchmarks|examples)/.*
|
|
||||||
- repo: https://github.com/crate-ci/typos
|
- repo: https://github.com/crate-ci/typos
|
||||||
rev: v1.35.5
|
rev: v1.35.5
|
||||||
hooks:
|
hooks:
|
||||||
- id: typos
|
- id: typos
|
||||||
- repo: https://github.com/PyCQA/isort
|
|
||||||
rev: 6.0.1
|
|
||||||
hooks:
|
|
||||||
- id: isort
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v20.1.3
|
rev: v20.1.3
|
||||||
hooks:
|
hooks:
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
from benchmark_utils import TimeCollector
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
||||||
from benchmark_utils import TimeCollector
|
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
|
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ import time
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from benchmark_utils import TimeCollector
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
||||||
from benchmark_utils import TimeCollector
|
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
DeviceConfig,
|
DeviceConfig,
|
||||||
|
|||||||
@@ -37,14 +37,13 @@ from typing import Optional
|
|||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tqdm.asyncio import tqdm
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
|
||||||
|
|
||||||
from backend_request_func import (
|
from backend_request_func import (
|
||||||
ASYNC_REQUEST_FUNCS,
|
ASYNC_REQUEST_FUNCS,
|
||||||
RequestFuncInput,
|
RequestFuncInput,
|
||||||
RequestFuncOutput,
|
RequestFuncOutput,
|
||||||
)
|
)
|
||||||
|
from tqdm.asyncio import tqdm
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|||||||
@@ -1,49 +0,0 @@
|
|||||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
|
||||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
|
||||||
# following differences:
|
|
||||||
# - ruff line length is overridden to 88
|
|
||||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
line-length = 88
|
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
|
||||||
"vllm/third_party/**" = ["ALL"]
|
|
||||||
"vllm/version.py" = ["F401"]
|
|
||||||
"vllm/_version.py" = ["ALL"]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = [
|
|
||||||
# pycodestyle
|
|
||||||
"E",
|
|
||||||
# Pyflakes
|
|
||||||
"F",
|
|
||||||
# pyupgrade
|
|
||||||
"UP",
|
|
||||||
# flake8-bugbear
|
|
||||||
"B",
|
|
||||||
# flake8-simplify
|
|
||||||
"SIM",
|
|
||||||
# isort
|
|
||||||
"I",
|
|
||||||
# flake8-logging-format
|
|
||||||
"G",
|
|
||||||
]
|
|
||||||
ignore = [
|
|
||||||
# star imports
|
|
||||||
"F405", "F403",
|
|
||||||
# lambda expression assignment
|
|
||||||
"E731",
|
|
||||||
# Loop control variable not used within loop body
|
|
||||||
"B007",
|
|
||||||
# f-string format
|
|
||||||
"UP032",
|
|
||||||
# Can remove once 3.10+ is the minimum Python version
|
|
||||||
"UP007",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint.isort]
|
|
||||||
known-first-party = ["vllm"]
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
|
||||||
docstring-code-format = true
|
|
||||||
@@ -16,7 +16,7 @@ import shutil
|
|||||||
|
|
||||||
from torch.utils.hipify.hipify_python import hipify
|
from torch.utils.hipify.hipify_python import hipify
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
# Project directory where all the source + include files live.
|
# Project directory where all the source + include files live.
|
||||||
@@ -34,15 +34,14 @@ if __name__ == '__main__':
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Source files to convert.
|
# Source files to convert.
|
||||||
parser.add_argument("sources",
|
parser.add_argument(
|
||||||
help="Source files to hipify.",
|
"sources", help="Source files to hipify.", nargs="*", default=[]
|
||||||
nargs="*",
|
)
|
||||||
default=[])
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Limit include scope to project_dir only
|
# Limit include scope to project_dir only
|
||||||
includes = [os.path.join(args.project_dir, '*')]
|
includes = [os.path.join(args.project_dir, "*")]
|
||||||
|
|
||||||
# Get absolute path for all source files.
|
# Get absolute path for all source files.
|
||||||
extra_files = [os.path.abspath(s) for s in args.sources]
|
extra_files = [os.path.abspath(s) for s in args.sources]
|
||||||
@@ -51,25 +50,31 @@ if __name__ == '__main__':
|
|||||||
# The directory might already exist to hold object files so we ignore that.
|
# The directory might already exist to hold object files so we ignore that.
|
||||||
shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
|
shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
|
||||||
|
|
||||||
hipify_result = hipify(project_directory=args.project_dir,
|
hipify_result = hipify(
|
||||||
output_directory=args.output_dir,
|
project_directory=args.project_dir,
|
||||||
header_include_dirs=[],
|
output_directory=args.output_dir,
|
||||||
includes=includes,
|
header_include_dirs=[],
|
||||||
extra_files=extra_files,
|
includes=includes,
|
||||||
show_detailed=True,
|
extra_files=extra_files,
|
||||||
is_pytorch_extension=True,
|
show_detailed=True,
|
||||||
hipify_extra_files_only=True)
|
is_pytorch_extension=True,
|
||||||
|
hipify_extra_files_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
hipified_sources = []
|
hipified_sources = []
|
||||||
for source in args.sources:
|
for source in args.sources:
|
||||||
s_abs = os.path.abspath(source)
|
s_abs = os.path.abspath(source)
|
||||||
hipified_s_abs = (hipify_result[s_abs].hipified_path if
|
hipified_s_abs = (
|
||||||
(s_abs in hipify_result
|
hipify_result[s_abs].hipified_path
|
||||||
and hipify_result[s_abs].hipified_path is not None)
|
if (
|
||||||
else s_abs)
|
s_abs in hipify_result
|
||||||
|
and hipify_result[s_abs].hipified_path is not None
|
||||||
|
)
|
||||||
|
else s_abs
|
||||||
|
)
|
||||||
hipified_sources.append(hipified_s_abs)
|
hipified_sources.append(hipified_s_abs)
|
||||||
|
|
||||||
assert (len(hipified_sources) == len(args.sources))
|
assert len(hipified_sources) == len(args.sources)
|
||||||
|
|
||||||
# Print hipified source files.
|
# Print hipified source files.
|
||||||
print("\n".join(hipified_sources))
|
print("\n".join(hipified_sources))
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = {
|
|||||||
**{
|
**{
|
||||||
VLLMDataType.u4b8: "u4b8",
|
VLLMDataType.u4b8: "u4b8",
|
||||||
VLLMDataType.u8b128: "u8b128",
|
VLLMDataType.u8b128: "u8b128",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||||
@@ -35,7 +35,7 @@ VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
|||||||
**{
|
**{
|
||||||
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
|
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
|
||||||
VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
|
VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
|
VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
|
||||||
@@ -43,7 +43,7 @@ VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
|
|||||||
**{
|
**{
|
||||||
VLLMDataType.u4b8: 4,
|
VLLMDataType.u4b8: 4,
|
||||||
VLLMDataType.u8b128: 8,
|
VLLMDataType.u8b128: 8,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||||
@@ -67,15 +67,13 @@ VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
|||||||
DataType.f32: "at::ScalarType::Float",
|
DataType.f32: "at::ScalarType::Float",
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMKernelScheduleTag: dict[Union[
|
VLLMKernelScheduleTag: dict[
|
||||||
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
Union[MixedInputKernelScheduleType, KernelScheduleType], str
|
||||||
**KernelScheduleTag, # type: ignore
|
] = {
|
||||||
**{
|
**KernelScheduleTag, # type: ignore
|
||||||
MixedInputKernelScheduleType.TmaWarpSpecialized:
|
**{
|
||||||
"cutlass::gemm::KernelTmaWarpSpecialized",
|
MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized",
|
||||||
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong:
|
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong",
|
||||||
"cutlass::gemm::KernelTmaWarpSpecializedPingpong",
|
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative",
|
||||||
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative:
|
},
|
||||||
"cutlass::gemm::KernelTmaWarpSpecializedCooperative",
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -17,25 +17,30 @@ FILE_HEAD = """
|
|||||||
namespace MARLIN_NAMESPACE_NAME {
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
TEMPLATE = ("template __global__ void Marlin<"
|
TEMPLATE = (
|
||||||
"{{scalar_t}}, "
|
"template __global__ void Marlin<"
|
||||||
"{{w_type_id}}, "
|
"{{scalar_t}}, "
|
||||||
"{{s_type_id}}, "
|
"{{w_type_id}}, "
|
||||||
"{{threads}}, "
|
"{{s_type_id}}, "
|
||||||
"{{thread_m_blocks}}, "
|
"{{threads}}, "
|
||||||
"{{thread_n_blocks}}, "
|
"{{thread_m_blocks}}, "
|
||||||
"{{thread_k_blocks}}, "
|
"{{thread_n_blocks}}, "
|
||||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
"{{thread_k_blocks}}, "
|
||||||
"{{stages}}, "
|
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||||
"{{group_blocks}}, "
|
"{{stages}}, "
|
||||||
"{{'true' if is_zp_float else 'false'}}>"
|
"{{group_blocks}}, "
|
||||||
"( MARLIN_KERNEL_PARAMS );")
|
"{{'true' if is_zp_float else 'false'}}>"
|
||||||
|
"( MARLIN_KERNEL_PARAMS );"
|
||||||
|
)
|
||||||
|
|
||||||
# int8 with zero point case (vllm::kU8) is also supported,
|
# int8 with zero point case (vllm::kU8) is also supported,
|
||||||
# we don't add it to reduce wheel size.
|
# we don't add it to reduce wheel size.
|
||||||
SCALAR_TYPES = [
|
SCALAR_TYPES = [
|
||||||
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
|
"vllm::kU4",
|
||||||
"vllm::kFE2M1f"
|
"vllm::kU4B8",
|
||||||
|
"vllm::kU8B128",
|
||||||
|
"vllm::kFE4M3fn",
|
||||||
|
"vllm::kFE2M1f",
|
||||||
]
|
]
|
||||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||||
|
|
||||||
@@ -58,11 +63,12 @@ def generate_new_kernels():
|
|||||||
all_template_str_list = []
|
all_template_str_list = []
|
||||||
|
|
||||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
|
||||||
|
):
|
||||||
# act order case only support gptq-int4 and gptq-int8
|
# act order case only support gptq-int4 and gptq-int8
|
||||||
if group_blocks == 0 and scalar_type not in [
|
if group_blocks == 0 and scalar_type not in [
|
||||||
"vllm::kU4B8", "vllm::kU8B128"
|
"vllm::kU4B8",
|
||||||
|
"vllm::kU8B128",
|
||||||
]:
|
]:
|
||||||
continue
|
continue
|
||||||
if thread_configs[2] == 256:
|
if thread_configs[2] == 256:
|
||||||
|
|||||||
@@ -17,28 +17,32 @@ FILE_HEAD = """
|
|||||||
namespace MARLIN_NAMESPACE_NAME {
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
TEMPLATE = ("template __global__ void Marlin<"
|
TEMPLATE = (
|
||||||
"{{scalar_t}}, "
|
"template __global__ void Marlin<"
|
||||||
"{{w_type_id}}, "
|
"{{scalar_t}}, "
|
||||||
"{{s_type_id}}, "
|
"{{w_type_id}}, "
|
||||||
"{{threads}}, "
|
"{{s_type_id}}, "
|
||||||
"{{thread_m_blocks}}, "
|
"{{threads}}, "
|
||||||
"{{thread_n_blocks}}, "
|
"{{thread_m_blocks}}, "
|
||||||
"{{thread_k_blocks}}, "
|
"{{thread_n_blocks}}, "
|
||||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
"{{thread_k_blocks}}, "
|
||||||
"{{stages}}, "
|
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||||
"{{group_blocks}}, "
|
"{{stages}}, "
|
||||||
"{{'true' if is_zp_float else 'false'}}>"
|
"{{group_blocks}}, "
|
||||||
"( MARLIN_KERNEL_PARAMS );")
|
"{{'true' if is_zp_float else 'false'}}>"
|
||||||
|
"( MARLIN_KERNEL_PARAMS );"
|
||||||
|
)
|
||||||
|
|
||||||
# int8 with zero point case (vllm::kU8) is also supported,
|
# int8 with zero point case (vllm::kU8) is also supported,
|
||||||
# we don't add it to reduce wheel size.
|
# we don't add it to reduce wheel size.
|
||||||
SCALAR_TYPES = [
|
SCALAR_TYPES = [
|
||||||
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
|
"vllm::kU4",
|
||||||
"vllm::kFE2M1f"
|
"vllm::kU4B8",
|
||||||
|
"vllm::kU8B128",
|
||||||
|
"vllm::kFE4M3fn",
|
||||||
|
"vllm::kFE2M1f",
|
||||||
]
|
]
|
||||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
|
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
|
||||||
(128, 64, 128)]
|
|
||||||
|
|
||||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||||
# group_blocks:
|
# group_blocks:
|
||||||
@@ -59,11 +63,12 @@ def generate_new_kernels():
|
|||||||
all_template_str_list = []
|
all_template_str_list = []
|
||||||
|
|
||||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
|
||||||
|
):
|
||||||
# act order case only support gptq-int4 and gptq-int8
|
# act order case only support gptq-int4 and gptq-int8
|
||||||
if group_blocks == 0 and scalar_type not in [
|
if group_blocks == 0 and scalar_type not in [
|
||||||
"vllm::kU4B8", "vllm::kU8B128"
|
"vllm::kU4B8",
|
||||||
|
"vllm::kU8B128",
|
||||||
]:
|
]:
|
||||||
continue
|
continue
|
||||||
if thread_configs[2] == 256:
|
if thread_configs[2] == 256:
|
||||||
@@ -93,8 +98,7 @@ def generate_new_kernels():
|
|||||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||||
|
|
||||||
is_zp_float_list = [False]
|
is_zp_float_list = [False]
|
||||||
if dtype == "fp16" and scalar_type == "vllm::kU4" and \
|
if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4:
|
||||||
group_blocks == 4:
|
|
||||||
# HQQ (is_zp_float = true) only supports
|
# HQQ (is_zp_float = true) only supports
|
||||||
# 4bit quantization and fp16
|
# 4bit quantization and fp16
|
||||||
is_zp_float_list.append(True)
|
is_zp_float_list.append(True)
|
||||||
|
|||||||
@@ -12,18 +12,24 @@ from functools import reduce
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
|
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag,
|
from vllm_cutlass_library_extension import (
|
||||||
EpilogueScheduleType,
|
DataType,
|
||||||
MixedInputKernelScheduleType,
|
EpilogueScheduleTag,
|
||||||
TileSchedulerTag,
|
EpilogueScheduleType,
|
||||||
TileSchedulerType, VLLMDataType,
|
MixedInputKernelScheduleType,
|
||||||
VLLMDataTypeNames,
|
TileSchedulerTag,
|
||||||
VLLMDataTypeSize, VLLMDataTypeTag,
|
TileSchedulerType,
|
||||||
VLLMDataTypeTorchDataTypeTag,
|
VLLMDataType,
|
||||||
VLLMDataTypeVLLMScalarTypeTag,
|
VLLMDataTypeNames,
|
||||||
VLLMKernelScheduleTag)
|
VLLMDataTypeSize,
|
||||||
|
VLLMDataTypeTag,
|
||||||
|
VLLMDataTypeTorchDataTypeTag,
|
||||||
|
VLLMDataTypeVLLMScalarTypeTag,
|
||||||
|
VLLMKernelScheduleTag,
|
||||||
|
)
|
||||||
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
|
|
||||||
@@ -286,18 +292,23 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
|||||||
tile_shape = (
|
tile_shape = (
|
||||||
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
||||||
)
|
)
|
||||||
cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" +
|
cluster_shape = (
|
||||||
f"x{schedule_config.cluster_shape_mnk[1]}" +
|
f"{schedule_config.cluster_shape_mnk[0]}"
|
||||||
f"x{schedule_config.cluster_shape_mnk[2]}")
|
+ f"x{schedule_config.cluster_shape_mnk[1]}"
|
||||||
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\
|
+ f"x{schedule_config.cluster_shape_mnk[2]}"
|
||||||
.split("::")[-1]
|
)
|
||||||
epilogue_schedule = EpilogueScheduleTag[
|
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule].split(
|
||||||
schedule_config.epilogue_schedule].split("::")[-1]
|
"::"
|
||||||
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\
|
)[-1]
|
||||||
.split("::")[-1]
|
epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split(
|
||||||
|
"::"
|
||||||
|
)[-1]
|
||||||
|
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1]
|
||||||
|
|
||||||
return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" +
|
return (
|
||||||
f"_{epilogue_schedule}_{tile_scheduler}")
|
f"{tile_shape}_{cluster_shape}_{kernel_schedule}"
|
||||||
|
+ f"_{epilogue_schedule}_{tile_scheduler}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# mostly unique shorter sch_sig
|
# mostly unique shorter sch_sig
|
||||||
@@ -316,18 +327,24 @@ def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
|||||||
|
|
||||||
# unique type_name
|
# unique type_name
|
||||||
def generate_type_signature(kernel_types: TypeConfig):
|
def generate_type_signature(kernel_types: TypeConfig):
|
||||||
return str("".join([
|
return str(
|
||||||
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
"".join(
|
||||||
for field in fields(TypeConfig)
|
[
|
||||||
]))
|
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||||
|
for field in fields(TypeConfig)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_type_option_name(kernel_types: TypeConfig):
|
def generate_type_option_name(kernel_types: TypeConfig):
|
||||||
return ", ".join([
|
return ", ".join(
|
||||||
f"{field.name.replace('b_', 'with_')+'_type'}=" +
|
[
|
||||||
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
f"{field.name.replace('b_', 'with_') + '_type'}="
|
||||||
for field in fields(TypeConfig)
|
+ VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||||
])
|
for field in fields(TypeConfig)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_power_of_two(n):
|
def is_power_of_two(n):
|
||||||
@@ -335,7 +352,6 @@ def is_power_of_two(n):
|
|||||||
|
|
||||||
|
|
||||||
def to_cute_constant(value: list[int]):
|
def to_cute_constant(value: list[int]):
|
||||||
|
|
||||||
def _to_cute_constant(value: int):
|
def _to_cute_constant(value: int):
|
||||||
if is_power_of_two(value):
|
if is_power_of_two(value):
|
||||||
return f"_{value}"
|
return f"_{value}"
|
||||||
@@ -350,11 +366,11 @@ def to_cute_constant(value: list[int]):
|
|||||||
|
|
||||||
def unique_schedules(impl_configs: list[ImplConfig]):
|
def unique_schedules(impl_configs: list[ImplConfig]):
|
||||||
# Use dict over set for deterministic ordering
|
# Use dict over set for deterministic ordering
|
||||||
return list({
|
return list(
|
||||||
sch: None
|
{
|
||||||
for impl_config in impl_configs
|
sch: None for impl_config in impl_configs for sch in impl_config.schedules
|
||||||
for sch in impl_config.schedules
|
}.keys()
|
||||||
}.keys())
|
)
|
||||||
|
|
||||||
|
|
||||||
def unsigned_type_with_bitwidth(num_bits):
|
def unsigned_type_with_bitwidth(num_bits):
|
||||||
@@ -380,7 +396,7 @@ template_globals = {
|
|||||||
"gen_type_sig": generate_type_signature,
|
"gen_type_sig": generate_type_signature,
|
||||||
"unique_schedules": unique_schedules,
|
"unique_schedules": unique_schedules,
|
||||||
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
|
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
|
||||||
"gen_type_option_name": generate_type_option_name
|
"gen_type_option_name": generate_type_option_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -398,23 +414,28 @@ prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
|||||||
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
||||||
sources = []
|
sources = []
|
||||||
|
|
||||||
sources.append((
|
sources.append(
|
||||||
"machete_mm_dispatch",
|
(
|
||||||
mm_dispatch_template.render(impl_configs=impl_configs),
|
"machete_mm_dispatch",
|
||||||
))
|
mm_dispatch_template.render(impl_configs=impl_configs),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
prepack_types = []
|
prepack_types = []
|
||||||
for impl_config in impl_configs:
|
for impl_config in impl_configs:
|
||||||
convert_type = impl_config.types.a \
|
convert_type = (
|
||||||
if impl_config.types.b_group_scale == DataType.void \
|
impl_config.types.a
|
||||||
else impl_config.types.b_group_scale
|
if impl_config.types.b_group_scale == DataType.void
|
||||||
|
else impl_config.types.b_group_scale
|
||||||
|
)
|
||||||
prepack_types.append(
|
prepack_types.append(
|
||||||
PrepackTypeConfig(
|
PrepackTypeConfig(
|
||||||
a=impl_config.types.a,
|
a=impl_config.types.a,
|
||||||
b_num_bits=VLLMDataTypeSize[impl_config.types.b],
|
b_num_bits=VLLMDataTypeSize[impl_config.types.b],
|
||||||
convert=convert_type,
|
convert=convert_type,
|
||||||
accumulator=impl_config.types.accumulator,
|
accumulator=impl_config.types.accumulator,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def prepacked_type_key(prepack_type: PrepackTypeConfig):
|
def prepacked_type_key(prepack_type: PrepackTypeConfig):
|
||||||
# For now, we can just use the first accumulator type seen since
|
# For now, we can just use the first accumulator type seen since
|
||||||
@@ -430,10 +451,14 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
|||||||
unique_prepack_types.append(prepack_type)
|
unique_prepack_types.append(prepack_type)
|
||||||
prepack_types_seen.add(key)
|
prepack_types_seen.add(key)
|
||||||
|
|
||||||
sources.append((
|
sources.append(
|
||||||
"machete_prepack",
|
(
|
||||||
prepack_dispatch_template.render(types=unique_prepack_types, ),
|
"machete_prepack",
|
||||||
))
|
prepack_dispatch_template.render(
|
||||||
|
types=unique_prepack_types,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Split up impls across files
|
# Split up impls across files
|
||||||
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
||||||
@@ -466,10 +491,12 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
|||||||
curr_impl_in_file += len(files_impls[-1][-1].schedules)
|
curr_impl_in_file += len(files_impls[-1][-1].schedules)
|
||||||
|
|
||||||
for part, file_impls in enumerate(files_impls):
|
for part, file_impls in enumerate(files_impls):
|
||||||
sources.append((
|
sources.append(
|
||||||
f"machete_mm_impl_part{part+1}",
|
(
|
||||||
mm_impl_template.render(impl_configs=file_impls),
|
f"machete_mm_impl_part{part + 1}",
|
||||||
))
|
mm_impl_template.render(impl_configs=file_impls),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
@@ -514,8 +541,7 @@ def generate():
|
|||||||
# For now we use the same heuristic for all types
|
# For now we use the same heuristic for all types
|
||||||
# Heuristic is currently tuned for H100s
|
# Heuristic is currently tuned for H100s
|
||||||
default_heuristic = [
|
default_heuristic = [
|
||||||
(cond, ScheduleConfig(*tile_config,
|
(cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore
|
||||||
**sch_common_params)) # type: ignore
|
|
||||||
for cond, tile_config in default_tile_heuristic_config.items()
|
for cond, tile_config in default_tile_heuristic_config.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -541,14 +567,18 @@ def generate():
|
|||||||
a_token_scale=DataType.void,
|
a_token_scale=DataType.void,
|
||||||
out=a,
|
out=a,
|
||||||
accumulator=DataType.f32,
|
accumulator=DataType.f32,
|
||||||
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
)
|
||||||
for a in (DataType.f16, DataType.bf16))
|
for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
||||||
|
for a in (DataType.f16, DataType.bf16)
|
||||||
|
)
|
||||||
|
|
||||||
impl_configs += [
|
impl_configs += [
|
||||||
ImplConfig(x[0], x[1], x[2])
|
ImplConfig(x[0], x[1], x[2])
|
||||||
for x in zip(GPTQ_kernel_type_configs,
|
for x in zip(
|
||||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
GPTQ_kernel_type_configs,
|
||||||
itertools.repeat(default_heuristic))
|
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||||
|
itertools.repeat(default_heuristic),
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
AWQ_kernel_type_configs = list(
|
AWQ_kernel_type_configs = list(
|
||||||
@@ -561,14 +591,18 @@ def generate():
|
|||||||
a_token_scale=DataType.void,
|
a_token_scale=DataType.void,
|
||||||
out=a,
|
out=a,
|
||||||
accumulator=DataType.f32,
|
accumulator=DataType.f32,
|
||||||
) for b in (DataType.u4, DataType.u8)
|
)
|
||||||
for a in (DataType.f16, DataType.bf16))
|
for b in (DataType.u4, DataType.u8)
|
||||||
|
for a in (DataType.f16, DataType.bf16)
|
||||||
|
)
|
||||||
|
|
||||||
impl_configs += [
|
impl_configs += [
|
||||||
ImplConfig(x[0], x[1], x[2])
|
ImplConfig(x[0], x[1], x[2])
|
||||||
for x in zip(AWQ_kernel_type_configs,
|
for x in zip(
|
||||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
AWQ_kernel_type_configs,
|
||||||
itertools.repeat(default_heuristic))
|
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||||
|
itertools.repeat(default_heuristic),
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO: Support W4A8 when ready
|
# TODO: Support W4A8 when ready
|
||||||
|
|||||||
@@ -33,8 +33,11 @@ def auto_mock(module, attr, max_mocks=50):
|
|||||||
try:
|
try:
|
||||||
# First treat attr as an attr, then as a submodule
|
# First treat attr as an attr, then as a submodule
|
||||||
with patch("importlib.metadata.version", return_value="0.0.0"):
|
with patch("importlib.metadata.version", return_value="0.0.0"):
|
||||||
return getattr(importlib.import_module(module), attr,
|
return getattr(
|
||||||
importlib.import_module(f"{module}.{attr}"))
|
importlib.import_module(module),
|
||||||
|
attr,
|
||||||
|
importlib.import_module(f"{module}.{attr}"),
|
||||||
|
)
|
||||||
except importlib.metadata.PackageNotFoundError as e:
|
except importlib.metadata.PackageNotFoundError as e:
|
||||||
raise e
|
raise e
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
@@ -42,7 +45,8 @@ def auto_mock(module, attr, max_mocks=50):
|
|||||||
sys.modules[e.name] = PydanticMagicMock()
|
sys.modules[e.name] = PydanticMagicMock()
|
||||||
|
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Failed to import {module}.{attr} after mocking {max_mocks} imports")
|
f"Failed to import {module}.{attr} after mocking {max_mocks} imports"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
latency = auto_mock("vllm.benchmarks", "latency")
|
latency = auto_mock("vllm.benchmarks", "latency")
|
||||||
@@ -61,9 +65,7 @@ class MarkdownFormatter(HelpFormatter):
|
|||||||
"""Custom formatter that generates markdown for argument groups."""
|
"""Custom formatter that generates markdown for argument groups."""
|
||||||
|
|
||||||
def __init__(self, prog, starting_heading_level=3):
|
def __init__(self, prog, starting_heading_level=3):
|
||||||
super().__init__(prog,
|
super().__init__(prog, max_help_position=float("inf"), width=float("inf"))
|
||||||
max_help_position=float('inf'),
|
|
||||||
width=float('inf'))
|
|
||||||
self._section_heading_prefix = "#" * starting_heading_level
|
self._section_heading_prefix = "#" * starting_heading_level
|
||||||
self._argument_heading_prefix = "#" * (starting_heading_level + 1)
|
self._argument_heading_prefix = "#" * (starting_heading_level + 1)
|
||||||
self._markdown_output = []
|
self._markdown_output = []
|
||||||
@@ -85,23 +87,19 @@ class MarkdownFormatter(HelpFormatter):
|
|||||||
|
|
||||||
def add_arguments(self, actions):
|
def add_arguments(self, actions):
|
||||||
for action in actions:
|
for action in actions:
|
||||||
if (len(action.option_strings) == 0
|
if len(action.option_strings) == 0 or "--help" in action.option_strings:
|
||||||
or "--help" in action.option_strings):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
option_strings = f'`{"`, `".join(action.option_strings)}`'
|
option_strings = f"`{'`, `'.join(action.option_strings)}`"
|
||||||
heading_md = f"{self._argument_heading_prefix} {option_strings}\n\n"
|
heading_md = f"{self._argument_heading_prefix} {option_strings}\n\n"
|
||||||
self._markdown_output.append(heading_md)
|
self._markdown_output.append(heading_md)
|
||||||
|
|
||||||
if choices := action.choices:
|
if choices := action.choices:
|
||||||
choices = f'`{"`, `".join(str(c) for c in choices)}`'
|
choices = f"`{'`, `'.join(str(c) for c in choices)}`"
|
||||||
self._markdown_output.append(
|
self._markdown_output.append(f"Possible choices: {choices}\n\n")
|
||||||
f"Possible choices: {choices}\n\n")
|
elif (metavar := action.metavar) and isinstance(metavar, (list, tuple)):
|
||||||
elif ((metavar := action.metavar)
|
metavar = f"`{'`, `'.join(str(m) for m in metavar)}`"
|
||||||
and isinstance(metavar, (list, tuple))):
|
self._markdown_output.append(f"Possible choices: {metavar}\n\n")
|
||||||
metavar = f'`{"`, `".join(str(m) for m in metavar)}`'
|
|
||||||
self._markdown_output.append(
|
|
||||||
f"Possible choices: {metavar}\n\n")
|
|
||||||
|
|
||||||
if action.help:
|
if action.help:
|
||||||
self._markdown_output.append(f"{action.help}\n\n")
|
self._markdown_output.append(f"{action.help}\n\n")
|
||||||
@@ -116,7 +114,7 @@ class MarkdownFormatter(HelpFormatter):
|
|||||||
|
|
||||||
def create_parser(add_cli_args, **kwargs) -> FlexibleArgumentParser:
|
def create_parser(add_cli_args, **kwargs) -> FlexibleArgumentParser:
|
||||||
"""Create a parser for the given class with markdown formatting.
|
"""Create a parser for the given class with markdown formatting.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cls: The class to create a parser for
|
cls: The class to create a parser for
|
||||||
**kwargs: Additional keyword arguments to pass to `cls.add_cli_args`.
|
**kwargs: Additional keyword arguments to pass to `cls.add_cli_args`.
|
||||||
@@ -143,24 +141,17 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
|||||||
|
|
||||||
# Create parsers to document
|
# Create parsers to document
|
||||||
parsers = {
|
parsers = {
|
||||||
"engine_args":
|
"engine_args": create_parser(EngineArgs.add_cli_args),
|
||||||
create_parser(EngineArgs.add_cli_args),
|
"async_engine_args": create_parser(
|
||||||
"async_engine_args":
|
AsyncEngineArgs.add_cli_args, async_args_only=True
|
||||||
create_parser(AsyncEngineArgs.add_cli_args, async_args_only=True),
|
),
|
||||||
"serve":
|
"serve": create_parser(cli_args.make_arg_parser),
|
||||||
create_parser(cli_args.make_arg_parser),
|
"chat": create_parser(ChatCommand.add_cli_args),
|
||||||
"chat":
|
"complete": create_parser(CompleteCommand.add_cli_args),
|
||||||
create_parser(ChatCommand.add_cli_args),
|
"bench_latency": create_parser(latency.add_cli_args),
|
||||||
"complete":
|
"bench_throughput": create_parser(throughput.add_cli_args),
|
||||||
create_parser(CompleteCommand.add_cli_args),
|
"bench_serve": create_parser(serve.add_cli_args),
|
||||||
"bench_latency":
|
"run-batch": create_parser(run_batch.make_arg_parser),
|
||||||
create_parser(latency.add_cli_args),
|
|
||||||
"bench_throughput":
|
|
||||||
create_parser(throughput.add_cli_args),
|
|
||||||
"bench_serve":
|
|
||||||
create_parser(serve.add_cli_args),
|
|
||||||
"run-batch":
|
|
||||||
create_parser(run_batch.make_arg_parser),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generate documentation for each parser
|
# Generate documentation for each parser
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import regex as re
|
|||||||
logger = logging.getLogger("mkdocs")
|
logger = logging.getLogger("mkdocs")
|
||||||
|
|
||||||
ROOT_DIR = Path(__file__).parent.parent.parent.parent
|
ROOT_DIR = Path(__file__).parent.parent.parent.parent
|
||||||
ROOT_DIR_RELATIVE = '../../../../..'
|
ROOT_DIR_RELATIVE = "../../../../.."
|
||||||
EXAMPLE_DIR = ROOT_DIR / "examples"
|
EXAMPLE_DIR = ROOT_DIR / "examples"
|
||||||
EXAMPLE_DOC_DIR = ROOT_DIR / "docs/examples"
|
EXAMPLE_DOC_DIR = ROOT_DIR / "docs/examples"
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ def fix_case(text: str) -> str:
|
|||||||
r"int\d+": lambda x: x.group(0).upper(), # e.g. int8, int16
|
r"int\d+": lambda x: x.group(0).upper(), # e.g. int8, int16
|
||||||
}
|
}
|
||||||
for pattern, repl in subs.items():
|
for pattern, repl in subs.items():
|
||||||
text = re.sub(rf'\b{pattern}\b', repl, text, flags=re.IGNORECASE)
|
text = re.sub(rf"\b{pattern}\b", repl, text, flags=re.IGNORECASE)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
@@ -58,7 +58,8 @@ class Example:
|
|||||||
determine_other_files() -> list[Path]: Determines other files in the directory excluding the main file.
|
determine_other_files() -> list[Path]: Determines other files in the directory excluding the main file.
|
||||||
determine_title() -> str: Determines the title of the document.
|
determine_title() -> str: Determines the title of the document.
|
||||||
generate() -> str: Generates the documentation content.
|
generate() -> str: Generates the documentation content.
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
path: Path
|
path: Path
|
||||||
category: str = None
|
category: str = None
|
||||||
main_file: Path = field(init=False)
|
main_file: Path = field(init=False)
|
||||||
@@ -84,9 +85,8 @@ class Example:
|
|||||||
Markdown file found in the directory.
|
Markdown file found in the directory.
|
||||||
Raises:
|
Raises:
|
||||||
IndexError: If no Markdown files are found in the directory.
|
IndexError: If no Markdown files are found in the directory.
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
return self.path if self.path.is_file() else list(
|
return self.path if self.path.is_file() else list(self.path.glob("*.md")).pop()
|
||||||
self.path.glob("*.md")).pop()
|
|
||||||
|
|
||||||
def determine_other_files(self) -> list[Path]:
|
def determine_other_files(self) -> list[Path]:
|
||||||
"""
|
"""
|
||||||
@@ -98,7 +98,7 @@ class Example:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[Path]: A list of Path objects representing the other files in the directory.
|
list[Path]: A list of Path objects representing the other files in the directory.
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
if self.path.is_file():
|
if self.path.is_file():
|
||||||
return []
|
return []
|
||||||
is_other_file = lambda file: file.is_file() and file != self.main_file
|
is_other_file = lambda file: file.is_file() and file != self.main_file
|
||||||
@@ -109,25 +109,25 @@ class Example:
|
|||||||
# Specify encoding for building on Windows
|
# Specify encoding for building on Windows
|
||||||
with open(self.main_file, encoding="utf-8") as f:
|
with open(self.main_file, encoding="utf-8") as f:
|
||||||
first_line = f.readline().strip()
|
first_line = f.readline().strip()
|
||||||
match = re.match(r'^#\s+(?P<title>.+)$', first_line)
|
match = re.match(r"^#\s+(?P<title>.+)$", first_line)
|
||||||
if match:
|
if match:
|
||||||
return match.group('title')
|
return match.group("title")
|
||||||
return fix_case(self.path.stem.replace("_", " ").title())
|
return fix_case(self.path.stem.replace("_", " ").title())
|
||||||
|
|
||||||
def fix_relative_links(self, content: str) -> str:
|
def fix_relative_links(self, content: str) -> str:
|
||||||
"""
|
"""
|
||||||
Fix relative links in markdown content by converting them to gh-file
|
Fix relative links in markdown content by converting them to gh-file
|
||||||
format.
|
format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content (str): The markdown content to process
|
content (str): The markdown content to process
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Content with relative links converted to gh-file format
|
str: Content with relative links converted to gh-file format
|
||||||
"""
|
"""
|
||||||
# Regex to match markdown links [text](relative_path)
|
# Regex to match markdown links [text](relative_path)
|
||||||
# This matches links that don't start with http, https, ftp, or #
|
# This matches links that don't start with http, https, ftp, or #
|
||||||
link_pattern = r'\[([^\]]*)\]\((?!(?:https?|ftp)://|#)([^)]+)\)'
|
link_pattern = r"\[([^\]]*)\]\((?!(?:https?|ftp)://|#)([^)]+)\)"
|
||||||
|
|
||||||
def replace_link(match):
|
def replace_link(match):
|
||||||
link_text = match.group(1)
|
link_text = match.group(1)
|
||||||
@@ -137,7 +137,7 @@ class Example:
|
|||||||
gh_file = (self.main_file.parent / relative_path).resolve()
|
gh_file = (self.main_file.parent / relative_path).resolve()
|
||||||
gh_file = gh_file.relative_to(ROOT_DIR)
|
gh_file = gh_file.relative_to(ROOT_DIR)
|
||||||
|
|
||||||
return f'[{link_text}](gh-file:{gh_file})'
|
return f"[{link_text}](gh-file:{gh_file})"
|
||||||
|
|
||||||
return re.sub(link_pattern, replace_link, content)
|
return re.sub(link_pattern, replace_link, content)
|
||||||
|
|
||||||
@@ -150,9 +150,11 @@ class Example:
|
|||||||
code_fence = "``````"
|
code_fence = "``````"
|
||||||
|
|
||||||
if self.is_code:
|
if self.is_code:
|
||||||
content += (f"{code_fence}{self.main_file.suffix[1:]}\n"
|
content += (
|
||||||
f'--8<-- "{self.main_file}"\n'
|
f"{code_fence}{self.main_file.suffix[1:]}\n"
|
||||||
f"{code_fence}\n")
|
f'--8<-- "{self.main_file}"\n'
|
||||||
|
f"{code_fence}\n"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
with open(self.main_file) as f:
|
with open(self.main_file) as f:
|
||||||
# Skip the title from md snippets as it's been included above
|
# Skip the title from md snippets as it's been included above
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Literal
|
|||||||
|
|
||||||
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
||||||
# see https://docs.readthedocs.io/en/stable/reference/environment-variables.html # noqa
|
# see https://docs.readthedocs.io/en/stable/reference/environment-variables.html # noqa
|
||||||
if os.getenv('READTHEDOCS_VERSION_TYPE') == "tag":
|
if os.getenv("READTHEDOCS_VERSION_TYPE") == "tag":
|
||||||
# remove the warning banner if the version is a tagged release
|
# remove the warning banner if the version is a tagged release
|
||||||
mkdocs_dir = Path(__file__).parent.parent
|
mkdocs_dir = Path(__file__).parent.parent
|
||||||
announcement_path = mkdocs_dir / "overrides/main.html"
|
announcement_path = mkdocs_dir / "overrides/main.html"
|
||||||
|
|||||||
@@ -25,8 +25,9 @@ from mkdocs.structure.files import Files
|
|||||||
from mkdocs.structure.pages import Page
|
from mkdocs.structure.pages import Page
|
||||||
|
|
||||||
|
|
||||||
def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig,
|
def on_page_markdown(
|
||||||
files: Files) -> str:
|
markdown: str, *, page: Page, config: MkDocsConfig, files: Files
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Custom MkDocs plugin hook to rewrite special GitHub reference links
|
Custom MkDocs plugin hook to rewrite special GitHub reference links
|
||||||
in Markdown.
|
in Markdown.
|
||||||
@@ -35,7 +36,7 @@ def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig,
|
|||||||
GitHub shorthand links, such as:
|
GitHub shorthand links, such as:
|
||||||
- `[Link text](gh-issue:123)`
|
- `[Link text](gh-issue:123)`
|
||||||
- `<gh-pr:456>`
|
- `<gh-pr:456>`
|
||||||
|
|
||||||
And rewrites them into fully-qualified GitHub URLs with GitHub icons:
|
And rewrites them into fully-qualified GitHub URLs with GitHub icons:
|
||||||
- `[:octicons-mark-github-16: Link text](https://github.com/vllm-project/vllm/issues/123)`
|
- `[:octicons-mark-github-16: Link text](https://github.com/vllm-project/vllm/issues/123)`
|
||||||
- `[:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456)`
|
- `[:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456)`
|
||||||
@@ -88,21 +89,21 @@ def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig,
|
|||||||
"""
|
"""
|
||||||
Replaces a matched inline-style GitHub shorthand link
|
Replaces a matched inline-style GitHub shorthand link
|
||||||
with a full Markdown link.
|
with a full Markdown link.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
[My issue](gh-issue:123) → [:octicons-mark-github-16: My issue](https://github.com/vllm-project/vllm/issues/123)
|
[My issue](gh-issue:123) → [:octicons-mark-github-16: My issue](https://github.com/vllm-project/vllm/issues/123)
|
||||||
"""
|
"""
|
||||||
url = f'{urls[match.group("type")]}/{match.group("path")}'
|
url = f"{urls[match.group('type')]}/{match.group('path')}"
|
||||||
if fragment := match.group("fragment"):
|
if fragment := match.group("fragment"):
|
||||||
url += f"#{fragment}"
|
url += f"#{fragment}"
|
||||||
|
|
||||||
return f'[{gh_icon} {match.group("title")}]({url})'
|
return f"[{gh_icon} {match.group('title')}]({url})"
|
||||||
|
|
||||||
def replace_auto_link(match: re.Match) -> str:
|
def replace_auto_link(match: re.Match) -> str:
|
||||||
"""
|
"""
|
||||||
Replaces a matched autolink-style GitHub shorthand
|
Replaces a matched autolink-style GitHub shorthand
|
||||||
with a full Markdown link.
|
with a full Markdown link.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
<gh-pr:456> → [:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456)
|
<gh-pr:456> → [:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456)
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,54 +0,0 @@
|
|||||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
|
||||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
|
||||||
# following differences:
|
|
||||||
# - ruff line length is overridden to 88
|
|
||||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
line-length = 88
|
|
||||||
exclude = [
|
|
||||||
# External file, leaving license intact
|
|
||||||
"examples/other/fp8/quantizer/quantize.py",
|
|
||||||
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
|
||||||
"vllm/third_party/**" = ["ALL"]
|
|
||||||
"vllm/version.py" = ["F401"]
|
|
||||||
"vllm/_version.py" = ["ALL"]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = [
|
|
||||||
# pycodestyle
|
|
||||||
"E",
|
|
||||||
# Pyflakes
|
|
||||||
"F",
|
|
||||||
# pyupgrade
|
|
||||||
"UP",
|
|
||||||
# flake8-bugbear
|
|
||||||
"B",
|
|
||||||
# flake8-simplify
|
|
||||||
"SIM",
|
|
||||||
# isort
|
|
||||||
"I",
|
|
||||||
# flake8-logging-format
|
|
||||||
"G",
|
|
||||||
]
|
|
||||||
ignore = [
|
|
||||||
# star imports
|
|
||||||
"F405", "F403",
|
|
||||||
# lambda expression assignment
|
|
||||||
"E731",
|
|
||||||
# Loop control variable not used within loop body
|
|
||||||
"B007",
|
|
||||||
# f-string format
|
|
||||||
"UP032",
|
|
||||||
# Can remove once 3.10+ is the minimum Python version
|
|
||||||
"UP007",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint.isort]
|
|
||||||
known-first-party = ["vllm"]
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
|
||||||
docstring-code-format = true
|
|
||||||
127
pyproject.toml
127
pyproject.toml
@@ -52,27 +52,106 @@ lora_filesystem_resolver = "vllm.plugins.lora_resolvers.filesystem_resolver:regi
|
|||||||
where = ["."]
|
where = ["."]
|
||||||
include = ["vllm*"]
|
include = ["vllm*"]
|
||||||
|
|
||||||
[tool.yapfignore]
|
|
||||||
ignore_patterns = [
|
|
||||||
".buildkite/**",
|
|
||||||
"benchmarks/**",
|
|
||||||
"build/**",
|
|
||||||
"examples/**",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
# Allow lines to be as long as 80.
|
|
||||||
line-length = 80
|
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"vllm/third_party/**" = ["ALL"]
|
"vllm/third_party/**" = ["ALL"]
|
||||||
"vllm/version.py" = ["F401"]
|
"vllm/version.py" = ["F401"]
|
||||||
"vllm/_version.py" = ["ALL"]
|
"vllm/_version.py" = ["ALL"]
|
||||||
# Python 3.8 typing - skip V0 code
|
# TEMPORARY! These ignores will be fixed forward
|
||||||
"vllm/attention/**/*.py" = ["UP006", "UP035"]
|
## Line length violations
|
||||||
"vllm/engine/**/*.py" = ["UP006", "UP035"]
|
"csrc/cutlass_extensions/vllm_cutlass_library_extension.py" = ["E501"]
|
||||||
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
"tests/compile/piecewise/test_simple.py" = ["E501"]
|
||||||
"vllm/worker/**/*.py" = ["UP006", "UP035"]
|
"tests/compile/piecewise/test_toy_llama.py" = ["E501", "B023"]
|
||||||
|
"tests/entrypoints/conftest.py" = ["E501"]
|
||||||
|
"tests/entrypoints/openai/test_audio.py" = ["E501"]
|
||||||
|
"tests/entrypoints/openai/test_chat.py" = ["E501"]
|
||||||
|
"tests/entrypoints/openai/test_chat_template.py" = ["E501"]
|
||||||
|
"tests/entrypoints/openai/test_chat_with_tool_reasoning.py" = ["E501"]
|
||||||
|
"tests/entrypoints/openai/test_completion_with_function_calling.py" = ["E501"]
|
||||||
|
"tests/entrypoints/openai/test_video.py" = ["E501"]
|
||||||
|
"tests/entrypoints/openai/test_vision.py" = ["E501"]
|
||||||
|
"tests/entrypoints/test_chat_utils.py" = ["E501"]
|
||||||
|
"tests/kernels/moe/modular_kernel_tools/common.py" = ["E501"]
|
||||||
|
"tests/models/language/generation/test_gemma.py" = ["E501"]
|
||||||
|
"tests/models/language/generation/test_mistral.py" = ["E501"]
|
||||||
|
"tests/models/multimodal/generation/test_ultravox.py" = ["E501"]
|
||||||
|
"tests/models/multimodal/generation/test_voxtral.py" = ["E501"]
|
||||||
|
"tests/models/multimodal/generation/vlm_utils/custom_inputs.py" = ["E501"]
|
||||||
|
"tests/tool_use/test_tool_choice_required.py" = ["E501"]
|
||||||
|
"tests/v1/attention/utils.py" = ["E501"]
|
||||||
|
"tests/v1/entrypoints/openai/responses/test_image.py" = ["E501"]
|
||||||
|
"tests/v1/kv_connector/nixl_integration/test_accuracy.py" = ["E501"]
|
||||||
|
"tests/v1/kv_connector/unit/test_offloading_connector.py" = ["E501"]
|
||||||
|
"tests/v1/logits_processors/test_custom_offline.py" = ["E501"]
|
||||||
|
"vllm/attention/ops/pallas_kv_cache_update.py" = ["E501"]
|
||||||
|
"vllm/compilation/collective_fusion.py" = ["E501"]
|
||||||
|
"vllm/compilation/wrapper.py" = ["E501"]
|
||||||
|
"vllm/config/vllm.py" = ["E501"]
|
||||||
|
"vllm/distributed/device_communicators/all2all.py" = ["E501"]
|
||||||
|
"vllm/entrypoints/openai/protocol.py" = ["E501"]
|
||||||
|
"vllm/lora/layers/vocal_parallel_embedding.py" = ["E501"]
|
||||||
|
"vllm/model_executor/model_loader/bitsandbytes_loader.py" = ["E501"]
|
||||||
|
"vllm/model_executor/models/bailing_moe.py" = ["E501"]
|
||||||
|
"vllm/model_executor/models/hyperclovax_vision.py" = ["E501"]
|
||||||
|
"vllm/model_executor/models/llama4_eagle.py" = ["E501"]
|
||||||
|
"vllm/model_executor/models/longcat_flash_mtp.py" = ["E501"]
|
||||||
|
"vllm/model_executor/models/phi4mm.py" = ["E501"]
|
||||||
|
"vllm/model_executor/models/qwen3_next.py" = ["E501"]
|
||||||
|
"vllm/model_executor/layers/quantization/ptpc_fp8.py" = ["E501"]
|
||||||
|
"vllm/v1/attention/backends/mla/common.py" = ["E501"]
|
||||||
|
"vllm/v1/engine/utils.py" = ["E501"]
|
||||||
|
"vllm/v1/utils.py" = ["E501"]
|
||||||
|
"vllm/v1/worker/gpu_model_runner.py" = ["E501"]
|
||||||
|
## Simplification rules
|
||||||
|
"tests/distributed/test_expert_placement.py" = ["SIM108"]
|
||||||
|
"tests/kernels/attention/test_cutlass_mla_decode.py" = ["SIM108"]
|
||||||
|
"tests/kernels/attention/test_flashmla.py" = ["SIM108"]
|
||||||
|
"tests/kernels/attention/test_lightning_attn.py" = ["SIM108"]
|
||||||
|
"tests/kernels/moe/test_pplx_moe.py" = ["SIM108"]
|
||||||
|
"tests/kernels/quantization/test_cutlass_scaled_mm.py" = ["SIM108"]
|
||||||
|
"tests/kernels/test_onednn.py" = ["SIM108"]
|
||||||
|
"tests/kernels/utils.py" = ["SIM108"]
|
||||||
|
"tests/multimodal/test_processing.py" = ["SIM108"]
|
||||||
|
"vllm/attention/ops/triton_reshape_and_cache_flash.py" = ["SIM108"]
|
||||||
|
"vllm/distributed/parallel_state.py" = ["SIM108"]
|
||||||
|
"vllm/entrypoints/chat_utils.py" = ["SIM108"]
|
||||||
|
"vllm/entrypoints/llm.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/batch_invariant.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/fused_moe/layer.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/fused_moe/modular_kernel.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/layernorm.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/lightning_attn.py" = ["SIM108"]
|
||||||
|
"vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py" = ["SIM103"]
|
||||||
|
"vllm/model_executor/layers/quantization/compressed_tensors/utils.py" = ["SIM110"]
|
||||||
|
"vllm/model_executor/layers/quantization/quark/utils.py" = ["SIM110"]
|
||||||
|
"vllm/utils/__init__.py" = ["SIM108"]
|
||||||
|
"vllm/v1/sample/ops/bad_words.py" = ["SIM108"]
|
||||||
|
"vllm/v1/sample/rejection_sampler.py" = ["SIM108"]
|
||||||
|
"vllm/v1/worker/tpu_model_runner.py" = ["SIM108"]
|
||||||
|
"vllm/_custom_ops.py" = ["SIM108"]
|
||||||
|
"tools/profiler/print_layerwise_table.py" = ["SIM118"]
|
||||||
|
## Loop variable binding issues
|
||||||
|
"tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"]
|
||||||
|
## Type annotation modernization and other rules
|
||||||
|
"vllm/attention/backends/abstract.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/attention/layer.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/attention/layers/chunked_local_attention.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/attention/ops/flashmla.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/attention/ops/paged_attn.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/engine/arg_utils.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/engine/metrics.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/engine/metrics_types.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/executor/executor_base.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/executor/msgspec_utils.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/executor/ray_distributed_executor.py" = ["UP035", "UP006", "SIM108", "SIM112"]
|
||||||
|
"vllm/executor/ray_utils.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/executor/uniproc_executor.py" = ["UP035", "UP006"]
|
||||||
|
"vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py" = ["UP035"]
|
||||||
|
## Type comparison issues
|
||||||
|
"vllm/multimodal/inputs.py" = ["E721"]
|
||||||
|
# End of temporary ignores
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
@@ -87,7 +166,7 @@ select = [
|
|||||||
# flake8-simplify
|
# flake8-simplify
|
||||||
"SIM",
|
"SIM",
|
||||||
# isort
|
# isort
|
||||||
# "I",
|
"I",
|
||||||
# flake8-logging-format
|
# flake8-logging-format
|
||||||
"G",
|
"G",
|
||||||
]
|
]
|
||||||
@@ -104,21 +183,15 @@ ignore = [
|
|||||||
"UP007",
|
"UP007",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
docstring-code-format = true
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
plugins = ['pydantic.mypy']
|
plugins = ['pydantic.mypy']
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
follow_imports = "silent"
|
follow_imports = "silent"
|
||||||
|
|
||||||
[tool.isort]
|
|
||||||
skip_glob = [
|
|
||||||
".buildkite/*",
|
|
||||||
"benchmarks/*",
|
|
||||||
"examples/*",
|
|
||||||
]
|
|
||||||
use_parentheses = true
|
|
||||||
skip_gitignore = true
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
markers = [
|
markers = [
|
||||||
"slow_test",
|
"slow_test",
|
||||||
|
|||||||
255
setup.py
255
setup.py
@@ -34,32 +34,36 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# cannot import envs directly because it depends on vllm,
|
# cannot import envs directly because it depends on vllm,
|
||||||
# which is not installed yet
|
# which is not installed yet
|
||||||
envs = load_module_from_path('envs', os.path.join(ROOT_DIR, 'vllm', 'envs.py'))
|
envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "vllm", "envs.py"))
|
||||||
|
|
||||||
VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE
|
VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE
|
||||||
|
|
||||||
if sys.platform.startswith("darwin") and VLLM_TARGET_DEVICE != "cpu":
|
if sys.platform.startswith("darwin") and VLLM_TARGET_DEVICE != "cpu":
|
||||||
logger.warning(
|
logger.warning("VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS")
|
||||||
"VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS")
|
|
||||||
VLLM_TARGET_DEVICE = "cpu"
|
VLLM_TARGET_DEVICE = "cpu"
|
||||||
elif not (sys.platform.startswith("linux")
|
elif not (sys.platform.startswith("linux") or sys.platform.startswith("darwin")):
|
||||||
or sys.platform.startswith("darwin")):
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"vLLM only supports Linux platform (including WSL) and MacOS."
|
"vLLM only supports Linux platform (including WSL) and MacOS."
|
||||||
"Building on %s, "
|
"Building on %s, "
|
||||||
"so vLLM may not be able to run correctly", sys.platform)
|
"so vLLM may not be able to run correctly",
|
||||||
|
sys.platform,
|
||||||
|
)
|
||||||
VLLM_TARGET_DEVICE = "empty"
|
VLLM_TARGET_DEVICE = "empty"
|
||||||
elif (sys.platform.startswith("linux") and torch.version.cuda is None
|
elif (
|
||||||
and os.getenv("VLLM_TARGET_DEVICE") is None
|
sys.platform.startswith("linux")
|
||||||
and torch.version.hip is None):
|
and torch.version.cuda is None
|
||||||
|
and os.getenv("VLLM_TARGET_DEVICE") is None
|
||||||
|
and torch.version.hip is None
|
||||||
|
):
|
||||||
# if cuda or hip is not available and VLLM_TARGET_DEVICE is not set,
|
# if cuda or hip is not available and VLLM_TARGET_DEVICE is not set,
|
||||||
# fallback to cpu
|
# fallback to cpu
|
||||||
VLLM_TARGET_DEVICE = "cpu"
|
VLLM_TARGET_DEVICE = "cpu"
|
||||||
|
|
||||||
|
|
||||||
def is_sccache_available() -> bool:
|
def is_sccache_available() -> bool:
|
||||||
return which("sccache") is not None and \
|
return which("sccache") is not None and not bool(
|
||||||
not bool(int(os.getenv("VLLM_DISABLE_SCCACHE", "0")))
|
int(os.getenv("VLLM_DISABLE_SCCACHE", "0"))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_ccache_available() -> bool:
|
def is_ccache_available() -> bool:
|
||||||
@@ -83,8 +87,7 @@ def is_url_available(url: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
class CMakeExtension(Extension):
|
class CMakeExtension(Extension):
|
||||||
|
def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None:
|
||||||
def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None:
|
|
||||||
super().__init__(name, sources=[], py_limited_api=True, **kwa)
|
super().__init__(name, sources=[], py_limited_api=True, **kwa)
|
||||||
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
|
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
|
||||||
|
|
||||||
@@ -121,8 +124,8 @@ class cmake_build_ext(build_ext):
|
|||||||
if nvcc_threads is not None:
|
if nvcc_threads is not None:
|
||||||
nvcc_threads = int(nvcc_threads)
|
nvcc_threads = int(nvcc_threads)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Using NVCC_THREADS=%d as the number of nvcc threads.",
|
"Using NVCC_THREADS=%d as the number of nvcc threads.", nvcc_threads
|
||||||
nvcc_threads)
|
)
|
||||||
else:
|
else:
|
||||||
nvcc_threads = 1
|
nvcc_threads = 1
|
||||||
num_jobs = max(1, num_jobs // nvcc_threads)
|
num_jobs = max(1, num_jobs // nvcc_threads)
|
||||||
@@ -146,36 +149,36 @@ class cmake_build_ext(build_ext):
|
|||||||
cfg = envs.CMAKE_BUILD_TYPE or default_cfg
|
cfg = envs.CMAKE_BUILD_TYPE or default_cfg
|
||||||
|
|
||||||
cmake_args = [
|
cmake_args = [
|
||||||
'-DCMAKE_BUILD_TYPE={}'.format(cfg),
|
"-DCMAKE_BUILD_TYPE={}".format(cfg),
|
||||||
'-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE),
|
"-DVLLM_TARGET_DEVICE={}".format(VLLM_TARGET_DEVICE),
|
||||||
]
|
]
|
||||||
|
|
||||||
verbose = envs.VERBOSE
|
verbose = envs.VERBOSE
|
||||||
if verbose:
|
if verbose:
|
||||||
cmake_args += ['-DCMAKE_VERBOSE_MAKEFILE=ON']
|
cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"]
|
||||||
|
|
||||||
if is_sccache_available():
|
if is_sccache_available():
|
||||||
cmake_args += [
|
cmake_args += [
|
||||||
'-DCMAKE_C_COMPILER_LAUNCHER=sccache',
|
"-DCMAKE_C_COMPILER_LAUNCHER=sccache",
|
||||||
'-DCMAKE_CXX_COMPILER_LAUNCHER=sccache',
|
"-DCMAKE_CXX_COMPILER_LAUNCHER=sccache",
|
||||||
'-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache',
|
"-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache",
|
||||||
'-DCMAKE_HIP_COMPILER_LAUNCHER=sccache',
|
"-DCMAKE_HIP_COMPILER_LAUNCHER=sccache",
|
||||||
]
|
]
|
||||||
elif is_ccache_available():
|
elif is_ccache_available():
|
||||||
cmake_args += [
|
cmake_args += [
|
||||||
'-DCMAKE_C_COMPILER_LAUNCHER=ccache',
|
"-DCMAKE_C_COMPILER_LAUNCHER=ccache",
|
||||||
'-DCMAKE_CXX_COMPILER_LAUNCHER=ccache',
|
"-DCMAKE_CXX_COMPILER_LAUNCHER=ccache",
|
||||||
'-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache',
|
"-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache",
|
||||||
'-DCMAKE_HIP_COMPILER_LAUNCHER=ccache',
|
"-DCMAKE_HIP_COMPILER_LAUNCHER=ccache",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Pass the python executable to cmake so it can find an exact
|
# Pass the python executable to cmake so it can find an exact
|
||||||
# match.
|
# match.
|
||||||
cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)]
|
cmake_args += ["-DVLLM_PYTHON_EXECUTABLE={}".format(sys.executable)]
|
||||||
|
|
||||||
# Pass the python path to cmake so it can reuse the build dependencies
|
# Pass the python path to cmake so it can reuse the build dependencies
|
||||||
# on subsequent calls to python.
|
# on subsequent calls to python.
|
||||||
cmake_args += ['-DVLLM_PYTHON_PATH={}'.format(":".join(sys.path))]
|
cmake_args += ["-DVLLM_PYTHON_PATH={}".format(":".join(sys.path))]
|
||||||
|
|
||||||
# Override the base directory for FetchContent downloads to $ROOT/.deps
|
# Override the base directory for FetchContent downloads to $ROOT/.deps
|
||||||
# This allows sharing dependencies between profiles,
|
# This allows sharing dependencies between profiles,
|
||||||
@@ -183,7 +186,7 @@ class cmake_build_ext(build_ext):
|
|||||||
# To override this, set the FETCHCONTENT_BASE_DIR environment variable.
|
# To override this, set the FETCHCONTENT_BASE_DIR environment variable.
|
||||||
fc_base_dir = os.path.join(ROOT_DIR, ".deps")
|
fc_base_dir = os.path.join(ROOT_DIR, ".deps")
|
||||||
fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir)
|
fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir)
|
||||||
cmake_args += ['-DFETCHCONTENT_BASE_DIR={}'.format(fc_base_dir)]
|
cmake_args += ["-DFETCHCONTENT_BASE_DIR={}".format(fc_base_dir)]
|
||||||
|
|
||||||
#
|
#
|
||||||
# Setup parallelism and build tool
|
# Setup parallelism and build tool
|
||||||
@@ -191,35 +194,36 @@ class cmake_build_ext(build_ext):
|
|||||||
num_jobs, nvcc_threads = self.compute_num_jobs()
|
num_jobs, nvcc_threads = self.compute_num_jobs()
|
||||||
|
|
||||||
if nvcc_threads:
|
if nvcc_threads:
|
||||||
cmake_args += ['-DNVCC_THREADS={}'.format(nvcc_threads)]
|
cmake_args += ["-DNVCC_THREADS={}".format(nvcc_threads)]
|
||||||
|
|
||||||
if is_ninja_available():
|
if is_ninja_available():
|
||||||
build_tool = ['-G', 'Ninja']
|
build_tool = ["-G", "Ninja"]
|
||||||
cmake_args += [
|
cmake_args += [
|
||||||
'-DCMAKE_JOB_POOL_COMPILE:STRING=compile',
|
"-DCMAKE_JOB_POOL_COMPILE:STRING=compile",
|
||||||
'-DCMAKE_JOB_POOLS:STRING=compile={}'.format(num_jobs),
|
"-DCMAKE_JOB_POOLS:STRING=compile={}".format(num_jobs),
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
# Default build tool to whatever cmake picks.
|
# Default build tool to whatever cmake picks.
|
||||||
build_tool = []
|
build_tool = []
|
||||||
# Make sure we use the nvcc from CUDA_HOME
|
# Make sure we use the nvcc from CUDA_HOME
|
||||||
if _is_cuda():
|
if _is_cuda():
|
||||||
cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc']
|
cmake_args += [f"-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc"]
|
||||||
|
|
||||||
other_cmake_args = os.environ.get("CMAKE_ARGS")
|
other_cmake_args = os.environ.get("CMAKE_ARGS")
|
||||||
if other_cmake_args:
|
if other_cmake_args:
|
||||||
cmake_args += other_cmake_args.split()
|
cmake_args += other_cmake_args.split()
|
||||||
|
|
||||||
subprocess.check_call(
|
subprocess.check_call(
|
||||||
['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args],
|
["cmake", ext.cmake_lists_dir, *build_tool, *cmake_args],
|
||||||
cwd=self.build_temp)
|
cwd=self.build_temp,
|
||||||
|
)
|
||||||
|
|
||||||
def build_extensions(self) -> None:
|
def build_extensions(self) -> None:
|
||||||
# Ensure that CMake is present and working
|
# Ensure that CMake is present and working
|
||||||
try:
|
try:
|
||||||
subprocess.check_output(['cmake', '--version'])
|
subprocess.check_output(["cmake", "--version"])
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise RuntimeError('Cannot find CMake executable') from e
|
raise RuntimeError("Cannot find CMake executable") from e
|
||||||
|
|
||||||
# Create build directory if it does not exist.
|
# Create build directory if it does not exist.
|
||||||
if not os.path.exists(self.build_temp):
|
if not os.path.exists(self.build_temp):
|
||||||
@@ -258,13 +262,18 @@ class cmake_build_ext(build_ext):
|
|||||||
# CMake appends the extension prefix to the install path,
|
# CMake appends the extension prefix to the install path,
|
||||||
# and outdir already contains that prefix, so we need to remove it.
|
# and outdir already contains that prefix, so we need to remove it.
|
||||||
prefix = outdir
|
prefix = outdir
|
||||||
for _ in range(ext.name.count('.')):
|
for _ in range(ext.name.count(".")):
|
||||||
prefix = prefix.parent
|
prefix = prefix.parent
|
||||||
|
|
||||||
# prefix here should actually be the same for all components
|
# prefix here should actually be the same for all components
|
||||||
install_args = [
|
install_args = [
|
||||||
"cmake", "--install", ".", "--prefix", prefix, "--component",
|
"cmake",
|
||||||
target_name(ext.name)
|
"--install",
|
||||||
|
".",
|
||||||
|
"--prefix",
|
||||||
|
prefix,
|
||||||
|
"--component",
|
||||||
|
target_name(ext.name),
|
||||||
]
|
]
|
||||||
subprocess.check_call(install_args, cwd=self.build_temp)
|
subprocess.check_call(install_args, cwd=self.build_temp)
|
||||||
|
|
||||||
@@ -275,12 +284,15 @@ class cmake_build_ext(build_ext):
|
|||||||
# copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current
|
# copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current
|
||||||
# directory so that they can be included in the editable build
|
# directory so that they can be included in the editable build
|
||||||
import glob
|
import glob
|
||||||
files = glob.glob(os.path.join(self.build_lib, "vllm",
|
|
||||||
"vllm_flash_attn", "**", "*.py"),
|
files = glob.glob(
|
||||||
recursive=True)
|
os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "**", "*.py"),
|
||||||
|
recursive=True,
|
||||||
|
)
|
||||||
for file in files:
|
for file in files:
|
||||||
dst_file = os.path.join("vllm/vllm_flash_attn",
|
dst_file = os.path.join(
|
||||||
file.split("vllm/vllm_flash_attn/")[-1])
|
"vllm/vllm_flash_attn", file.split("vllm/vllm_flash_attn/")[-1]
|
||||||
|
)
|
||||||
print(f"Copying {file} to {dst_file}")
|
print(f"Copying {file} to {dst_file}")
|
||||||
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
|
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
|
||||||
self.copy_file(file, dst_file)
|
self.copy_file(file, dst_file)
|
||||||
@@ -290,8 +302,7 @@ class precompiled_build_ext(build_ext):
|
|||||||
"""Disables extension building when using precompiled binaries."""
|
"""Disables extension building when using precompiled binaries."""
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
assert _is_cuda(
|
assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
|
||||||
), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
|
|
||||||
|
|
||||||
def build_extensions(self) -> None:
|
def build_extensions(self) -> None:
|
||||||
print("Skipping build_ext: using precompiled extensions.")
|
print("Skipping build_ext: using precompiled extensions.")
|
||||||
@@ -312,9 +323,9 @@ class precompiled_wheel_utils:
|
|||||||
wheel_filename = wheel_url_or_path.split("/")[-1]
|
wheel_filename = wheel_url_or_path.split("/")[-1]
|
||||||
temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
|
temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
|
||||||
wheel_path = os.path.join(temp_dir, wheel_filename)
|
wheel_path = os.path.join(temp_dir, wheel_filename)
|
||||||
print(f"Downloading wheel from {wheel_url_or_path} "
|
print(f"Downloading wheel from {wheel_url_or_path} to {wheel_path}")
|
||||||
f"to {wheel_path}")
|
|
||||||
from urllib.request import urlretrieve
|
from urllib.request import urlretrieve
|
||||||
|
|
||||||
urlretrieve(wheel_url_or_path, filename=wheel_path)
|
urlretrieve(wheel_url_or_path, filename=wheel_path)
|
||||||
else:
|
else:
|
||||||
wheel_path = wheel_url_or_path
|
wheel_path = wheel_url_or_path
|
||||||
@@ -335,25 +346,29 @@ class precompiled_wheel_utils:
|
|||||||
]
|
]
|
||||||
|
|
||||||
compiled_regex = re.compile(
|
compiled_regex = re.compile(
|
||||||
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
|
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
|
||||||
|
)
|
||||||
file_members = list(
|
file_members = list(
|
||||||
filter(lambda x: x.filename in files_to_copy,
|
filter(lambda x: x.filename in files_to_copy, wheel.filelist)
|
||||||
wheel.filelist))
|
)
|
||||||
file_members += list(
|
file_members += list(
|
||||||
filter(lambda x: compiled_regex.match(x.filename),
|
filter(lambda x: compiled_regex.match(x.filename), wheel.filelist)
|
||||||
wheel.filelist))
|
)
|
||||||
|
|
||||||
for file in file_members:
|
for file in file_members:
|
||||||
print(f"[extract] {file.filename}")
|
print(f"[extract] {file.filename}")
|
||||||
target_path = os.path.join(".", file.filename)
|
target_path = os.path.join(".", file.filename)
|
||||||
os.makedirs(os.path.dirname(target_path), exist_ok=True)
|
os.makedirs(os.path.dirname(target_path), exist_ok=True)
|
||||||
with wheel.open(file.filename) as src, open(
|
with (
|
||||||
target_path, "wb") as dst:
|
wheel.open(file.filename) as src,
|
||||||
|
open(target_path, "wb") as dst,
|
||||||
|
):
|
||||||
shutil.copyfileobj(src, dst)
|
shutil.copyfileobj(src, dst)
|
||||||
|
|
||||||
pkg = os.path.dirname(file.filename).replace("/", ".")
|
pkg = os.path.dirname(file.filename).replace("/", ".")
|
||||||
package_data_patch.setdefault(pkg, []).append(
|
package_data_patch.setdefault(pkg, []).append(
|
||||||
os.path.basename(file.filename))
|
os.path.basename(file.filename)
|
||||||
|
)
|
||||||
|
|
||||||
return package_data_patch
|
return package_data_patch
|
||||||
finally:
|
finally:
|
||||||
@@ -369,10 +384,13 @@ class precompiled_wheel_utils:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Get the latest commit hash of the upstream main branch.
|
# Get the latest commit hash of the upstream main branch.
|
||||||
resp_json = subprocess.check_output([
|
resp_json = subprocess.check_output(
|
||||||
"curl", "-s",
|
[
|
||||||
"https://api.github.com/repos/vllm-project/vllm/commits/main"
|
"curl",
|
||||||
]).decode("utf-8")
|
"-s",
|
||||||
|
"https://api.github.com/repos/vllm-project/vllm/commits/main",
|
||||||
|
]
|
||||||
|
).decode("utf-8")
|
||||||
upstream_main_commit = json.loads(resp_json)["sha"]
|
upstream_main_commit = json.loads(resp_json)["sha"]
|
||||||
|
|
||||||
# In Docker build context, .git may be immutable or missing.
|
# In Docker build context, .git may be immutable or missing.
|
||||||
@@ -382,25 +400,32 @@ class precompiled_wheel_utils:
|
|||||||
# Check if the upstream_main_commit exists in the local repo
|
# Check if the upstream_main_commit exists in the local repo
|
||||||
try:
|
try:
|
||||||
subprocess.check_output(
|
subprocess.check_output(
|
||||||
["git", "cat-file", "-e", f"{upstream_main_commit}"])
|
["git", "cat-file", "-e", f"{upstream_main_commit}"]
|
||||||
|
)
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
# If not present, fetch it from the remote repository.
|
# If not present, fetch it from the remote repository.
|
||||||
# Note that this does not update any local branches,
|
# Note that this does not update any local branches,
|
||||||
# but ensures that this commit ref and its history are
|
# but ensures that this commit ref and its history are
|
||||||
# available in our local repo.
|
# available in our local repo.
|
||||||
subprocess.check_call([
|
subprocess.check_call(
|
||||||
"git", "fetch", "https://github.com/vllm-project/vllm",
|
["git", "fetch", "https://github.com/vllm-project/vllm", "main"]
|
||||||
"main"
|
)
|
||||||
])
|
|
||||||
|
|
||||||
# Then get the commit hash of the current branch that is the same as
|
# Then get the commit hash of the current branch that is the same as
|
||||||
# the upstream main commit.
|
# the upstream main commit.
|
||||||
current_branch = subprocess.check_output(
|
current_branch = (
|
||||||
["git", "branch", "--show-current"]).decode("utf-8").strip()
|
subprocess.check_output(["git", "branch", "--show-current"])
|
||||||
|
.decode("utf-8")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
|
||||||
base_commit = subprocess.check_output([
|
base_commit = (
|
||||||
"git", "merge-base", f"{upstream_main_commit}", current_branch
|
subprocess.check_output(
|
||||||
]).decode("utf-8").strip()
|
["git", "merge-base", f"{upstream_main_commit}", current_branch]
|
||||||
|
)
|
||||||
|
.decode("utf-8")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
return base_commit
|
return base_commit
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
raise ValueError(err) from None
|
raise ValueError(err) from None
|
||||||
@@ -408,7 +433,9 @@ class precompiled_wheel_utils:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to get the base commit in the main branch. "
|
"Failed to get the base commit in the main branch. "
|
||||||
"Using the nightly wheel. The libraries in this "
|
"Using the nightly wheel. The libraries in this "
|
||||||
"wheel may not be compatible with your dev branch: %s", err)
|
"wheel may not be compatible with your dev branch: %s",
|
||||||
|
err,
|
||||||
|
)
|
||||||
return "nightly"
|
return "nightly"
|
||||||
|
|
||||||
|
|
||||||
@@ -418,12 +445,13 @@ def _no_device() -> bool:
|
|||||||
|
|
||||||
def _is_cuda() -> bool:
|
def _is_cuda() -> bool:
|
||||||
has_cuda = torch.version.cuda is not None
|
has_cuda = torch.version.cuda is not None
|
||||||
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu())
|
return VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu()
|
||||||
|
|
||||||
|
|
||||||
def _is_hip() -> bool:
|
def _is_hip() -> bool:
|
||||||
return (VLLM_TARGET_DEVICE == "cuda"
|
return (
|
||||||
or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None
|
VLLM_TARGET_DEVICE == "cuda" or VLLM_TARGET_DEVICE == "rocm"
|
||||||
|
) and torch.version.hip is not None
|
||||||
|
|
||||||
|
|
||||||
def _is_tpu() -> bool:
|
def _is_tpu() -> bool:
|
||||||
@@ -462,8 +490,12 @@ def get_rocm_version():
|
|||||||
minor = ctypes.c_uint32()
|
minor = ctypes.c_uint32()
|
||||||
patch = ctypes.c_uint32()
|
patch = ctypes.c_uint32()
|
||||||
|
|
||||||
if (get_rocm_core_version(ctypes.byref(major), ctypes.byref(minor),
|
if (
|
||||||
ctypes.byref(patch)) == 0):
|
get_rocm_core_version(
|
||||||
|
ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch)
|
||||||
|
)
|
||||||
|
== 0
|
||||||
|
):
|
||||||
return f"{major.value}.{minor.value}.{patch.value}"
|
return f"{major.value}.{minor.value}.{patch.value}"
|
||||||
return None
|
return None
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -476,8 +508,9 @@ def get_nvcc_cuda_version() -> Version:
|
|||||||
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
|
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
|
||||||
"""
|
"""
|
||||||
assert CUDA_HOME is not None, "CUDA_HOME is not set"
|
assert CUDA_HOME is not None, "CUDA_HOME is not set"
|
||||||
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
|
nvcc_output = subprocess.check_output(
|
||||||
universal_newlines=True)
|
[CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True
|
||||||
|
)
|
||||||
output = nvcc_output.split()
|
output = nvcc_output.split()
|
||||||
release_idx = output.index("release") + 1
|
release_idx = output.index("release") + 1
|
||||||
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
|
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
|
||||||
@@ -489,14 +522,20 @@ def get_gaudi_sw_version():
|
|||||||
Returns the driver version.
|
Returns the driver version.
|
||||||
"""
|
"""
|
||||||
# Enable console printing for `hl-smi` check
|
# Enable console printing for `hl-smi` check
|
||||||
output = subprocess.run("hl-smi",
|
output = subprocess.run(
|
||||||
shell=True,
|
"hl-smi",
|
||||||
text=True,
|
shell=True,
|
||||||
capture_output=True,
|
text=True,
|
||||||
env={"ENABLE_CONSOLE": "true"})
|
capture_output=True,
|
||||||
|
env={"ENABLE_CONSOLE": "true"},
|
||||||
|
)
|
||||||
if output.returncode == 0 and output.stdout:
|
if output.returncode == 0 and output.stdout:
|
||||||
return output.stdout.split("\n")[2].replace(
|
return (
|
||||||
" ", "").split(":")[1][:-1].split("-")[0]
|
output.stdout.split("\n")[2]
|
||||||
|
.replace(" ", "")
|
||||||
|
.split(":")[1][:-1]
|
||||||
|
.split("-")[0]
|
||||||
|
)
|
||||||
return "0.0.0" # when hl-smi is not available
|
return "0.0.0" # when hl-smi is not available
|
||||||
|
|
||||||
|
|
||||||
@@ -546,8 +585,11 @@ def get_requirements() -> list[str]:
|
|||||||
for line in requirements:
|
for line in requirements:
|
||||||
if line.startswith("-r "):
|
if line.startswith("-r "):
|
||||||
resolved_requirements += _read_requirements(line.split()[1])
|
resolved_requirements += _read_requirements(line.split()[1])
|
||||||
elif not line.startswith("--") and not line.startswith(
|
elif (
|
||||||
"#") and line.strip() != "":
|
not line.startswith("--")
|
||||||
|
and not line.startswith("#")
|
||||||
|
and line.strip() != ""
|
||||||
|
):
|
||||||
resolved_requirements.append(line)
|
resolved_requirements.append(line)
|
||||||
return resolved_requirements
|
return resolved_requirements
|
||||||
|
|
||||||
@@ -558,7 +600,7 @@ def get_requirements() -> list[str]:
|
|||||||
cuda_major, cuda_minor = torch.version.cuda.split(".")
|
cuda_major, cuda_minor = torch.version.cuda.split(".")
|
||||||
modified_requirements = []
|
modified_requirements = []
|
||||||
for req in requirements:
|
for req in requirements:
|
||||||
if ("vllm-flash-attn" in req and cuda_major != "12"):
|
if "vllm-flash-attn" in req and cuda_major != "12":
|
||||||
# vllm-flash-attn is built only for CUDA 12.x.
|
# vllm-flash-attn is built only for CUDA 12.x.
|
||||||
# Skip for other versions.
|
# Skip for other versions.
|
||||||
continue
|
continue
|
||||||
@@ -573,8 +615,7 @@ def get_requirements() -> list[str]:
|
|||||||
elif _is_xpu():
|
elif _is_xpu():
|
||||||
requirements = _read_requirements("xpu.txt")
|
requirements = _read_requirements("xpu.txt")
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("Unsupported platform, please use CUDA, ROCm, or CPU.")
|
||||||
"Unsupported platform, please use CUDA, ROCm, or CPU.")
|
|
||||||
return requirements
|
return requirements
|
||||||
|
|
||||||
|
|
||||||
@@ -590,14 +631,13 @@ if _is_cuda():
|
|||||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
|
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
|
||||||
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
|
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
|
||||||
# FA3 requires CUDA 12.3 or later
|
# FA3 requires CUDA 12.3 or later
|
||||||
ext_modules.append(
|
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||||
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
|
||||||
# Optional since this doesn't get built (produce an .so file) when
|
# Optional since this doesn't get built (produce an .so file) when
|
||||||
# not targeting a hopper system
|
# not targeting a hopper system
|
||||||
|
ext_modules.append(CMakeExtension(name="vllm._flashmla_C", optional=True))
|
||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
CMakeExtension(name="vllm._flashmla_C", optional=True))
|
CMakeExtension(name="vllm._flashmla_extension_C", optional=True)
|
||||||
ext_modules.append(
|
)
|
||||||
CMakeExtension(name="vllm._flashmla_extension_C", optional=True))
|
|
||||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||||
|
|
||||||
if _build_custom_ops():
|
if _build_custom_ops():
|
||||||
@@ -619,6 +659,7 @@ if envs.VLLM_USE_PRECOMPILED:
|
|||||||
wheel_url = wheel_location
|
wheel_url = wheel_location
|
||||||
else:
|
else:
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
arch = platform.machine()
|
arch = platform.machine()
|
||||||
if arch == "x86_64":
|
if arch == "x86_64":
|
||||||
wheel_tag = "manylinux1_x86_64"
|
wheel_tag = "manylinux1_x86_64"
|
||||||
@@ -628,8 +669,11 @@ if envs.VLLM_USE_PRECOMPILED:
|
|||||||
raise ValueError(f"Unsupported architecture: {arch}")
|
raise ValueError(f"Unsupported architecture: {arch}")
|
||||||
base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
|
base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
|
||||||
wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
|
wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
|
||||||
nightly_wheel_url = f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
|
nightly_wheel_url = (
|
||||||
|
f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
|
||||||
|
)
|
||||||
from urllib.request import urlopen
|
from urllib.request import urlopen
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with urlopen(wheel_url) as resp:
|
with urlopen(wheel_url) as resp:
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
@@ -638,8 +682,7 @@ if envs.VLLM_USE_PRECOMPILED:
|
|||||||
print(f"[warn] Falling back to nightly wheel: {e}")
|
print(f"[warn] Falling back to nightly wheel: {e}")
|
||||||
wheel_url = nightly_wheel_url
|
wheel_url = nightly_wheel_url
|
||||||
|
|
||||||
patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(
|
patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(wheel_url)
|
||||||
wheel_url)
|
|
||||||
for pkg, files in patch.items():
|
for pkg, files in patch.items():
|
||||||
package_data.setdefault(pkg, []).extend(files)
|
package_data.setdefault(pkg, []).extend(files)
|
||||||
|
|
||||||
@@ -650,8 +693,9 @@ if not ext_modules:
|
|||||||
cmdclass = {}
|
cmdclass = {}
|
||||||
else:
|
else:
|
||||||
cmdclass = {
|
cmdclass = {
|
||||||
"build_ext":
|
"build_ext": precompiled_build_ext
|
||||||
precompiled_build_ext if envs.VLLM_USE_PRECOMPILED else cmake_build_ext
|
if envs.VLLM_USE_PRECOMPILED
|
||||||
|
else cmake_build_ext
|
||||||
}
|
}
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
@@ -664,8 +708,11 @@ setup(
|
|||||||
"tensorizer": ["tensorizer==2.10.1"],
|
"tensorizer": ["tensorizer==2.10.1"],
|
||||||
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||||
"runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"],
|
"runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"],
|
||||||
"audio": ["librosa", "soundfile",
|
"audio": [
|
||||||
"mistral_common[audio]"], # Required for audio processing
|
"librosa",
|
||||||
|
"soundfile",
|
||||||
|
"mistral_common[audio]",
|
||||||
|
], # Required for audio processing
|
||||||
"video": [], # Kept for backwards compatibility
|
"video": [], # Kept for backwards compatibility
|
||||||
# FlashInfer should be updated together with the Dockerfile
|
# FlashInfer should be updated together with the Dockerfile
|
||||||
"flashinfer": ["flashinfer-python==0.3.1"],
|
"flashinfer": ["flashinfer-python==0.3.1"],
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
|
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import weakref
|
import weakref
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
@@ -37,16 +38,21 @@ def test_vllm_gc_ed():
|
|||||||
|
|
||||||
|
|
||||||
def _fix_prompt_embed_outputs(
|
def _fix_prompt_embed_outputs(
|
||||||
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner,
|
vllm_outputs: list[tuple[list[int], str]],
|
||||||
example_prompts: list[str]) -> list[tuple[list[int], str]]:
|
hf_model: HfRunner,
|
||||||
|
example_prompts: list[str],
|
||||||
|
) -> list[tuple[list[int], str]]:
|
||||||
fixed_vllm_outputs = []
|
fixed_vllm_outputs = []
|
||||||
for vllm_output, hf_input, prompt in zip(
|
for vllm_output, hf_input, prompt in zip(
|
||||||
vllm_outputs, hf_model.get_inputs(example_prompts),
|
vllm_outputs, hf_model.get_inputs(example_prompts), example_prompts
|
||||||
example_prompts):
|
):
|
||||||
hf_input_ids = hf_input["input_ids"].tolist()[0]
|
hf_input_ids = hf_input["input_ids"].tolist()[0]
|
||||||
fixed_vllm_outputs.append(
|
fixed_vllm_outputs.append(
|
||||||
(hf_input_ids + vllm_output[0][len(hf_input_ids):],
|
(
|
||||||
prompt + vllm_output[1]))
|
hf_input_ids + vllm_output[0][len(hf_input_ids) :],
|
||||||
|
prompt + vllm_output[1],
|
||||||
|
)
|
||||||
|
)
|
||||||
return fixed_vllm_outputs
|
return fixed_vllm_outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -69,8 +75,7 @@ def test_models(
|
|||||||
enable_prompt_embeds: bool,
|
enable_prompt_embeds: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
|
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
|
||||||
pytest.skip(
|
pytest.skip(f"{backend} does not support gemma2 with full context length.")
|
||||||
f"{backend} does not support gemma2 with full context length.")
|
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
@@ -78,34 +83,35 @@ def test_models(
|
|||||||
# 5042 tokens for gemma2
|
# 5042 tokens for gemma2
|
||||||
# gemma2 has alternating sliding window size of 4096
|
# gemma2 has alternating sliding window size of 4096
|
||||||
# we need a prompt with more than 4096 tokens to test the sliding window
|
# we need a prompt with more than 4096 tokens to test the sliding window
|
||||||
prompt = "The following numbers of the sequence " + ", ".join(
|
prompt = (
|
||||||
str(i) for i in range(1024)) + " are:"
|
"The following numbers of the sequence "
|
||||||
|
+ ", ".join(str(i) for i in range(1024))
|
||||||
|
+ " are:"
|
||||||
|
)
|
||||||
example_prompts = [prompt]
|
example_prompts = [prompt]
|
||||||
|
|
||||||
with hf_runner(model) as hf_model:
|
with hf_runner(model) as hf_model:
|
||||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
if enable_prompt_embeds:
|
if enable_prompt_embeds:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
prompt_embeds = hf_model.get_prompt_embeddings(
|
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||||
example_prompts)
|
|
||||||
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
model,
|
model,
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
enable_prompt_embeds=enable_prompt_embeds,
|
enable_prompt_embeds=enable_prompt_embeds,
|
||||||
gpu_memory_utilization=0.7,
|
gpu_memory_utilization=0.7,
|
||||||
async_scheduling=async_scheduling,
|
async_scheduling=async_scheduling,
|
||||||
distributed_executor_backend=model_executor,
|
distributed_executor_backend=model_executor,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
if enable_prompt_embeds:
|
if enable_prompt_embeds:
|
||||||
vllm_outputs = vllm_model.generate_greedy(
|
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
||||||
prompt_embeds, max_tokens)
|
|
||||||
vllm_outputs = _fix_prompt_embed_outputs(
|
vllm_outputs = _fix_prompt_embed_outputs(
|
||||||
vllm_outputs, hf_model, example_prompts)
|
vllm_outputs, hf_model, example_prompts
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
vllm_outputs = vllm_model.generate_greedy(
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
example_prompts, max_tokens)
|
|
||||||
|
|
||||||
check_outputs_equal(
|
check_outputs_equal(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
@@ -117,21 +123,18 @@ def test_models(
|
|||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, distributed_executor_backend, attention_backend, "
|
"model, distributed_executor_backend, attention_backend, test_suite, extra_env",
|
||||||
"test_suite, extra_env", [
|
[
|
||||||
("distilbert/distilgpt2", "ray", "", "L4", {}),
|
("distilbert/distilgpt2", "ray", "", "L4", {}),
|
||||||
("distilbert/distilgpt2", "mp", "", "L4", {}),
|
("distilbert/distilgpt2", "mp", "", "L4", {}),
|
||||||
("distilbert/distilgpt2", "ray", "", "L4", {
|
("distilbert/distilgpt2", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
|
||||||
"VLLM_SLEEP_WHEN_IDLE": "1"
|
("distilbert/distilgpt2", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
|
||||||
}),
|
|
||||||
("distilbert/distilgpt2", "mp", "", "L4", {
|
|
||||||
"VLLM_SLEEP_WHEN_IDLE": "1"
|
|
||||||
}),
|
|
||||||
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
|
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
|
||||||
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
|
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
|
||||||
("distilbert/distilgpt2", "ray", "", "A100", {}),
|
("distilbert/distilgpt2", "ray", "", "A100", {}),
|
||||||
("distilbert/distilgpt2", "mp", "", "A100", {}),
|
("distilbert/distilgpt2", "mp", "", "A100", {}),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||||
def test_models_distributed(
|
def test_models_distributed(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
@@ -149,11 +152,14 @@ def test_models_distributed(
|
|||||||
pytest.skip(f"Skip test for {test_suite}")
|
pytest.skip(f"Skip test for {test_suite}")
|
||||||
|
|
||||||
with monkeypatch.context() as monkeypatch_context:
|
with monkeypatch.context() as monkeypatch_context:
|
||||||
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
|
if (
|
||||||
|
model == "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
and distributed_executor_backend == "ray"
|
||||||
|
and attention_backend == ""
|
||||||
|
and test_suite == "L4"
|
||||||
|
): # noqa
|
||||||
if enable_prompt_embeds:
|
if enable_prompt_embeds:
|
||||||
pytest.skip(
|
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
|
||||||
"enable_prompt_embeds does not work with ray compiled dag."
|
|
||||||
)
|
|
||||||
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
|
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
|
||||||
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
|
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
|
||||||
|
|
||||||
@@ -175,30 +181,26 @@ def test_models_distributed(
|
|||||||
# will hurt multiprocessing backend with fork method
|
# will hurt multiprocessing backend with fork method
|
||||||
# (the default method).
|
# (the default method).
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model,
|
model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
distributed_executor_backend=distributed_executor_backend,
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
enable_prompt_embeds=enable_prompt_embeds,
|
enable_prompt_embeds=enable_prompt_embeds,
|
||||||
gpu_memory_utilization=0.7,
|
gpu_memory_utilization=0.7,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
if enable_prompt_embeds:
|
if enable_prompt_embeds:
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
prompt_embeds = hf_model.get_prompt_embeddings(
|
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||||
example_prompts)
|
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
||||||
vllm_outputs = vllm_model.generate_greedy(
|
|
||||||
prompt_embeds, max_tokens)
|
|
||||||
vllm_outputs = _fix_prompt_embed_outputs(
|
vllm_outputs = _fix_prompt_embed_outputs(
|
||||||
vllm_outputs, hf_model, example_prompts)
|
vllm_outputs, hf_model, example_prompts
|
||||||
hf_outputs = hf_model.generate_greedy(
|
)
|
||||||
example_prompts, max_tokens)
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
else:
|
else:
|
||||||
vllm_outputs = vllm_model.generate_greedy(
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
example_prompts, max_tokens)
|
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
hf_outputs = hf_model.generate_greedy(
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
example_prompts, max_tokens)
|
|
||||||
|
|
||||||
check_outputs_equal(
|
check_outputs_equal(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
@@ -209,27 +211,23 @@ def test_models_distributed(
|
|||||||
|
|
||||||
|
|
||||||
def test_failed_model_execution(vllm_runner, monkeypatch) -> None:
|
def test_failed_model_execution(vllm_runner, monkeypatch) -> None:
|
||||||
|
|
||||||
from vllm.envs import VLLM_USE_V1
|
from vllm.envs import VLLM_USE_V1
|
||||||
|
|
||||||
if not VLLM_USE_V1:
|
if not VLLM_USE_V1:
|
||||||
pytest.skip("Skipping V0 test, dump input not supported")
|
pytest.skip("Skipping V0 test, dump input not supported")
|
||||||
|
|
||||||
# Needed to mock an error in the same process
|
# Needed to mock an error in the same process
|
||||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
|
|
||||||
with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model:
|
with vllm_runner("facebook/opt-125m", enforce_eager=True) as vllm_model:
|
||||||
if isinstance(vllm_model.llm.llm_engine, LLMEngineV1):
|
if isinstance(vllm_model.llm.llm_engine, LLMEngineV1):
|
||||||
v1_test_failed_model_execution(vllm_model)
|
v1_test_failed_model_execution(vllm_model)
|
||||||
|
|
||||||
|
|
||||||
def v1_test_failed_model_execution(vllm_model):
|
def v1_test_failed_model_execution(vllm_model):
|
||||||
|
|
||||||
engine = vllm_model.llm.llm_engine
|
engine = vllm_model.llm.llm_engine
|
||||||
mocked_execute_model = Mock(
|
mocked_execute_model = Mock(side_effect=RuntimeError("Mocked Critical Error"))
|
||||||
side_effect=RuntimeError("Mocked Critical Error"))
|
engine.engine_core.engine_core.model_executor.execute_model = mocked_execute_model
|
||||||
engine.engine_core.engine_core.model_executor.execute_model =\
|
|
||||||
mocked_execute_model
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as exc_info:
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
prompts = [
|
prompts = [
|
||||||
|
|||||||
@@ -5,5 +5,6 @@ from ..utils import compare_two_settings
|
|||||||
|
|
||||||
|
|
||||||
def test_cpu_offload():
|
def test_cpu_offload():
|
||||||
compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", [],
|
compare_two_settings(
|
||||||
["--cpu-offload-gb", "1"])
|
"meta-llama/Llama-3.2-1B-Instruct", [], ["--cpu-offload-gb", "1"]
|
||||||
|
)
|
||||||
|
|||||||
@@ -23,13 +23,13 @@ def test_python_error():
|
|||||||
tensors = []
|
tensors = []
|
||||||
with allocator.use_memory_pool():
|
with allocator.use_memory_pool():
|
||||||
# allocate 70% of the total memory
|
# allocate 70% of the total memory
|
||||||
x = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda')
|
x = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda")
|
||||||
tensors.append(x)
|
tensors.append(x)
|
||||||
# release the memory
|
# release the memory
|
||||||
allocator.sleep()
|
allocator.sleep()
|
||||||
|
|
||||||
# allocate more memory than the total memory
|
# allocate more memory than the total memory
|
||||||
y = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda')
|
y = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda")
|
||||||
tensors.append(y)
|
tensors.append(y)
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
# when the allocator is woken up, it should raise an error
|
# when the allocator is woken up, it should raise an error
|
||||||
@@ -41,17 +41,17 @@ def test_python_error():
|
|||||||
def test_basic_cumem():
|
def test_basic_cumem():
|
||||||
# some tensors from default memory pool
|
# some tensors from default memory pool
|
||||||
shape = (1024, 1024)
|
shape = (1024, 1024)
|
||||||
x = torch.empty(shape, device='cuda')
|
x = torch.empty(shape, device="cuda")
|
||||||
x.zero_()
|
x.zero_()
|
||||||
|
|
||||||
# some tensors from custom memory pool
|
# some tensors from custom memory pool
|
||||||
allocator = CuMemAllocator.get_instance()
|
allocator = CuMemAllocator.get_instance()
|
||||||
with allocator.use_memory_pool():
|
with allocator.use_memory_pool():
|
||||||
# custom memory pool
|
# custom memory pool
|
||||||
y = torch.empty(shape, device='cuda')
|
y = torch.empty(shape, device="cuda")
|
||||||
y.zero_()
|
y.zero_()
|
||||||
y += 1
|
y += 1
|
||||||
z = torch.empty(shape, device='cuda')
|
z = torch.empty(shape, device="cuda")
|
||||||
z.zero_()
|
z.zero_()
|
||||||
z += 2
|
z += 2
|
||||||
|
|
||||||
@@ -74,16 +74,16 @@ def test_basic_cumem():
|
|||||||
def test_cumem_with_cudagraph():
|
def test_cumem_with_cudagraph():
|
||||||
allocator = CuMemAllocator.get_instance()
|
allocator = CuMemAllocator.get_instance()
|
||||||
with allocator.use_memory_pool():
|
with allocator.use_memory_pool():
|
||||||
weight = torch.eye(1024, device='cuda')
|
weight = torch.eye(1024, device="cuda")
|
||||||
with allocator.use_memory_pool(tag="discard"):
|
with allocator.use_memory_pool(tag="discard"):
|
||||||
cache = torch.empty(1024, 1024, device='cuda')
|
cache = torch.empty(1024, 1024, device="cuda")
|
||||||
|
|
||||||
def model(x):
|
def model(x):
|
||||||
out = x @ weight
|
out = x @ weight
|
||||||
cache[:out.size(0)].copy_(out)
|
cache[: out.size(0)].copy_(out)
|
||||||
return out + 1
|
return out + 1
|
||||||
|
|
||||||
x = torch.empty(128, 1024, device='cuda')
|
x = torch.empty(128, 1024, device="cuda")
|
||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
model(x)
|
model(x)
|
||||||
@@ -109,7 +109,7 @@ def test_cumem_with_cudagraph():
|
|||||||
model_graph.replay()
|
model_graph.replay()
|
||||||
|
|
||||||
# cache content is as expected
|
# cache content is as expected
|
||||||
assert torch.allclose(x, cache[:x.size(0)])
|
assert torch.allclose(x, cache[: x.size(0)])
|
||||||
|
|
||||||
# output content is as expected
|
# output content is as expected
|
||||||
assert torch.allclose(y, x + 1)
|
assert torch.allclose(y, x + 1)
|
||||||
@@ -123,7 +123,8 @@ def test_cumem_with_cudagraph():
|
|||||||
("meta-llama/Llama-3.2-1B", True),
|
("meta-llama/Llama-3.2-1B", True),
|
||||||
# sleep mode with pytorch checkpoint
|
# sleep mode with pytorch checkpoint
|
||||||
("facebook/opt-125m", True),
|
("facebook/opt-125m", True),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool):
|
def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool):
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
assert use_v1
|
assert use_v1
|
||||||
|
|||||||
@@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
|||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_bench_latency():
|
def test_bench_latency():
|
||||||
command = [
|
command = [
|
||||||
"vllm", "bench", "latency", "--model", MODEL_NAME, "--input-len", "32",
|
"vllm",
|
||||||
"--output-len", "1", "--enforce-eager", "--load-format", "dummy"
|
"bench",
|
||||||
|
"latency",
|
||||||
|
"--model",
|
||||||
|
MODEL_NAME,
|
||||||
|
"--input-len",
|
||||||
|
"32",
|
||||||
|
"--output-len",
|
||||||
|
"1",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--load-format",
|
||||||
|
"dummy",
|
||||||
]
|
]
|
||||||
result = subprocess.run(command, capture_output=True, text=True)
|
result = subprocess.run(command, capture_output=True, text=True)
|
||||||
print(result.stdout)
|
print(result.stdout)
|
||||||
|
|||||||
@@ -7,8 +7,11 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset,
|
from vllm.benchmarks.datasets import (
|
||||||
SampleRequest)
|
RandomDataset,
|
||||||
|
RandomMultiModalDataset,
|
||||||
|
SampleRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@@ -27,11 +30,9 @@ class Params(NamedTuple):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def random_dataset_params() -> Params:
|
def random_dataset_params() -> Params:
|
||||||
return Params(num_requests=16,
|
return Params(
|
||||||
prefix_len=7,
|
num_requests=16, prefix_len=7, range_ratio=0.3, input_len=50, output_len=20
|
||||||
range_ratio=0.3,
|
)
|
||||||
input_len=50,
|
|
||||||
output_len=20)
|
|
||||||
|
|
||||||
|
|
||||||
def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
|
def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
|
||||||
@@ -39,13 +40,15 @@ def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
|
|||||||
return (req.prompt, req.prompt_len, req.expected_output_len)
|
return (req.prompt, req.prompt_len, req.expected_output_len)
|
||||||
|
|
||||||
|
|
||||||
def _collect_samples(dataset: RandomDataset,
|
def _collect_samples(
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
dataset: RandomDataset,
|
||||||
num_requests: int = 16,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
prefix_len: int = 7,
|
num_requests: int = 16,
|
||||||
range_ratio: float = 0.3,
|
prefix_len: int = 7,
|
||||||
input_len: int = 50,
|
range_ratio: float = 0.3,
|
||||||
output_len: int = 20) -> list[tuple[str, int, int]]:
|
input_len: int = 50,
|
||||||
|
output_len: int = 20,
|
||||||
|
) -> list[tuple[str, int, int]]:
|
||||||
samples = dataset.sample(
|
samples = dataset.sample(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_requests=num_requests,
|
num_requests=num_requests,
|
||||||
@@ -59,8 +62,8 @@ def _collect_samples(dataset: RandomDataset,
|
|||||||
|
|
||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_random_dataset_same_seed(
|
def test_random_dataset_same_seed(
|
||||||
hf_tokenizer: PreTrainedTokenizerBase,
|
hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params
|
||||||
random_dataset_params: Params) -> None:
|
) -> None:
|
||||||
"""Same seed should yield identical outputs, even if global RNGs change.
|
"""Same seed should yield identical outputs, even if global RNGs change.
|
||||||
|
|
||||||
This guards against accidental reliance on Python's random or np.random
|
This guards against accidental reliance on Python's random or np.random
|
||||||
@@ -70,13 +73,15 @@ def test_random_dataset_same_seed(
|
|||||||
common_seed = 123
|
common_seed = 123
|
||||||
dataset_a = RandomDataset(random_seed=common_seed)
|
dataset_a = RandomDataset(random_seed=common_seed)
|
||||||
dataset_b = RandomDataset(random_seed=common_seed)
|
dataset_b = RandomDataset(random_seed=common_seed)
|
||||||
a = _collect_samples(dataset_a,
|
a = _collect_samples(
|
||||||
hf_tokenizer,
|
dataset_a,
|
||||||
num_requests=p.num_requests,
|
hf_tokenizer,
|
||||||
prefix_len=p.prefix_len,
|
num_requests=p.num_requests,
|
||||||
range_ratio=p.range_ratio,
|
prefix_len=p.prefix_len,
|
||||||
input_len=p.input_len,
|
range_ratio=p.range_ratio,
|
||||||
output_len=p.output_len)
|
input_len=p.input_len,
|
||||||
|
output_len=p.output_len,
|
||||||
|
)
|
||||||
|
|
||||||
# Perturb global RNG state to ensure isolation
|
# Perturb global RNG state to ensure isolation
|
||||||
random.seed(999)
|
random.seed(999)
|
||||||
@@ -84,43 +89,50 @@ def test_random_dataset_same_seed(
|
|||||||
np.random.seed(888)
|
np.random.seed(888)
|
||||||
_ = [np.random.random() for _ in range(100)]
|
_ = [np.random.random() for _ in range(100)]
|
||||||
|
|
||||||
b = _collect_samples(dataset_b,
|
b = _collect_samples(
|
||||||
hf_tokenizer,
|
dataset_b,
|
||||||
num_requests=p.num_requests,
|
hf_tokenizer,
|
||||||
prefix_len=p.prefix_len,
|
num_requests=p.num_requests,
|
||||||
range_ratio=p.range_ratio,
|
prefix_len=p.prefix_len,
|
||||||
input_len=p.input_len,
|
range_ratio=p.range_ratio,
|
||||||
output_len=p.output_len)
|
input_len=p.input_len,
|
||||||
|
output_len=p.output_len,
|
||||||
|
)
|
||||||
assert a == b
|
assert a == b
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_random_dataset_different_seeds(
|
def test_random_dataset_different_seeds(
|
||||||
hf_tokenizer: PreTrainedTokenizerBase,
|
hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params
|
||||||
random_dataset_params: Params) -> None:
|
) -> None:
|
||||||
"""Different seeds should change outputs with overwhelming likelihood."""
|
"""Different seeds should change outputs with overwhelming likelihood."""
|
||||||
p = random_dataset_params
|
p = random_dataset_params
|
||||||
seed_a = 0
|
seed_a = 0
|
||||||
dataset_a = RandomDataset(random_seed=seed_a)
|
dataset_a = RandomDataset(random_seed=seed_a)
|
||||||
a = _collect_samples(dataset_a,
|
a = _collect_samples(
|
||||||
hf_tokenizer,
|
dataset_a,
|
||||||
num_requests=p.num_requests,
|
hf_tokenizer,
|
||||||
prefix_len=p.prefix_len,
|
num_requests=p.num_requests,
|
||||||
range_ratio=p.range_ratio,
|
prefix_len=p.prefix_len,
|
||||||
input_len=p.input_len,
|
range_ratio=p.range_ratio,
|
||||||
output_len=p.output_len)
|
input_len=p.input_len,
|
||||||
|
output_len=p.output_len,
|
||||||
|
)
|
||||||
|
|
||||||
seed_b = 999
|
seed_b = 999
|
||||||
dataset_b = RandomDataset(random_seed=seed_b)
|
dataset_b = RandomDataset(random_seed=seed_b)
|
||||||
# Perturb global RNG with same seed as dataset_a to ensure isolation
|
# Perturb global RNG with same seed as dataset_a to ensure isolation
|
||||||
random.seed(seed_a)
|
random.seed(seed_a)
|
||||||
np.random.seed(seed_a)
|
np.random.seed(seed_a)
|
||||||
b = _collect_samples(dataset_b,
|
b = _collect_samples(
|
||||||
hf_tokenizer,
|
dataset_b,
|
||||||
num_requests=p.num_requests,
|
hf_tokenizer,
|
||||||
prefix_len=p.prefix_len,
|
num_requests=p.num_requests,
|
||||||
range_ratio=p.range_ratio,
|
prefix_len=p.prefix_len,
|
||||||
input_len=p.input_len,
|
range_ratio=p.range_ratio,
|
||||||
output_len=p.output_len)
|
input_len=p.input_len,
|
||||||
|
output_len=p.output_len,
|
||||||
|
)
|
||||||
assert a != b
|
assert a != b
|
||||||
|
|
||||||
|
|
||||||
@@ -128,6 +140,7 @@ def test_random_dataset_different_seeds(
|
|||||||
# RandomMultiModalDataset tests
|
# RandomMultiModalDataset tests
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
def _mm_fingerprint_sample(
|
def _mm_fingerprint_sample(
|
||||||
req: SampleRequest,
|
req: SampleRequest,
|
||||||
) -> tuple[str, int, int, int, list[str]]:
|
) -> tuple[str, int, int, int, list[str]]:
|
||||||
@@ -152,8 +165,13 @@ def _mm_fingerprint_sample(
|
|||||||
item_prefixes.append(f"video:{url[:22]}")
|
item_prefixes.append(f"video:{url[:22]}")
|
||||||
else:
|
else:
|
||||||
item_prefixes.append("unknown:")
|
item_prefixes.append("unknown:")
|
||||||
return (req.prompt, req.prompt_len, req.expected_output_len, len(items),
|
return (
|
||||||
item_prefixes)
|
req.prompt,
|
||||||
|
req.prompt_len,
|
||||||
|
req.expected_output_len,
|
||||||
|
len(items),
|
||||||
|
item_prefixes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _collect_mm_samples(
|
def _collect_mm_samples(
|
||||||
@@ -214,6 +232,7 @@ def test_random_mm_different_seeds(
|
|||||||
fb = [_mm_fingerprint_sample(s) for s in b]
|
fb = [_mm_fingerprint_sample(s) for s in b]
|
||||||
assert fa != fb
|
assert fa != fb
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_random_mm_respects_limits(
|
def test_random_mm_respects_limits(
|
||||||
hf_tokenizer: PreTrainedTokenizerBase,
|
hf_tokenizer: PreTrainedTokenizerBase,
|
||||||
@@ -271,9 +290,9 @@ def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
|||||||
for s in samples:
|
for s in samples:
|
||||||
assert s.multi_modal_data == []
|
assert s.multi_modal_data == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_random_mm_num_items_per_prompt(
|
def test_random_mm_num_items_per_prompt(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||||
hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
|
||||||
ds = RandomMultiModalDataset(random_seed=0)
|
ds = RandomMultiModalDataset(random_seed=0)
|
||||||
# Fixed number of images per prompt
|
# Fixed number of images per prompt
|
||||||
# set num_mm_items_range_ratio to 0.0
|
# set num_mm_items_range_ratio to 0.0
|
||||||
@@ -300,7 +319,6 @@ def test_random_mm_num_items_per_prompt(
|
|||||||
def test_random_mm_bucket_config_not_mutated(
|
def test_random_mm_bucket_config_not_mutated(
|
||||||
hf_tokenizer: PreTrainedTokenizerBase,
|
hf_tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
ds = RandomMultiModalDataset(random_seed=0)
|
ds = RandomMultiModalDataset(random_seed=0)
|
||||||
# This bucket config is not normalized to sum to 1
|
# This bucket config is not normalized to sum to 1
|
||||||
# and has more buckets than requested images
|
# and has more buckets than requested images
|
||||||
@@ -321,7 +339,6 @@ def test_random_mm_bucket_config_not_mutated(
|
|||||||
# Ensure the original dict content is unchanged
|
# Ensure the original dict content is unchanged
|
||||||
assert original == snapshot
|
assert original == snapshot
|
||||||
|
|
||||||
|
|
||||||
# Vary number of mm items per prompt
|
# Vary number of mm items per prompt
|
||||||
# set num_mm_items_range_ratio to 0.5
|
# set num_mm_items_range_ratio to 0.5
|
||||||
samples_varying_items = _collect_mm_samples(
|
samples_varying_items = _collect_mm_samples(
|
||||||
|
|||||||
@@ -11,9 +11,7 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server():
|
def server():
|
||||||
args = [
|
args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"]
|
||||||
"--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"
|
|
||||||
]
|
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
yield remote_server
|
yield remote_server
|
||||||
@@ -46,6 +44,7 @@ def test_bench_serve(server):
|
|||||||
|
|
||||||
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_bench_serve_chat(server):
|
def test_bench_serve_chat(server):
|
||||||
command = [
|
command = [
|
||||||
|
|||||||
@@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
|||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_bench_throughput():
|
def test_bench_throughput():
|
||||||
command = [
|
command = [
|
||||||
"vllm", "bench", "throughput", "--model", MODEL_NAME, "--input-len",
|
"vllm",
|
||||||
"32", "--output-len", "1", "--enforce-eager", "--load-format", "dummy"
|
"bench",
|
||||||
|
"throughput",
|
||||||
|
"--model",
|
||||||
|
MODEL_NAME,
|
||||||
|
"--input-len",
|
||||||
|
"32",
|
||||||
|
"--output-len",
|
||||||
|
"1",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--load-format",
|
||||||
|
"dummy",
|
||||||
]
|
]
|
||||||
result = subprocess.run(command, capture_output=True, text=True)
|
result = subprocess.run(command, capture_output=True, text=True)
|
||||||
print(result.stdout)
|
print(result.stdout)
|
||||||
|
|||||||
@@ -23,8 +23,7 @@ class LazyInitPass(InductorPass):
|
|||||||
and then immediately invoke it.
|
and then immediately invoke it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pass_cls: type[VllmInductorPass],
|
def __init__(self, pass_cls: type[VllmInductorPass], vllm_config: VllmConfig):
|
||||||
vllm_config: VllmConfig):
|
|
||||||
self.pass_cls = pass_cls
|
self.pass_cls = pass_cls
|
||||||
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
|
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
|
||||||
|
|
||||||
@@ -45,20 +44,18 @@ class TestBackend:
|
|||||||
Inductor config is default-initialized from VllmConfig.CompilationConfig.
|
Inductor config is default-initialized from VllmConfig.CompilationConfig.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
|
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]):
|
||||||
None]]):
|
|
||||||
self.custom_passes = list(passes)
|
self.custom_passes = list(passes)
|
||||||
compile_config = get_current_vllm_config().compilation_config
|
compile_config = get_current_vllm_config().compilation_config
|
||||||
self.inductor_config = compile_config.inductor_compile_config
|
self.inductor_config = compile_config.inductor_compile_config
|
||||||
self.inductor_config['force_disable_caches'] = True
|
self.inductor_config["force_disable_caches"] = True
|
||||||
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
|
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
|
||||||
|
|
||||||
def __call__(self, graph: fx.GraphModule, example_inputs):
|
def __call__(self, graph: fx.GraphModule, example_inputs):
|
||||||
self.graph_pre_compile = deepcopy(graph)
|
self.graph_pre_compile = deepcopy(graph)
|
||||||
from torch._inductor.compile_fx import compile_fx
|
from torch._inductor.compile_fx import compile_fx
|
||||||
return compile_fx(graph,
|
|
||||||
example_inputs,
|
return compile_fx(graph, example_inputs, config_patches=self.inductor_config)
|
||||||
config_patches=self.inductor_config)
|
|
||||||
|
|
||||||
@with_pattern_match_debug
|
@with_pattern_match_debug
|
||||||
def post_pass(self, graph: fx.Graph):
|
def post_pass(self, graph: fx.Graph):
|
||||||
@@ -82,8 +79,7 @@ class TestBackend:
|
|||||||
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
|
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
|
||||||
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
|
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
|
||||||
if fully_replaced:
|
if fully_replaced:
|
||||||
assert num_post == 0, \
|
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
|
||||||
f"Unexpected op {op.name()} in post-pass graph"
|
|
||||||
|
|
||||||
def check_after_ops(self, ops: Sequence[OpOverload]):
|
def check_after_ops(self, ops: Sequence[OpOverload]):
|
||||||
for op in ops:
|
for op in ops:
|
||||||
|
|||||||
@@ -38,8 +38,8 @@ test_params_full_cudagraph = []
|
|||||||
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
|
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
|
||||||
for mla_backend in MLA_backends:
|
for mla_backend in MLA_backends:
|
||||||
test_params_full_cudagraph.append(
|
test_params_full_cudagraph.append(
|
||||||
pytest.param(
|
pytest.param(("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]))
|
||||||
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))
|
)
|
||||||
|
|
||||||
# Qwen/Qwen2-1.5B-Instruct with other backends
|
# Qwen/Qwen2-1.5B-Instruct with other backends
|
||||||
other_backend_configs = [
|
other_backend_configs = [
|
||||||
@@ -47,7 +47,8 @@ other_backend_configs = [
|
|||||||
]
|
]
|
||||||
for backend_config in other_backend_configs:
|
for backend_config in other_backend_configs:
|
||||||
test_params_full_cudagraph.append(
|
test_params_full_cudagraph.append(
|
||||||
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))
|
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="class")
|
@pytest.fixture(scope="class")
|
||||||
@@ -55,8 +56,10 @@ def llm_pair(request):
|
|||||||
model, backend_config = request.param
|
model, backend_config = request.param
|
||||||
|
|
||||||
# Dynamically skip test if GPU capability is not met
|
# Dynamically skip test if GPU capability is not met
|
||||||
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
|
if (
|
||||||
!= current_platform.get_device_capability():
|
backend_config.specific_gpu_arch
|
||||||
|
and backend_config.specific_gpu_arch != current_platform.get_device_capability()
|
||||||
|
):
|
||||||
if backend_config.specific_gpu_arch == (9, 0):
|
if backend_config.specific_gpu_arch == (9, 0):
|
||||||
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
||||||
elif backend_config.specific_gpu_arch == (10, 0):
|
elif backend_config.specific_gpu_arch == (10, 0):
|
||||||
@@ -76,8 +79,7 @@ def llm_pair(request):
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
max_num_seqs=128,
|
max_num_seqs=128,
|
||||||
compilation_config=\
|
compilation_config=CompilationConfig(**backend_config.comp_config),
|
||||||
CompilationConfig(**backend_config.comp_config),
|
|
||||||
generation_config="vllm",
|
generation_config="vllm",
|
||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
@@ -113,20 +115,22 @@ class TestFullCUDAGraph:
|
|||||||
meaning there would be multiple LLM instances hogging memory simultaneously.
|
meaning there would be multiple LLM instances hogging memory simultaneously.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
|
@pytest.mark.parametrize(
|
||||||
(1, 10),
|
("batch_size", "max_tokens"),
|
||||||
(7, 10),
|
[
|
||||||
(16, 10),
|
(1, 10),
|
||||||
(25, 10),
|
(7, 10),
|
||||||
(32, 10),
|
(16, 10),
|
||||||
(45, 10),
|
(25, 10),
|
||||||
(64, 10),
|
(32, 10),
|
||||||
(123, 10),
|
(45, 10),
|
||||||
(8, 5),
|
(64, 10),
|
||||||
(8, 30),
|
(123, 10),
|
||||||
])
|
(8, 5),
|
||||||
def test_full_cudagraph(self, batch_size, max_tokens,
|
(8, 30),
|
||||||
llm_pair: tuple[LLM, LLM]):
|
],
|
||||||
|
)
|
||||||
|
def test_full_cudagraph(self, batch_size, max_tokens, llm_pair: tuple[LLM, LLM]):
|
||||||
"""
|
"""
|
||||||
Test various batch sizes and max_tokens to ensure that the
|
Test various batch sizes and max_tokens to ensure that the
|
||||||
full cudagraph compilation works for padded cases too.
|
full cudagraph compilation works for padded cases too.
|
||||||
@@ -137,26 +141,34 @@ class TestFullCUDAGraph:
|
|||||||
prompts = ["the quick brown fox"] * batch_size
|
prompts = ["the quick brown fox"] * batch_size
|
||||||
# Use purely greedy decoding to avoid top-p truncation sensitivity
|
# Use purely greedy decoding to avoid top-p truncation sensitivity
|
||||||
# that can amplify tiny numeric differences across runtimes.
|
# that can amplify tiny numeric differences across runtimes.
|
||||||
sampling_params = SamplingParams(temperature=0.0,
|
sampling_params = SamplingParams(
|
||||||
max_tokens=max_tokens,
|
temperature=0.0, max_tokens=max_tokens, top_p=1.0
|
||||||
top_p=1.0)
|
)
|
||||||
|
|
||||||
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
||||||
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
# Check that all responses are the same
|
# Check that all responses are the same
|
||||||
for piecewise_res, full_res in zip(piecewise_responses,
|
for piecewise_res, full_res in zip(piecewise_responses, full_responses):
|
||||||
full_responses):
|
assert (
|
||||||
assert piecewise_res.outputs[0].text.lower() == \
|
piecewise_res.outputs[0].text.lower()
|
||||||
full_res.outputs[0].text.lower()
|
== full_res.outputs[0].text.lower()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||||
def test_full_cudagraph_with_invalid_backend():
|
def test_full_cudagraph_with_invalid_backend():
|
||||||
with temporary_environ({
|
with (
|
||||||
"VLLM_USE_V1": "1",
|
temporary_environ(
|
||||||
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
|
{
|
||||||
# Flex_Attention is not supported with full cuda graph
|
"VLLM_USE_V1": "1",
|
||||||
}), pytest.raises(RuntimeError):
|
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION",
|
||||||
LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
# Flex_Attention is not supported with full cuda graph
|
||||||
compilation_config=CompilationConfig(cudagraph_mode="FULL"))
|
}
|
||||||
|
),
|
||||||
|
pytest.raises(RuntimeError),
|
||||||
|
):
|
||||||
|
LLM(
|
||||||
|
model="Qwen/Qwen2-1.5B-Instruct",
|
||||||
|
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
|
||||||
|
)
|
||||||
|
|||||||
@@ -10,10 +10,14 @@ from torch import nn
|
|||||||
|
|
||||||
from vllm.compilation.backends import set_model_tag
|
from vllm.compilation.backends import set_model_tag
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
|
||||||
support_torch_compile)
|
from vllm.config import (
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
CompilationConfig,
|
||||||
VllmConfig, set_current_vllm_config)
|
CompilationLevel,
|
||||||
|
CUDAGraphMode,
|
||||||
|
VllmConfig,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
|
|
||||||
# This import automatically registers `torch.ops.silly.attention`
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
@@ -27,12 +31,7 @@ RANDOM_SEED = 0
|
|||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class ParentModel(nn.Module):
|
class ParentModel(nn.Module):
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||||
def __init__(self,
|
|
||||||
*,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
prefix: str = '',
|
|
||||||
**kwargs) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -40,7 +39,6 @@ class ParentModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
|
|
||||||
def __init__(self, mlp_size: int, hidden_size: int) -> None:
|
def __init__(self, mlp_size: int, hidden_size: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False)
|
self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False)
|
||||||
@@ -51,17 +49,21 @@ class Attention(nn.Module):
|
|||||||
nn.init.xavier_normal_(
|
nn.init.xavier_normal_(
|
||||||
self.pre_attn.weight.data,
|
self.pre_attn.weight.data,
|
||||||
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
||||||
gain=0.001)
|
gain=0.001,
|
||||||
|
)
|
||||||
nn.init.xavier_normal_(
|
nn.init.xavier_normal_(
|
||||||
self.post_attn.weight.data,
|
self.post_attn.weight.data,
|
||||||
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
||||||
gain=0.001)
|
gain=0.001,
|
||||||
|
)
|
||||||
|
|
||||||
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
|
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x_f32 = x.float()
|
x_f32 = x.float()
|
||||||
return (x_f32 * torch.rsqrt(
|
return (
|
||||||
torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) *
|
x_f32
|
||||||
self.rms_norm_weight).to(x.dtype)
|
* torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6)
|
||||||
|
* self.rms_norm_weight
|
||||||
|
).to(x.dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = self.pre_attn(x)
|
x = self.pre_attn(x)
|
||||||
@@ -76,14 +78,15 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class CompiledAttention(nn.Module):
|
class CompiledAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self,
|
||||||
*,
|
*,
|
||||||
mlp_size: int,
|
mlp_size: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
prefix: str = '',
|
prefix: str = "",
|
||||||
**kwargs) -> None:
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attn = Attention(mlp_size, hidden_size)
|
self.attn = Attention(mlp_size, hidden_size)
|
||||||
|
|
||||||
@@ -93,21 +96,21 @@ class CompiledAttention(nn.Module):
|
|||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class CompiledAttentionTwo(CompiledAttention):
|
class CompiledAttentionTwo(CompiledAttention):
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self.attn(x) + x
|
return self.attn(x) + x
|
||||||
|
|
||||||
|
|
||||||
@ignore_torch_compile
|
@ignore_torch_compile
|
||||||
class SimpleModelWithTwoGraphs(ParentModel):
|
class SimpleModelWithTwoGraphs(ParentModel):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self,
|
||||||
*,
|
*,
|
||||||
mlp_size: int,
|
mlp_size: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
prefix: str = '',
|
prefix: str = "",
|
||||||
**kwargs) -> None:
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
# Test will fail without set_model_tag here with error:
|
# Test will fail without set_model_tag here with error:
|
||||||
# "ValueError: too many values to unpack (expected 3)"
|
# "ValueError: too many values to unpack (expected 3)"
|
||||||
@@ -142,32 +145,45 @@ class SimpleModelWithTwoGraphs(ParentModel):
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor,
|
def run_model(
|
||||||
cudagraph_runtime_mode: CUDAGraphMode):
|
vllm_config: VllmConfig,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
cudagraph_runtime_mode: CUDAGraphMode,
|
||||||
|
):
|
||||||
with set_forward_context({}, vllm_config=vllm_config):
|
with set_forward_context({}, vllm_config=vllm_config):
|
||||||
# warmup for the model with cudagraph_mode NONE
|
# warmup for the model with cudagraph_mode NONE
|
||||||
model(inputs)
|
model(inputs)
|
||||||
|
|
||||||
# simulate cudagraphs capturing
|
# simulate cudagraphs capturing
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
vllm_config=vllm_config,
|
{},
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
vllm_config=vllm_config,
|
||||||
batch_descriptor=BatchDescriptor(
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
num_tokens=2, )):
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(inputs[:2])
|
model(inputs[:2])
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
vllm_config=vllm_config,
|
{},
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
vllm_config=vllm_config,
|
||||||
batch_descriptor=BatchDescriptor(
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
num_tokens=1, )):
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=1,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(inputs[:1])
|
model(inputs[:1])
|
||||||
|
|
||||||
# simulate cudagraphs replay
|
# simulate cudagraphs replay
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
vllm_config=vllm_config,
|
{},
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
vllm_config=vllm_config,
|
||||||
batch_descriptor=BatchDescriptor(
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
num_tokens=2, )):
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
output = model(inputs[:2])
|
output = model(inputs[:2])
|
||||||
|
|
||||||
output = output.cpu()
|
output = output.cpu()
|
||||||
@@ -178,82 +194,104 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
# piecewise compile
|
# piecewise compile
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
compilation_config=CompilationConfig(
|
||||||
use_cudagraph=True,
|
level=CompilationLevel.PIECEWISE,
|
||||||
splitting_ops=["silly.attention"],
|
use_cudagraph=True,
|
||||||
cudagraph_capture_sizes=[1, 2],
|
splitting_ops=["silly.attention"],
|
||||||
))
|
cudagraph_capture_sizes=[1, 2],
|
||||||
|
)
|
||||||
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
model = (
|
||||||
hidden_size=HIDDEN_SIZE,
|
SimpleModelWithTwoGraphs(
|
||||||
vllm_config=vllm_config,
|
mlp_size=MLP_SIZE,
|
||||||
prefix='').eval().cuda()
|
hidden_size=HIDDEN_SIZE,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix="",
|
||||||
|
)
|
||||||
|
.eval()
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
|
|
||||||
# Pre-allocate memory for CUDAGraph which expects
|
# Pre-allocate memory for CUDAGraph which expects
|
||||||
# static tensor addresses
|
# static tensor addresses
|
||||||
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
|
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=2, # two graphs for the model
|
num_graphs_seen=2, # two graphs for the model
|
||||||
num_piecewise_graphs_seen=6,
|
num_piecewise_graphs_seen=6,
|
||||||
# attn_one, attn_two each has 3 piecewise graphs
|
# attn_one, attn_two each has 3 piecewise graphs
|
||||||
# (pre attn, post attn, silly_attention) each
|
# (pre attn, post attn, silly_attention) each
|
||||||
num_piecewise_capturable_graphs_seen=4,
|
num_piecewise_capturable_graphs_seen=4,
|
||||||
# attn_one, attn_two has pre attn and post attn each, total=4
|
# attn_one, attn_two has pre attn and post attn each, total=4
|
||||||
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
|
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
|
||||||
num_cudagraph_captured=8,
|
num_cudagraph_captured=8,
|
||||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
|
||||||
|
|
||||||
# no compile or cudagraph
|
# no compile or cudagraph
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
level=CompilationLevel.NO_COMPILATION, ))
|
compilation_config=CompilationConfig(
|
||||||
|
level=CompilationLevel.NO_COMPILATION,
|
||||||
|
)
|
||||||
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
model = (
|
||||||
hidden_size=HIDDEN_SIZE,
|
SimpleModelWithTwoGraphs(
|
||||||
vllm_config=vllm_config,
|
mlp_size=MLP_SIZE,
|
||||||
prefix='').eval().cuda()
|
hidden_size=HIDDEN_SIZE,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix="",
|
||||||
|
)
|
||||||
|
.eval()
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=0,
|
num_graphs_seen=0,
|
||||||
num_piecewise_graphs_seen=0,
|
num_piecewise_graphs_seen=0,
|
||||||
num_piecewise_capturable_graphs_seen=0,
|
num_piecewise_capturable_graphs_seen=0,
|
||||||
num_backend_compilations=0,
|
num_backend_compilations=0,
|
||||||
num_cudagraph_captured=0,
|
num_cudagraph_captured=0,
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
|
||||||
|
|
||||||
# piecewise compile without CUDA graph
|
# piecewise compile without CUDA graph
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
compilation_config=CompilationConfig(
|
||||||
use_cudagraph=False,
|
level=CompilationLevel.PIECEWISE,
|
||||||
splitting_ops=["silly.attention"],
|
use_cudagraph=False,
|
||||||
))
|
splitting_ops=["silly.attention"],
|
||||||
|
)
|
||||||
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
model = (
|
||||||
hidden_size=HIDDEN_SIZE,
|
SimpleModelWithTwoGraphs(
|
||||||
vllm_config=vllm_config,
|
mlp_size=MLP_SIZE,
|
||||||
prefix='').eval().cuda()
|
hidden_size=HIDDEN_SIZE,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix="",
|
||||||
|
)
|
||||||
|
.eval()
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=2,
|
num_graphs_seen=2,
|
||||||
num_piecewise_graphs_seen=6,
|
num_piecewise_graphs_seen=6,
|
||||||
num_piecewise_capturable_graphs_seen=4,
|
num_piecewise_capturable_graphs_seen=4,
|
||||||
num_backend_compilations=4,
|
num_backend_compilations=4,
|
||||||
num_cudagraph_captured=0, # no cudagraph captured
|
num_cudagraph_captured=0, # no cudagraph captured
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
|
||||||
|
|
||||||
# Generally don't expect outputs with and without inductor
|
# Generally don't expect outputs with and without inductor
|
||||||
# to be bitwise equivalent
|
# to be bitwise equivalent
|
||||||
|
|||||||
@@ -11,8 +11,13 @@ from torch import nn
|
|||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
from vllm.config import (
|
||||||
VllmConfig, set_current_vllm_config)
|
CompilationConfig,
|
||||||
|
CompilationLevel,
|
||||||
|
CUDAGraphMode,
|
||||||
|
VllmConfig,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
from vllm.envs import VLLM_USE_V1
|
from vllm.envs import VLLM_USE_V1
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.utils import is_torch_equal_or_newer
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
@@ -23,12 +28,7 @@ from ..silly_attention import get_global_counter, reset_global_counter
|
|||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class SillyModel(nn.Module):
|
class SillyModel(nn.Module):
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||||
def __init__(self,
|
|
||||||
*,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
prefix: str = '',
|
|
||||||
**kwargs) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -60,53 +60,65 @@ def _run_simple_model(
|
|||||||
expected_num_backend_compilations,
|
expected_num_backend_compilations,
|
||||||
expected_num_cudagraph_captured,
|
expected_num_cudagraph_captured,
|
||||||
):
|
):
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
compilation_config=CompilationConfig(
|
||||||
use_cudagraph=True,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_inductor=use_inductor,
|
use_cudagraph=True,
|
||||||
splitting_ops=splitting_ops,
|
use_inductor=use_inductor,
|
||||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
splitting_ops=splitting_ops,
|
||||||
cudagraph_copy_inputs=True,
|
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_copy_inputs=True,
|
||||||
))
|
cudagraph_capture_sizes=[1, 2],
|
||||||
|
)
|
||||||
|
)
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = SillyModel(vllm_config=vllm_config, prefix='')
|
model = SillyModel(vllm_config=vllm_config, prefix="")
|
||||||
|
|
||||||
inputs = torch.randn(100).cuda()
|
inputs = torch.randn(100).cuda()
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with (
|
||||||
|
compilation_counter.expect(
|
||||||
num_graphs_seen=1, # one graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
||||||
num_piecewise_capturable_graphs_seen=
|
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
|
||||||
expected_num_piecewise_capturable_graphs_seen,
|
|
||||||
num_backend_compilations=expected_num_backend_compilations,
|
num_backend_compilations=expected_num_backend_compilations,
|
||||||
num_cudagraph_captured=expected_num_cudagraph_captured,
|
num_cudagraph_captured=expected_num_cudagraph_captured,
|
||||||
), set_forward_context(None,
|
),
|
||||||
vllm_config=vllm_config): # background context
|
set_forward_context(None, vllm_config=vllm_config),
|
||||||
|
): # background context
|
||||||
# warm up with background context
|
# warm up with background context
|
||||||
model(inputs)
|
model(inputs)
|
||||||
|
|
||||||
# capturing/replaying should under context of cudagraph dispatching
|
# capturing/replaying should under context of cudagraph dispatching
|
||||||
with set_forward_context(
|
with set_forward_context(
|
||||||
None,
|
None,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(torch.randn(2).cuda())
|
model(torch.randn(2).cuda())
|
||||||
with set_forward_context(
|
with set_forward_context(
|
||||||
None,
|
None,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||||
batch_descriptor=BatchDescriptor(num_tokens=1, )):
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=1,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(torch.randn(1).cuda())
|
model(torch.randn(1).cuda())
|
||||||
|
|
||||||
input = torch.zeros(2).cuda()
|
input = torch.zeros(2).cuda()
|
||||||
reset_global_counter()
|
reset_global_counter()
|
||||||
with set_forward_context(
|
with set_forward_context(
|
||||||
None,
|
None,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
output = model(input)
|
output = model(input)
|
||||||
assert get_global_counter() == 2
|
assert get_global_counter() == 2
|
||||||
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
|
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
|
||||||
@@ -122,10 +134,8 @@ def test_simple_piecewise_compile(use_inductor):
|
|||||||
use_inductor=use_inductor,
|
use_inductor=use_inductor,
|
||||||
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
||||||
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
||||||
expected_num_backend_compilations=
|
expected_num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||||
3, # num_piecewise_capturable_graphs_seen
|
expected_num_cudagraph_captured=6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
expected_num_cudagraph_captured=
|
|
||||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -134,8 +144,7 @@ def test_simple_piecewise_compile(use_inductor):
|
|||||||
def test_simple_inductor_graph_partition(splitting_ops):
|
def test_simple_inductor_graph_partition(splitting_ops):
|
||||||
assert VLLM_USE_V1
|
assert VLLM_USE_V1
|
||||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
pytest.skip("inductor graph partition is only available "
|
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||||
"in PyTorch 2.9+")
|
|
||||||
|
|
||||||
_run_simple_model(
|
_run_simple_model(
|
||||||
# inductor graph partition automatically resets splitting_ops
|
# inductor graph partition automatically resets splitting_ops
|
||||||
@@ -143,13 +152,9 @@ def test_simple_inductor_graph_partition(splitting_ops):
|
|||||||
splitting_ops=splitting_ops,
|
splitting_ops=splitting_ops,
|
||||||
use_inductor_graph_partition=True,
|
use_inductor_graph_partition=True,
|
||||||
use_inductor=True,
|
use_inductor=True,
|
||||||
expected_num_piecewise_graphs_seen=
|
expected_num_piecewise_graphs_seen=1, # since not splitting at fx graph level
|
||||||
1, # since not splitting at fx graph level
|
expected_num_piecewise_capturable_graphs_seen=1, # since not splitting at fx graph level
|
||||||
expected_num_piecewise_capturable_graphs_seen=
|
expected_num_backend_compilations=1, # since not splitting at fx graph level
|
||||||
1, # since not splitting at fx graph level
|
expected_num_cudagraph_captured=6, # inductor graph partition still captures 6
|
||||||
expected_num_backend_compilations=
|
|
||||||
1, # since not splitting at fx graph level
|
|
||||||
expected_num_cudagraph_captured=
|
|
||||||
6, # inductor graph partition still captures 6
|
|
||||||
# graph, same as fx graph partition.
|
# graph, same as fx graph partition.
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ This is a tractable model, the weights and computation are specially designed
|
|||||||
if the config `tractable_init` is set to True. Otherwise, the weights are
|
if the config `tractable_init` is set to True. Otherwise, the weights are
|
||||||
initialized randomly with a fixed seed.
|
initialized randomly with a fixed seed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@@ -17,8 +18,13 @@ from torch import nn
|
|||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
from vllm.config import (
|
||||||
VllmConfig, set_current_vllm_config)
|
CompilationConfig,
|
||||||
|
CompilationLevel,
|
||||||
|
CUDAGraphMode,
|
||||||
|
VllmConfig,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
|
|
||||||
# This import automatically registers `torch.ops.silly.attention`
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
@@ -43,15 +49,14 @@ class LlamaConfig:
|
|||||||
factors.append((k, v))
|
factors.append((k, v))
|
||||||
factors.sort()
|
factors.sort()
|
||||||
import hashlib
|
import hashlib
|
||||||
return hashlib.md5(str(factors).encode(),
|
|
||||||
usedforsecurity=False).hexdigest()
|
return hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert self.mlp_size >= self.hidden_size
|
assert self.mlp_size >= self.hidden_size
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig) -> None:
|
def __init__(self, config: LlamaConfig) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_projection = nn.Linear(
|
self.gate_up_projection = nn.Linear(
|
||||||
@@ -66,31 +71,31 @@ class LlamaMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if config.tractable_init:
|
if config.tractable_init:
|
||||||
nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size])
|
nn.init.eye_(self.gate_up_projection.weight.data[: config.mlp_size])
|
||||||
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:])
|
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size :])
|
||||||
nn.init.eye_(self.down_projection.weight.data)
|
nn.init.eye_(self.down_projection.weight.data)
|
||||||
else:
|
else:
|
||||||
nn.init.xavier_normal_(self.gate_up_projection.weight.data,
|
nn.init.xavier_normal_(
|
||||||
generator=torch.Generator().manual_seed(
|
self.gate_up_projection.weight.data,
|
||||||
config.random_seed),
|
generator=torch.Generator().manual_seed(config.random_seed),
|
||||||
gain=0.001)
|
gain=0.001,
|
||||||
nn.init.xavier_normal_(self.down_projection.weight.data,
|
)
|
||||||
generator=torch.Generator().manual_seed(
|
nn.init.xavier_normal_(
|
||||||
config.random_seed),
|
self.down_projection.weight.data,
|
||||||
gain=0.001)
|
generator=torch.Generator().manual_seed(config.random_seed),
|
||||||
|
gain=0.001,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# for tractable_init and positive input, this is
|
# for tractable_init and positive input, this is
|
||||||
# essentially an elementwise-square
|
# essentially an elementwise-square
|
||||||
x = self.gate_up_projection(x)
|
x = self.gate_up_projection(x)
|
||||||
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
|
x = x[:, : x.size(1) // 2] * torch.nn.functional.relu(x[:, x.size(1) // 2 :])
|
||||||
x[:, x.size(1) // 2:])
|
|
||||||
x = self.down_projection(x)
|
x = self.down_projection(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LlamaAttention(nn.Module):
|
class LlamaAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig) -> None:
|
def __init__(self, config: LlamaConfig) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.qkv_projection = nn.Linear(
|
self.qkv_projection = nn.Linear(
|
||||||
@@ -106,21 +111,25 @@ class LlamaAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if config.tractable_init:
|
if config.tractable_init:
|
||||||
nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size])
|
nn.init.eye_(self.qkv_projection.weight.data[: config.hidden_size])
|
||||||
nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 *
|
nn.init.eye_(
|
||||||
config.hidden_size])
|
self.qkv_projection.weight.data[
|
||||||
nn.init.eye_(self.qkv_projection.weight.data[2 *
|
config.hidden_size : 2 * config.hidden_size
|
||||||
config.hidden_size:])
|
]
|
||||||
|
)
|
||||||
|
nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size :])
|
||||||
nn.init.eye_(self.output_projection.weight.data)
|
nn.init.eye_(self.output_projection.weight.data)
|
||||||
else:
|
else:
|
||||||
nn.init.xavier_normal_(self.qkv_projection.weight.data,
|
nn.init.xavier_normal_(
|
||||||
generator=torch.Generator().manual_seed(
|
self.qkv_projection.weight.data,
|
||||||
config.random_seed),
|
generator=torch.Generator().manual_seed(config.random_seed),
|
||||||
gain=0.001)
|
gain=0.001,
|
||||||
nn.init.xavier_normal_(self.output_projection.weight.data,
|
)
|
||||||
generator=torch.Generator().manual_seed(
|
nn.init.xavier_normal_(
|
||||||
config.random_seed),
|
self.output_projection.weight.data,
|
||||||
gain=0.001)
|
generator=torch.Generator().manual_seed(config.random_seed),
|
||||||
|
gain=0.001,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -144,7 +153,6 @@ class LlamaAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LlamaDecoderLayer(nn.Module):
|
class LlamaDecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig) -> None:
|
def __init__(self, config: LlamaConfig) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attention = LlamaAttention(config)
|
self.self_attention = LlamaAttention(config)
|
||||||
@@ -164,7 +172,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
- if residual is not None, the outputs are:
|
- if residual is not None, the outputs are:
|
||||||
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
|
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
|
||||||
- hidden_states = (residual + 1) ** 2
|
- hidden_states = (residual + 1) ** 2
|
||||||
""" # noqa
|
""" # noqa
|
||||||
if residual is None:
|
if residual is None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = hidden_states + 1
|
hidden_states = hidden_states + 1
|
||||||
@@ -173,8 +181,9 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = hidden_states + 1
|
hidden_states = hidden_states + 1
|
||||||
|
|
||||||
hidden_states = self.self_attention(positions=positions,
|
hidden_states = self.self_attention(
|
||||||
hidden_states=hidden_states)
|
positions=positions, hidden_states=hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -186,20 +195,22 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class LlamaModel(nn.Module):
|
class LlamaModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self,
|
||||||
*,
|
*,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
config: LlamaConfig,
|
config: LlamaConfig,
|
||||||
prefix: str = '',
|
prefix: str = "",
|
||||||
**kwargs) -> None:
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embedding_tokens = nn.Embedding(
|
self.embedding_tokens = nn.Embedding(
|
||||||
num_embeddings=config.vocab_size,
|
num_embeddings=config.vocab_size,
|
||||||
embedding_dim=config.hidden_size,
|
embedding_dim=config.hidden_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
|
[LlamaDecoderLayer(config) for _ in range(config.num_layers)]
|
||||||
|
)
|
||||||
|
|
||||||
# this is the initial value of the hidden states
|
# this is the initial value of the hidden states
|
||||||
self.embedding_tokens.weight.data.fill_(config.init_value)
|
self.embedding_tokens.weight.data.fill_(config.init_value)
|
||||||
@@ -216,34 +227,39 @@ class LlamaModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def tractable_computation(input_ids: torch.Tensor,
|
def tractable_computation(
|
||||||
positions: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
config: LlamaConfig,
|
positions: torch.Tensor,
|
||||||
init_value: float = 1.0) -> torch.Tensor:
|
config: LlamaConfig,
|
||||||
hidden_states = torch.ones(input_ids.size(0),
|
init_value: float = 1.0,
|
||||||
config.hidden_size,
|
) -> torch.Tensor:
|
||||||
device=input_ids.device,
|
hidden_states = (
|
||||||
dtype=input_ids.dtype) * init_value
|
torch.ones(
|
||||||
|
input_ids.size(0),
|
||||||
|
config.hidden_size,
|
||||||
|
device=input_ids.device,
|
||||||
|
dtype=input_ids.dtype,
|
||||||
|
)
|
||||||
|
* init_value
|
||||||
|
)
|
||||||
|
|
||||||
# first layer
|
# first layer
|
||||||
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
||||||
hidden_states = (residual + 1)**2
|
hidden_states = (residual + 1) ** 2
|
||||||
|
|
||||||
# following layers
|
# following layers
|
||||||
for _ in range(config.num_layers - 1):
|
for _ in range(config.num_layers - 1):
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
||||||
hidden_states = (residual + 1)**2
|
hidden_states = (residual + 1) ** 2
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def run_model(llama_config,
|
def run_model(
|
||||||
use_compile: bool,
|
llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False
|
||||||
use_inductor: bool,
|
) -> torch.Tensor:
|
||||||
split_attn: bool = False) -> torch.Tensor:
|
|
||||||
|
|
||||||
if use_compile:
|
if use_compile:
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
@@ -256,54 +272,66 @@ def run_model(llama_config,
|
|||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
else:
|
else:
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
level=CompilationLevel.NO_COMPILATION, )
|
level=CompilationLevel.NO_COMPILATION,
|
||||||
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=compilation_config,
|
vllm_config = VllmConfig(
|
||||||
additional_config=llama_config)
|
compilation_config=compilation_config, additional_config=llama_config
|
||||||
|
)
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = LlamaModel(config=llama_config,
|
model = (
|
||||||
vllm_config=vllm_config,
|
LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
|
||||||
prefix="").eval().cuda()
|
.eval()
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
|
|
||||||
with set_forward_context({},
|
with set_forward_context({}, vllm_config=vllm_config): # background context
|
||||||
vllm_config=vllm_config): # background context
|
|
||||||
B = 16 # max batch size
|
B = 16 # max batch size
|
||||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
|
||||||
positions = torch.arange(B).cuda()
|
positions = torch.arange(B).cuda()
|
||||||
|
|
||||||
# warmup for the model with cudagraph_mode NONE
|
# warmup for the model with cudagraph_mode NONE
|
||||||
model(input_ids, positions)
|
model(input_ids, positions)
|
||||||
|
|
||||||
# simulate cudagraphs capturing
|
# simulate cudagraphs capturing
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
vllm_config=vllm_config,
|
{},
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
vllm_config=vllm_config,
|
||||||
batch_descriptor=BatchDescriptor(
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
num_tokens=2, )):
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(input_ids[:2], positions[:2])
|
model(input_ids[:2], positions[:2])
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
vllm_config=vllm_config,
|
{},
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
vllm_config=vllm_config,
|
||||||
batch_descriptor=BatchDescriptor(
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
num_tokens=1, )):
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=1,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(input_ids[:1], positions[:1])
|
model(input_ids[:1], positions[:1])
|
||||||
|
|
||||||
input_ids[:2].zero_()
|
input_ids[:2].zero_()
|
||||||
# simulate cudagraphs replay
|
# simulate cudagraphs replay
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
vllm_config=vllm_config,
|
{},
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
vllm_config=vllm_config,
|
||||||
batch_descriptor=BatchDescriptor(
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
num_tokens=2, )):
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
output = model(input_ids[:2], positions[:2])
|
output = model(input_ids[:2], positions[:2])
|
||||||
|
|
||||||
output = output.cpu()
|
output = output.cpu()
|
||||||
|
|
||||||
if llama_config.tractable_init:
|
if llama_config.tractable_init:
|
||||||
expected_output = tractable_computation(input_ids[:2],
|
expected_output = tractable_computation(
|
||||||
positions[:2],
|
input_ids[:2], positions[:2], llama_config
|
||||||
llama_config).cpu()
|
).cpu()
|
||||||
|
|
||||||
assert torch.allclose(output, expected_output)
|
assert torch.allclose(output, expected_output)
|
||||||
else:
|
else:
|
||||||
@@ -314,27 +342,23 @@ def run_model(llama_config,
|
|||||||
def test_toy_llama(use_inductor: bool):
|
def test_toy_llama(use_inductor: bool):
|
||||||
# compare output with and without piecewise compilation
|
# compare output with and without piecewise compilation
|
||||||
|
|
||||||
llama_config = LlamaConfig(hidden_size=128,
|
llama_config = LlamaConfig(
|
||||||
mlp_size=256,
|
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=12
|
||||||
vocab_size=128,
|
)
|
||||||
num_layers=12)
|
|
||||||
|
|
||||||
tractable_config = LlamaConfig(hidden_size=128,
|
tractable_config = LlamaConfig(
|
||||||
mlp_size=256,
|
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
|
||||||
vocab_size=128,
|
)
|
||||||
num_layers=2,
|
|
||||||
tractable_init=True)
|
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=0,
|
num_graphs_seen=0,
|
||||||
num_piecewise_graphs_seen=0,
|
num_piecewise_graphs_seen=0,
|
||||||
num_piecewise_capturable_graphs_seen=0,
|
num_piecewise_capturable_graphs_seen=0,
|
||||||
num_backend_compilations=0,
|
num_backend_compilations=0,
|
||||||
num_cudagraph_captured=0,
|
num_cudagraph_captured=0,
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(run_model(llama_config, use_inductor=False, use_compile=False))
|
||||||
run_model(llama_config, use_inductor=False, use_compile=False))
|
|
||||||
run_model(tractable_config, use_inductor=False, use_compile=False)
|
run_model(tractable_config, use_inductor=False, use_compile=False)
|
||||||
|
|
||||||
if use_inductor:
|
if use_inductor:
|
||||||
@@ -343,41 +367,41 @@ def test_toy_llama(use_inductor: bool):
|
|||||||
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
|
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=1, # one graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
num_piecewise_graphs_seen=1,
|
num_piecewise_graphs_seen=1,
|
||||||
num_piecewise_capturable_graphs_seen=1,
|
num_piecewise_capturable_graphs_seen=1,
|
||||||
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||||
num_cudagraph_captured=
|
num_cudagraph_captured=2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
**kwargs,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(
|
||||||
run_model(llama_config,
|
run_model(llama_config, use_inductor=use_inductor, use_compile=True)
|
||||||
use_inductor=use_inductor,
|
)
|
||||||
use_compile=True))
|
|
||||||
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
|
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=1, # one graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
num_piecewise_graphs_seen=2 * llama_config.num_layers +
|
num_piecewise_graphs_seen=2 * llama_config.num_layers + 1, # 2 * num_layers + 1
|
||||||
1, # 2 * num_layers + 1
|
num_piecewise_capturable_graphs_seen=1
|
||||||
num_piecewise_capturable_graphs_seen=1 +
|
+ llama_config.num_layers, # 1 + num_layers
|
||||||
llama_config.num_layers, # 1 + num_layers
|
num_backend_compilations=1
|
||||||
num_backend_compilations=1 +
|
+ llama_config.num_layers, # num_piecewise_capturable_graphs_seen
|
||||||
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
|
num_cudagraph_captured=2
|
||||||
num_cudagraph_captured=2 *
|
* (
|
||||||
(1 + llama_config.num_layers
|
1 + llama_config.num_layers
|
||||||
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(
|
||||||
run_model(llama_config,
|
run_model(
|
||||||
use_inductor=use_inductor,
|
llama_config,
|
||||||
use_compile=True,
|
use_inductor=use_inductor,
|
||||||
split_attn=True))
|
use_compile=True,
|
||||||
run_model(tractable_config,
|
split_attn=True,
|
||||||
use_inductor=use_inductor,
|
)
|
||||||
use_compile=True,
|
)
|
||||||
split_attn=True)
|
run_model(
|
||||||
|
tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(1, len(outputs)):
|
for i in range(1, len(outputs)):
|
||||||
assert torch.allclose(outputs[0], outputs[i])
|
assert torch.allclose(outputs[0], outputs[i])
|
||||||
@@ -388,17 +412,15 @@ def benchmark():
|
|||||||
from triton.testing import do_bench
|
from triton.testing import do_bench
|
||||||
|
|
||||||
# similar to llama 3.1-8B
|
# similar to llama 3.1-8B
|
||||||
llama_config = LlamaConfig(hidden_size=4096,
|
llama_config = LlamaConfig(
|
||||||
mlp_size=14336,
|
hidden_size=4096, mlp_size=14336, vocab_size=128 * 1024, num_layers=32
|
||||||
vocab_size=128 * 1024,
|
)
|
||||||
num_layers=32)
|
|
||||||
|
|
||||||
# a tiny model to measure the overhead
|
# a tiny model to measure the overhead
|
||||||
# of piecewise cudagraph
|
# of piecewise cudagraph
|
||||||
llama_config = LlamaConfig(hidden_size=40,
|
llama_config = LlamaConfig(
|
||||||
mlp_size=80,
|
hidden_size=40, mlp_size=80, vocab_size=128, num_layers=2
|
||||||
vocab_size=128,
|
)
|
||||||
num_layers=2)
|
|
||||||
|
|
||||||
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
|
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
|
||||||
|
|
||||||
@@ -424,12 +446,15 @@ def benchmark():
|
|||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=compilation_config)
|
vllm_config = VllmConfig(compilation_config=compilation_config)
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = LlamaModel(config=llama_config,
|
model = (
|
||||||
vllm_config=vllm_config,
|
LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
|
||||||
prefix="").eval().cuda().to(torch.bfloat16)
|
.eval()
|
||||||
|
.cuda()
|
||||||
|
.to(torch.bfloat16)
|
||||||
|
)
|
||||||
|
|
||||||
B = 256 # max batch size
|
B = 256 # max batch size
|
||||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
|
||||||
positions = torch.arange(B).cuda().to(torch.bfloat16)
|
positions = torch.arange(B).cuda().to(torch.bfloat16)
|
||||||
|
|
||||||
graphs = {}
|
graphs = {}
|
||||||
@@ -451,21 +476,25 @@ def benchmark():
|
|||||||
# and use it later, because it will look up the name `b` in the
|
# and use it later, because it will look up the name `b` in the
|
||||||
# enclosing scope, and the value of `b` will always be 256.
|
# enclosing scope, and the value of `b` will always be 256.
|
||||||
# it is fine here, because we only use the lambda function once.
|
# it is fine here, because we only use the lambda function once.
|
||||||
runtime = do_bench(lambda: graphs[b][0] # noqa
|
runtime = do_bench(
|
||||||
(input_ids[:b], positions[:b])) # noqa
|
lambda: graphs[b][0]( # noqa
|
||||||
|
input_ids[:b], positions[:b]
|
||||||
|
)
|
||||||
|
) # noqa
|
||||||
piecewise_cudagraph_time[b] = runtime
|
piecewise_cudagraph_time[b] = runtime
|
||||||
else:
|
else:
|
||||||
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
|
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
|
||||||
eager_runtime = do_bench(
|
eager_runtime = do_bench(lambda: model(input_ids[:b], positions[:b])) # noqa
|
||||||
lambda: model(input_ids[:b], positions[:b])) # noqa
|
|
||||||
full_cudagraph_time[b] = runtime
|
full_cudagraph_time[b] = runtime
|
||||||
eager_time[b] = eager_runtime
|
eager_time[b] = eager_runtime
|
||||||
|
|
||||||
# print in tabular format
|
# print in tabular format
|
||||||
print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
|
print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
|
||||||
for b in cudagraph_sizes:
|
for b in cudagraph_sizes:
|
||||||
print(f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
|
print(
|
||||||
f"\t{piecewise_cudagraph_time[b]:.3f}")
|
f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
|
||||||
|
f"\t{piecewise_cudagraph_time[b]:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -31,8 +31,9 @@ def reset_global_counter():
|
|||||||
_global_counter = 0
|
_global_counter = 0
|
||||||
|
|
||||||
|
|
||||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
def silly_attention(
|
||||||
out: torch.Tensor) -> None:
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Unified attention implementation that depends on
|
Unified attention implementation that depends on
|
||||||
all inputs and affects the output.
|
all inputs and affects the output.
|
||||||
@@ -47,8 +48,9 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|||||||
out.copy_(q + k + v)
|
out.copy_(q + k + v)
|
||||||
|
|
||||||
|
|
||||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
def silly_attention_fake(
|
||||||
out: torch.Tensor) -> None:
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
|
||||||
|
) -> None:
|
||||||
"""Fake implementation for testing"""
|
"""Fake implementation for testing"""
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -60,5 +62,5 @@ direct_register_custom_op(
|
|||||||
mutates_args=["out"],
|
mutates_args=["out"],
|
||||||
fake_impl=silly_attention_fake,
|
fake_impl=silly_attention_fake,
|
||||||
target_lib=silly_lib,
|
target_lib=silly_lib,
|
||||||
tags=(torch._C.Tag.cudagraph_unsafe, ),
|
tags=(torch._C.Tag.cudagraph_unsafe,),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,18 +8,30 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.collective_fusion import AsyncTPPass
|
from vllm.compilation.collective_fusion import AsyncTPPass
|
||||||
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (
|
||||||
PassConfig, VllmConfig)
|
CompilationConfig,
|
||||||
from vllm.distributed import (tensor_model_parallel_all_gather,
|
DeviceConfig,
|
||||||
tensor_model_parallel_reduce_scatter)
|
ModelConfig,
|
||||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
PassConfig,
|
||||||
initialize_model_parallel)
|
VllmConfig,
|
||||||
|
)
|
||||||
|
from vllm.distributed import (
|
||||||
|
tensor_model_parallel_all_gather,
|
||||||
|
tensor_model_parallel_reduce_scatter,
|
||||||
|
)
|
||||||
|
from vllm.distributed.parallel_state import (
|
||||||
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
from ..models.registry import HF_EXAMPLE_MODELS
|
from ..models.registry import HF_EXAMPLE_MODELS
|
||||||
from ..utils import (compare_two_settings, create_new_process_for_each_test,
|
from ..utils import (
|
||||||
multi_gpu_test)
|
compare_two_settings,
|
||||||
|
create_new_process_for_each_test,
|
||||||
|
multi_gpu_test,
|
||||||
|
)
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
@@ -33,21 +45,20 @@ prompts = [
|
|||||||
|
|
||||||
|
|
||||||
class TestMMRSModel(torch.nn.Module):
|
class TestMMRSModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.gate_proj = torch.nn.Parameter(torch.empty(
|
self.gate_proj = torch.nn.Parameter(
|
||||||
(self.hidden_size * 2, hidden_size)),
|
torch.empty((self.hidden_size * 2, hidden_size)), requires_grad=False
|
||||||
requires_grad=False)
|
)
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
"""
|
"""
|
||||||
Forward pass implementing the mm + reduce scatter in the FX graph
|
Forward pass implementing the mm + reduce scatter in the FX graph
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Reshape input
|
# Reshape input
|
||||||
view = hidden_states.reshape(-1, self.hidden_size)
|
view = hidden_states.reshape(-1, self.hidden_size)
|
||||||
@@ -66,14 +77,13 @@ class TestMMRSModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestAGMMModel(torch.nn.Module):
|
class TestAGMMModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.weight = torch.nn.Parameter(torch.empty(
|
self.weight = torch.nn.Parameter(
|
||||||
(hidden_size, hidden_size)),
|
torch.empty((hidden_size, hidden_size)), requires_grad=False
|
||||||
requires_grad=False)
|
)
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
torch.nn.init.normal_(self.weight, std=0.02)
|
torch.nn.init.normal_(self.weight, std=0.02)
|
||||||
|
|
||||||
@@ -96,32 +106,35 @@ class TestAGMMModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class _BaseScaledMMModel(torch.nn.Module):
|
class _BaseScaledMMModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\
|
self.weight = (
|
||||||
.contiguous().transpose(0, 1)
|
torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)
|
||||||
|
.contiguous()
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize scale_b for _scaled_mm.
|
# Initialize scale_b for _scaled_mm.
|
||||||
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
|
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
class TestScaledMMRSModel(_BaseScaledMMModel):
|
class TestScaledMMRSModel(_BaseScaledMMModel):
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor):
|
def forward(self, input: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
|
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
|
||||||
|
|
||||||
"""
|
"""
|
||||||
fp8_input = input.to(FP8_DTYPE)
|
fp8_input = input.to(FP8_DTYPE)
|
||||||
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
||||||
scaled_mm = torch._scaled_mm(fp8_input,
|
scaled_mm = torch._scaled_mm(
|
||||||
self.weight,
|
fp8_input,
|
||||||
scale_a=scale_a,
|
self.weight,
|
||||||
scale_b=self.scale_b,
|
scale_a=scale_a,
|
||||||
out_dtype=self.dtype)
|
scale_b=self.scale_b,
|
||||||
|
out_dtype=self.dtype,
|
||||||
|
)
|
||||||
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
|
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
|
||||||
return reduce_scatter
|
return reduce_scatter
|
||||||
|
|
||||||
@@ -133,7 +146,6 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
|
|||||||
|
|
||||||
|
|
||||||
class TestAGScaledMMModel(_BaseScaledMMModel):
|
class TestAGScaledMMModel(_BaseScaledMMModel):
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor):
|
def forward(self, input: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Forward pass implementing the all gather + scaled_mm in the FX graph
|
Forward pass implementing the all gather + scaled_mm in the FX graph
|
||||||
@@ -143,11 +155,13 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
|
|||||||
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
|
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
|
||||||
|
|
||||||
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
||||||
scaled_mm = torch._scaled_mm(all_gather,
|
scaled_mm = torch._scaled_mm(
|
||||||
self.weight,
|
all_gather,
|
||||||
scale_a=scale_a,
|
self.weight,
|
||||||
scale_b=self.scale_b,
|
scale_a=scale_a,
|
||||||
out_dtype=self.dtype)
|
scale_b=self.scale_b,
|
||||||
|
out_dtype=self.dtype,
|
||||||
|
)
|
||||||
return scaled_mm
|
return scaled_mm
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
@@ -158,20 +172,22 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
|
|||||||
|
|
||||||
|
|
||||||
class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor):
|
def forward(self, input: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Forward pass implementing the cutlass_scaled_mm + reduce scatter
|
Forward pass implementing the cutlass_scaled_mm + reduce scatter
|
||||||
in the FX graph
|
in the FX graph
|
||||||
|
|
||||||
"""
|
"""
|
||||||
fp8_input = input.to(FP8_DTYPE)
|
fp8_input = input.to(FP8_DTYPE)
|
||||||
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
||||||
mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]),
|
mm_out = torch.empty(
|
||||||
dtype=self.dtype,
|
(fp8_input.shape[0], self.weight.shape[1]),
|
||||||
device=input.device)
|
dtype=self.dtype,
|
||||||
torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a,
|
device=input.device,
|
||||||
self.scale_b, None)
|
)
|
||||||
|
torch.ops._C.cutlass_scaled_mm(
|
||||||
|
mm_out, fp8_input, self.weight, scale_a, self.scale_b, None
|
||||||
|
)
|
||||||
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
|
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
|
||||||
return reduce_scatter
|
return reduce_scatter
|
||||||
|
|
||||||
@@ -183,10 +199,9 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
|||||||
|
|
||||||
|
|
||||||
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor):
|
def forward(self, input: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Forward pass implementing the all gather + cutlass_scaled_mm
|
Forward pass implementing the all gather + cutlass_scaled_mm
|
||||||
in the FX graph
|
in the FX graph
|
||||||
"""
|
"""
|
||||||
# Reshape input
|
# Reshape input
|
||||||
@@ -195,11 +210,14 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
|||||||
|
|
||||||
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
||||||
|
|
||||||
mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]),
|
mm_out = torch.empty(
|
||||||
dtype=self.dtype,
|
(all_gather.shape[0], self.weight.shape[1]),
|
||||||
device=all_gather.device)
|
dtype=self.dtype,
|
||||||
torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight,
|
device=all_gather.device,
|
||||||
scale_a, self.scale_b, None)
|
)
|
||||||
|
torch.ops._C.cutlass_scaled_mm(
|
||||||
|
mm_out, all_gather, self.weight, scale_a, self.scale_b, None
|
||||||
|
)
|
||||||
return mm_out
|
return mm_out
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
@@ -210,23 +228,37 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
|||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@pytest.mark.parametrize("test_model", [
|
@pytest.mark.parametrize(
|
||||||
TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel,
|
"test_model",
|
||||||
TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel
|
[
|
||||||
])
|
TestMMRSModel,
|
||||||
|
TestAGMMModel,
|
||||||
|
TestScaledMMRSModel,
|
||||||
|
TestAGScaledMMModel,
|
||||||
|
TestCutlassScaledMMRSModel,
|
||||||
|
TestAGCutlassScaledMMModel,
|
||||||
|
],
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize("seq_len", [16])
|
@pytest.mark.parametrize("seq_len", [16])
|
||||||
@pytest.mark.parametrize("hidden_size", [16])
|
@pytest.mark.parametrize("hidden_size", [16])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||||
reason="Only test on CUDA")
|
def test_async_tp_pass_replace(
|
||||||
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
|
test_model: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype
|
||||||
hidden_size: int, dtype: torch.dtype):
|
):
|
||||||
if test_model in (TestScaledMMRSModel, TestAGScaledMMModel,
|
if (
|
||||||
TestCutlassScaledMMRSModel,
|
test_model
|
||||||
TestAGCutlassScaledMMModel) and dtype == torch.float16:
|
in (
|
||||||
|
TestScaledMMRSModel,
|
||||||
|
TestAGScaledMMModel,
|
||||||
|
TestCutlassScaledMMRSModel,
|
||||||
|
TestAGCutlassScaledMMModel,
|
||||||
|
)
|
||||||
|
and dtype == torch.float16
|
||||||
|
):
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"Only bf16 high precision output types are supported for " \
|
"Only bf16 high precision output types are supported for "
|
||||||
"per-token (row-wise) scaling"
|
"per-token (row-wise) scaling"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -235,19 +267,24 @@ def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
|
|||||||
def run_torch_spawn(fn, nprocs):
|
def run_torch_spawn(fn, nprocs):
|
||||||
# need to use torch.mp.spawn otherwise will have problems with
|
# need to use torch.mp.spawn otherwise will have problems with
|
||||||
# torch.distributed and cuda
|
# torch.distributed and cuda
|
||||||
torch.multiprocessing.spawn(fn,
|
torch.multiprocessing.spawn(
|
||||||
args=(num_processes, test_model,
|
fn,
|
||||||
batch_size, seq_len, hidden_size,
|
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
|
||||||
dtype),
|
nprocs=nprocs,
|
||||||
nprocs=nprocs)
|
)
|
||||||
|
|
||||||
run_torch_spawn(async_tp_pass_on_test_model, num_processes)
|
run_torch_spawn(async_tp_pass_on_test_model, num_processes)
|
||||||
|
|
||||||
|
|
||||||
def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
def async_tp_pass_on_test_model(
|
||||||
test_model_cls: torch.nn.Module,
|
local_rank: int,
|
||||||
batch_size: int, seq_len: int,
|
world_size: int,
|
||||||
hidden_size: int, dtype: torch.dtype):
|
test_model_cls: torch.nn.Module,
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
@@ -255,13 +292,15 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
|
|
||||||
update_environment_variables({
|
update_environment_variables(
|
||||||
'RANK': str(local_rank),
|
{
|
||||||
'LOCAL_RANK': str(local_rank),
|
"RANK": str(local_rank),
|
||||||
'WORLD_SIZE': str(world_size),
|
"LOCAL_RANK": str(local_rank),
|
||||||
'MASTER_ADDR': 'localhost',
|
"WORLD_SIZE": str(world_size),
|
||||||
'MASTER_PORT': '12345',
|
"MASTER_ADDR": "localhost",
|
||||||
})
|
"MASTER_PORT": "12345",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# initialize distributed
|
# initialize distributed
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
@@ -269,27 +308,28 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
|
|
||||||
# configure vllm config for SequenceParallelismPass
|
# configure vllm config for SequenceParallelismPass
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
|
vllm_config.compilation_config = CompilationConfig(
|
||||||
enable_async_tp=True, ), )
|
pass_config=PassConfig(
|
||||||
|
enable_async_tp=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||||
|
|
||||||
# this is a fake model name to construct the model config
|
# this is a fake model name to construct the model config
|
||||||
# in the vllm_config, it's not really used.
|
# in the vllm_config, it's not really used.
|
||||||
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||||
vllm_config.model_config = ModelConfig(model=model_name,
|
vllm_config.model_config = ModelConfig(
|
||||||
trust_remote_code=True,
|
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||||
dtype=dtype,
|
)
|
||||||
seed=42)
|
|
||||||
|
|
||||||
async_tp_pass = AsyncTPPass(vllm_config)
|
async_tp_pass = AsyncTPPass(vllm_config)
|
||||||
backend = TestBackend(async_tp_pass)
|
backend = TestBackend(async_tp_pass)
|
||||||
|
|
||||||
model = test_model_cls(hidden_size,
|
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
|
||||||
dtype) # Pass dtype to model constructor
|
|
||||||
|
|
||||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
hidden_states = torch.randn(
|
||||||
dtype=dtype,
|
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
|
||||||
requires_grad=False)
|
)
|
||||||
|
|
||||||
compiled_model = torch.compile(model, backend=backend)
|
compiled_model = torch.compile(model, backend=backend)
|
||||||
compiled_model(hidden_states)
|
compiled_model(hidden_states)
|
||||||
@@ -306,10 +346,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
|
|
||||||
|
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
@pytest.mark.parametrize("model_id", [
|
@pytest.mark.parametrize(
|
||||||
"meta-llama/Llama-3.2-1B-Instruct",
|
"model_id",
|
||||||
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
|
["meta-llama/Llama-3.2-1B-Instruct", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"],
|
||||||
])
|
)
|
||||||
@pytest.mark.parametrize("tp_size", [2])
|
@pytest.mark.parametrize("tp_size", [2])
|
||||||
@pytest.mark.parametrize("async_tp_enabled", [True])
|
@pytest.mark.parametrize("async_tp_enabled", [True])
|
||||||
@pytest.mark.parametrize("distributed_backend", ["mp"])
|
@pytest.mark.parametrize("distributed_backend", ["mp"])
|
||||||
@@ -342,12 +382,10 @@ def test_async_tp_pass_correctness(
|
|||||||
common_args.append("--enforce-eager")
|
common_args.append("--enforce-eager")
|
||||||
|
|
||||||
compilation_config = {
|
compilation_config = {
|
||||||
'level': 3,
|
"level": 3,
|
||||||
'compile_sizes': [2, 4, 8],
|
"compile_sizes": [2, 4, 8],
|
||||||
'splitting_ops': [],
|
"splitting_ops": [],
|
||||||
'pass_config': {
|
"pass_config": {"enable_async_tp": async_tp_enabled},
|
||||||
'enable_async_tp': async_tp_enabled
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async_tp_env = tp_env = {
|
async_tp_env = tp_env = {
|
||||||
@@ -372,9 +410,6 @@ def test_async_tp_pass_correctness(
|
|||||||
"mp",
|
"mp",
|
||||||
]
|
]
|
||||||
|
|
||||||
compare_two_settings(model_id,
|
compare_two_settings(
|
||||||
async_tp_args,
|
model_id, async_tp_args, tp_args, async_tp_env, tp_env, method="generate"
|
||||||
tp_args,
|
)
|
||||||
async_tp_env,
|
|
||||||
tp_env,
|
|
||||||
method="generate")
|
|
||||||
|
|||||||
@@ -103,23 +103,28 @@ def test_compile_correctness(
|
|||||||
attn_backend = test_setting.attn_backend
|
attn_backend = test_setting.attn_backend
|
||||||
method = test_setting.method
|
method = test_setting.method
|
||||||
if cuda_device_count_stateless() < pp_size * tp_size:
|
if cuda_device_count_stateless() < pp_size * tp_size:
|
||||||
pytest.skip(f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
|
pytest.skip(
|
||||||
f"{cuda_device_count_stateless()}")
|
f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
|
||||||
|
f"{cuda_device_count_stateless()}"
|
||||||
|
)
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
final_args = [
|
final_args = [
|
||||||
"--enforce-eager", *model_args, "-pp",
|
"--enforce-eager",
|
||||||
str(pp_size), "-tp",
|
*model_args,
|
||||||
str(tp_size)
|
"-pp",
|
||||||
|
str(pp_size),
|
||||||
|
"-tp",
|
||||||
|
str(tp_size),
|
||||||
]
|
]
|
||||||
|
|
||||||
all_args: list[list[str]] = []
|
all_args: list[list[str]] = []
|
||||||
all_envs: list[dict[str, str] | None] = []
|
all_envs: list[dict[str, str] | None] = []
|
||||||
|
|
||||||
for level in [
|
for level in [
|
||||||
CompilationLevel.NO_COMPILATION,
|
CompilationLevel.NO_COMPILATION,
|
||||||
CompilationLevel.PIECEWISE,
|
CompilationLevel.PIECEWISE,
|
||||||
]:
|
]:
|
||||||
all_args.append(final_args + [f"-O{level}"])
|
all_args.append(final_args + [f"-O{level}"])
|
||||||
all_envs.append({})
|
all_envs.append({})
|
||||||
@@ -130,14 +135,15 @@ def test_compile_correctness(
|
|||||||
model,
|
model,
|
||||||
all_args,
|
all_args,
|
||||||
all_envs,
|
all_envs,
|
||||||
method=method if method != "generate" else "generate_close")
|
method=method if method != "generate" else "generate_close",
|
||||||
|
)
|
||||||
all_envs.clear()
|
all_envs.clear()
|
||||||
all_args.clear()
|
all_args.clear()
|
||||||
|
|
||||||
for level in [
|
for level in [
|
||||||
CompilationLevel.NO_COMPILATION,
|
CompilationLevel.NO_COMPILATION,
|
||||||
CompilationLevel.DYNAMO_AS_IS,
|
CompilationLevel.DYNAMO_AS_IS,
|
||||||
CompilationLevel.DYNAMO_ONCE,
|
CompilationLevel.DYNAMO_ONCE,
|
||||||
]:
|
]:
|
||||||
all_args.append(final_args + [f"-O{level}"])
|
all_args.append(final_args + [f"-O{level}"])
|
||||||
all_envs.append({})
|
all_envs.append({})
|
||||||
|
|||||||
@@ -9,11 +9,11 @@ from vllm.utils import _is_torch_equal_or_newer
|
|||||||
|
|
||||||
|
|
||||||
def test_version():
|
def test_version():
|
||||||
assert _is_torch_equal_or_newer('2.8.0.dev20250624+cu128', '2.8.0.dev')
|
assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
|
||||||
assert _is_torch_equal_or_newer('2.8.0a0+gitc82a174', '2.8.0.dev')
|
assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
|
||||||
assert _is_torch_equal_or_newer('2.8.0', '2.8.0.dev')
|
assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
|
||||||
assert _is_torch_equal_or_newer('2.8.1', '2.8.0.dev')
|
assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
|
||||||
assert not _is_torch_equal_or_newer('2.7.1', '2.8.0.dev')
|
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
|
||||||
|
|
||||||
|
|
||||||
def test_use_cudagraphs_dynamic(monkeypatch):
|
def test_use_cudagraphs_dynamic(monkeypatch):
|
||||||
@@ -21,7 +21,7 @@ def test_use_cudagraphs_dynamic(monkeypatch):
|
|||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
assert vllm_config.compilation_config.use_cudagraph
|
assert vllm_config.compilation_config.use_cudagraph
|
||||||
|
|
||||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
assert not vllm_config.compilation_config.use_cudagraph
|
assert not vllm_config.compilation_config.use_cudagraph
|
||||||
|
|
||||||
@@ -44,19 +44,23 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
|
|||||||
assert vllm.envs.VLLM_USE_V1
|
assert vllm.envs.VLLM_USE_V1
|
||||||
|
|
||||||
# Disable multiprocessing so that the counter is in the same process
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val)
|
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
|
||||||
|
|
||||||
compilation_config = {
|
compilation_config = {
|
||||||
"use_cudagraph": False, # speed things up a bit
|
"use_cudagraph": False, # speed things up a bit
|
||||||
}
|
}
|
||||||
with (
|
with (
|
||||||
compilation_counter.expect(num_cache_entries_updated=0,
|
compilation_counter.expect(
|
||||||
num_compiled_artifacts_saved=0),
|
num_cache_entries_updated=0, num_compiled_artifacts_saved=0
|
||||||
# loading the model causes compilation (if enabled) to happen
|
),
|
||||||
vllm_runner('facebook/opt-125m',
|
# loading the model causes compilation (if enabled) to happen
|
||||||
compilation_config=compilation_config,
|
vllm_runner(
|
||||||
gpu_memory_utilization=0.4) as _):
|
"facebook/opt-125m",
|
||||||
|
compilation_config=compilation_config,
|
||||||
|
gpu_memory_utilization=0.4,
|
||||||
|
) as _,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -67,22 +71,25 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
|||||||
assert vllm.envs.VLLM_USE_V1
|
assert vllm.envs.VLLM_USE_V1
|
||||||
|
|
||||||
# Disable multiprocessing so that the counter is in the same process
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
|
|
||||||
compilation_config = {
|
compilation_config = {
|
||||||
"cudagraph_capture_sizes": [100],
|
"cudagraph_capture_sizes": [100],
|
||||||
"use_cudagraph": enabled,
|
"use_cudagraph": enabled,
|
||||||
}
|
}
|
||||||
with (
|
with (
|
||||||
compilation_counter.expect(
|
compilation_counter.expect(
|
||||||
num_graphs_seen=1,
|
num_graphs_seen=1,
|
||||||
num_gpu_runner_capture_triggers=1 if enabled else 0,
|
num_gpu_runner_capture_triggers=1 if enabled else 0,
|
||||||
num_cudagraph_captured=13 if enabled else 0,
|
num_cudagraph_captured=13 if enabled else 0,
|
||||||
),
|
),
|
||||||
# loading the model causes compilation (if enabled) to happen
|
# loading the model causes compilation (if enabled) to happen
|
||||||
vllm_runner('facebook/opt-125m',
|
vllm_runner(
|
||||||
compilation_config=compilation_config,
|
"facebook/opt-125m",
|
||||||
gpu_memory_utilization=0.4) as _):
|
compilation_config=compilation_config,
|
||||||
|
gpu_memory_utilization=0.4,
|
||||||
|
) as _,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -90,14 +97,17 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
|||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_dynamo_as_is(vllm_runner, monkeypatch):
|
def test_dynamo_as_is(vllm_runner, monkeypatch):
|
||||||
# Disable multiprocessing so that the counter is in the same process
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
|
|
||||||
with (
|
with (
|
||||||
compilation_counter.expect(dynamo_as_is_count=1),
|
compilation_counter.expect(dynamo_as_is_count=1),
|
||||||
# loading the model causes compilation (if enabled) to happen
|
# loading the model causes compilation (if enabled) to happen
|
||||||
vllm_runner('facebook/opt-125m',
|
vllm_runner(
|
||||||
compilation_config={"level": 1},
|
"facebook/opt-125m",
|
||||||
gpu_memory_utilization=0.4) as _):
|
compilation_config={"level": 1},
|
||||||
|
gpu_memory_utilization=0.4,
|
||||||
|
) as _,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -105,14 +115,16 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
|
|||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_no_compilation(vllm_runner, monkeypatch):
|
def test_no_compilation(vllm_runner, monkeypatch):
|
||||||
# Disable multiprocessing so that the counter is in the same process
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
with (
|
with (
|
||||||
compilation_counter.expect(num_graphs_seen=0,
|
compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
|
||||||
dynamo_as_is_count=0),
|
# loading the model causes compilation (if enabled) to happen
|
||||||
# loading the model causes compilation (if enabled) to happen
|
vllm_runner(
|
||||||
vllm_runner('facebook/opt-125m',
|
"facebook/opt-125m",
|
||||||
compilation_config={"level": 0},
|
compilation_config={"level": 0},
|
||||||
gpu_memory_utilization=0.4) as _):
|
gpu_memory_utilization=0.4,
|
||||||
|
) as _,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -120,77 +132,73 @@ def test_no_compilation(vllm_runner, monkeypatch):
|
|||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_enforce_eager(vllm_runner, monkeypatch):
|
def test_enforce_eager(vllm_runner, monkeypatch):
|
||||||
# Disable multiprocessing so that the counter is in the same process
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
|
|
||||||
with (
|
with (
|
||||||
compilation_counter.expect(num_graphs_seen=0,
|
compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
|
||||||
dynamo_as_is_count=0),
|
# loading the model causes compilation (if enabled) to happen
|
||||||
# loading the model causes compilation (if enabled) to happen
|
vllm_runner(
|
||||||
vllm_runner('facebook/opt-125m',
|
"facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
|
||||||
enforce_eager=True,
|
) as _,
|
||||||
gpu_memory_utilization=0.4) as _):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_splitting_ops_dynamic():
|
def test_splitting_ops_dynamic():
|
||||||
# Default config
|
# Default config
|
||||||
config = VllmConfig()
|
config = VllmConfig()
|
||||||
assert config.compilation_config.cudagraph_mode == \
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||||
CUDAGraphMode.FULL_AND_PIECEWISE
|
|
||||||
assert config.compilation_config.splitting_ops_contain_attention()
|
assert config.compilation_config.splitting_ops_contain_attention()
|
||||||
|
|
||||||
# When use_inductor_graph_partition=True
|
# When use_inductor_graph_partition=True
|
||||||
if _is_torch_equal_or_newer('2.9.0.dev'):
|
if _is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
# inductor graph partition is only available in PyTorch 2.9+.
|
# inductor graph partition is only available in PyTorch 2.9+.
|
||||||
# this is a fast config check so we are not using pytest.skip.
|
# this is a fast config check so we are not using pytest.skip.
|
||||||
config = VllmConfig(compilation_config=CompilationConfig(
|
config = VllmConfig(
|
||||||
use_inductor_graph_partition=True,
|
compilation_config=CompilationConfig(
|
||||||
splitting_ops=["silly_attention"]))
|
use_inductor_graph_partition=True, splitting_ops=["silly_attention"]
|
||||||
|
)
|
||||||
|
)
|
||||||
# should ignore splitting_ops
|
# should ignore splitting_ops
|
||||||
assert config.compilation_config.splitting_ops == []
|
assert config.compilation_config.splitting_ops == []
|
||||||
|
|
||||||
# When attn_fusion pass enabled.
|
# When attn_fusion pass enabled.
|
||||||
config = VllmConfig(compilation_config=CompilationConfig(
|
config = VllmConfig(
|
||||||
pass_config={
|
compilation_config=CompilationConfig(
|
||||||
"enable_attn_fusion": True,
|
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||||
"enable_noop": True
|
custom_ops=["+quant_fp8"],
|
||||||
},
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
custom_ops=["+quant_fp8"],
|
)
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
)
|
||||||
))
|
|
||||||
assert config.compilation_config.splitting_ops == []
|
assert config.compilation_config.splitting_ops == []
|
||||||
# cudagraph mode also fall back to FULL
|
# cudagraph mode also fall back to FULL
|
||||||
assert config.compilation_config.cudagraph_mode == \
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
||||||
CUDAGraphMode.FULL
|
|
||||||
|
|
||||||
# splitting_ops can not contain attention ops when attn_fusion
|
# splitting_ops can not contain attention ops when attn_fusion
|
||||||
# pass enabled.
|
# pass enabled.
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = VllmConfig(compilation_config=CompilationConfig(
|
config = VllmConfig(
|
||||||
pass_config={
|
compilation_config=CompilationConfig(
|
||||||
"enable_attn_fusion": True,
|
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||||
"enable_noop": True
|
custom_ops=["+quant_fp8"],
|
||||||
},
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
custom_ops=["+quant_fp8"],
|
# work around for accessing all attntion ops
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
splitting_ops=CompilationConfig()._attention_ops,
|
||||||
# work around for accessing all attntion ops
|
)
|
||||||
splitting_ops=CompilationConfig()._attention_ops,
|
)
|
||||||
))
|
|
||||||
|
|
||||||
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
||||||
if _is_torch_equal_or_newer('2.9.0.dev'):
|
if _is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
config = VllmConfig(compilation_config=CompilationConfig(
|
config = VllmConfig(
|
||||||
use_inductor_graph_partition=True,
|
compilation_config=CompilationConfig(
|
||||||
pass_config={
|
use_inductor_graph_partition=True,
|
||||||
"enable_attn_fusion": True,
|
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||||
"enable_noop": True
|
custom_ops=["+quant_fp8"],
|
||||||
},
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
custom_ops=["+quant_fp8"],
|
)
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
)
|
||||||
))
|
|
||||||
assert config.compilation_config.splitting_ops == []
|
assert config.compilation_config.splitting_ops == []
|
||||||
# enable_attn_fusion is directly support under
|
# enable_attn_fusion is directly support under
|
||||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||||
# is unchanged.
|
# is unchanged.
|
||||||
assert config.compilation_config.cudagraph_mode == \
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||||
CUDAGraphMode.PIECEWISE
|
|
||||||
|
|||||||
@@ -4,10 +4,15 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
|
||||||
support_torch_compile)
|
from vllm.config import (
|
||||||
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
CacheConfig,
|
||||||
CUDAGraphMode, VllmConfig, set_current_vllm_config)
|
CompilationConfig,
|
||||||
|
CompilationLevel,
|
||||||
|
CUDAGraphMode,
|
||||||
|
VllmConfig,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
|
|
||||||
# This import automatically registers `torch.ops.silly.attention`
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
@@ -18,32 +23,42 @@ MLP_SIZE = 128
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def run_model(vllm_config: VllmConfig, model: nn.Module,
|
def run_model(
|
||||||
cudagraph_runtime_mode: CUDAGraphMode):
|
vllm_config: VllmConfig, model: nn.Module, cudagraph_runtime_mode: CUDAGraphMode
|
||||||
|
):
|
||||||
with set_forward_context({}, vllm_config=vllm_config):
|
with set_forward_context({}, vllm_config=vllm_config):
|
||||||
# warmup for the model with cudagraph_mode NONE
|
# warmup for the model with cudagraph_mode NONE
|
||||||
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||||
|
|
||||||
# simulate cudagraphs capturing
|
# simulate cudagraphs capturing
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
vllm_config=vllm_config,
|
{},
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
vllm_config=vllm_config,
|
||||||
batch_descriptor=BatchDescriptor(
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
num_tokens=2, )):
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(torch.randn(2, MLP_SIZE).cuda())
|
model(torch.randn(2, MLP_SIZE).cuda())
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
vllm_config=vllm_config,
|
{},
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
vllm_config=vllm_config,
|
||||||
batch_descriptor=BatchDescriptor(
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
num_tokens=1, )):
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=1,
|
||||||
|
),
|
||||||
|
):
|
||||||
model(torch.randn(1, MLP_SIZE).cuda())
|
model(torch.randn(1, MLP_SIZE).cuda())
|
||||||
|
|
||||||
# simulate cudagraphs replay
|
# simulate cudagraphs replay
|
||||||
with set_forward_context({},
|
with set_forward_context(
|
||||||
vllm_config=vllm_config,
|
{},
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
vllm_config=vllm_config,
|
||||||
batch_descriptor=BatchDescriptor(
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
num_tokens=2, )):
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2,
|
||||||
|
),
|
||||||
|
):
|
||||||
output = model(torch.randn(2, MLP_SIZE).cuda())
|
output = model(torch.randn(2, MLP_SIZE).cuda())
|
||||||
|
|
||||||
output = output.cpu()
|
output = output.cpu()
|
||||||
@@ -52,22 +67,21 @@ def run_model(vllm_config: VllmConfig, model: nn.Module,
|
|||||||
|
|
||||||
def test_ignore_torch_compile_decorator():
|
def test_ignore_torch_compile_decorator():
|
||||||
# piecewise
|
# piecewise
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
compilation_config=CompilationConfig(
|
||||||
use_cudagraph=True,
|
level=CompilationLevel.PIECEWISE,
|
||||||
splitting_ops=["silly.attention"],
|
use_cudagraph=True,
|
||||||
cudagraph_capture_sizes=[1, 2],
|
splitting_ops=["silly.attention"],
|
||||||
))
|
cudagraph_capture_sizes=[1, 2],
|
||||||
|
)
|
||||||
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class A(nn.Module):
|
class A(nn.Module):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs
|
||||||
*,
|
) -> None:
|
||||||
vllm_config: VllmConfig,
|
|
||||||
prefix: str = '',
|
|
||||||
**kwargs) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -79,66 +93,60 @@ def test_ignore_torch_compile_decorator():
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
@ignore_torch_compile
|
@ignore_torch_compile
|
||||||
class B(A):
|
class B(A): ...
|
||||||
...
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class C(B):
|
class C(B): ...
|
||||||
...
|
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||||
|
|
||||||
# A has support_torch_compile
|
# A has support_torch_compile
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=1,
|
num_graphs_seen=1,
|
||||||
num_piecewise_graphs_seen=3,
|
num_piecewise_graphs_seen=3,
|
||||||
num_piecewise_capturable_graphs_seen=2,
|
num_piecewise_capturable_graphs_seen=2,
|
||||||
num_backend_compilations=2,
|
num_backend_compilations=2,
|
||||||
num_cudagraph_captured=4,
|
num_cudagraph_captured=4,
|
||||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
|
mod_B = B(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||||
|
|
||||||
# B's ignore_torch_compile should override A's support_torch_compile
|
# B's ignore_torch_compile should override A's support_torch_compile
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=0,
|
num_graphs_seen=0,
|
||||||
num_piecewise_graphs_seen=0,
|
num_piecewise_graphs_seen=0,
|
||||||
num_piecewise_capturable_graphs_seen=0,
|
num_piecewise_capturable_graphs_seen=0,
|
||||||
num_backend_compilations=0,
|
num_backend_compilations=0,
|
||||||
num_cudagraph_captured=0,
|
num_cudagraph_captured=0,
|
||||||
):
|
):
|
||||||
run_model(vllm_config, mod_B, cudagraph_runtime_mode)
|
run_model(vllm_config, mod_B, cudagraph_runtime_mode)
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
|
mod_C = C(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||||
|
|
||||||
# C's support_torch_compile should override B's ignore_torch_compile
|
# C's support_torch_compile should override B's ignore_torch_compile
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=1,
|
num_graphs_seen=1,
|
||||||
num_piecewise_graphs_seen=3,
|
num_piecewise_graphs_seen=3,
|
||||||
num_piecewise_capturable_graphs_seen=2,
|
num_piecewise_capturable_graphs_seen=2,
|
||||||
num_backend_compilations=2,
|
num_backend_compilations=2,
|
||||||
num_cudagraph_captured=4,
|
num_cudagraph_captured=4,
|
||||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
|
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
|
||||||
|
|
||||||
|
|
||||||
# Only enable torch.compile if
|
# Only enable torch.compile if
|
||||||
# vllm_config.cache_config.kv_sharing_fast_prefill=True
|
# vllm_config.cache_config.kv_sharing_fast_prefill=True
|
||||||
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
@support_torch_compile(
|
||||||
kv_sharing_fast_prefill)
|
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
|
||||||
|
)
|
||||||
class B(nn.Module):
|
class B(nn.Module):
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||||
def __init__(self,
|
|
||||||
*,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
prefix: str = '',
|
|
||||||
**kwargs) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -152,15 +160,11 @@ class B(nn.Module):
|
|||||||
|
|
||||||
# Only enable torch.compile if
|
# Only enable torch.compile if
|
||||||
# vllm_config.cache_config.kv_sharing_fast_prefill=False
|
# vllm_config.cache_config.kv_sharing_fast_prefill=False
|
||||||
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
|
@support_torch_compile(
|
||||||
cache_config.kv_sharing_fast_prefill)
|
enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
|
||||||
|
)
|
||||||
class A(nn.Module):
|
class A(nn.Module):
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||||
def __init__(self,
|
|
||||||
*,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
prefix: str = '',
|
|
||||||
**kwargs) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||||
self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||||
@@ -175,54 +179,60 @@ class A(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def test_conditional_compile_enable_if():
|
def test_conditional_compile_enable_if():
|
||||||
vllm_config = VllmConfig(cache_config=CacheConfig(
|
vllm_config = VllmConfig(
|
||||||
kv_sharing_fast_prefill=True, ),
|
cache_config=CacheConfig(
|
||||||
compilation_config=CompilationConfig(
|
kv_sharing_fast_prefill=True,
|
||||||
level=CompilationLevel.PIECEWISE,
|
),
|
||||||
use_cudagraph=True,
|
compilation_config=CompilationConfig(
|
||||||
splitting_ops=["silly.attention"],
|
level=CompilationLevel.PIECEWISE,
|
||||||
cudagraph_capture_sizes=[1, 2],
|
use_cudagraph=True,
|
||||||
))
|
splitting_ops=["silly.attention"],
|
||||||
|
cudagraph_capture_sizes=[1, 2],
|
||||||
|
),
|
||||||
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||||
|
|
||||||
# A has support_torch_compile but enable_if fn returns False
|
# A has support_torch_compile but enable_if fn returns False
|
||||||
# enalbe_if will be True for B, so we expect mod1 and mod2
|
# enalbe_if will be True for B, so we expect mod1 and mod2
|
||||||
# to be compiled
|
# to be compiled
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=2,
|
num_graphs_seen=2,
|
||||||
num_piecewise_graphs_seen=6,
|
num_piecewise_graphs_seen=6,
|
||||||
# 3 piecewise graphs per instance of B()
|
# 3 piecewise graphs per instance of B()
|
||||||
num_piecewise_capturable_graphs_seen=4,
|
num_piecewise_capturable_graphs_seen=4,
|
||||||
num_backend_compilations=4,
|
num_backend_compilations=4,
|
||||||
num_cudagraph_captured=8,
|
num_cudagraph_captured=8,
|
||||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||||
|
|
||||||
# Set kv_sharing_fast_prefill=False
|
# Set kv_sharing_fast_prefill=False
|
||||||
# which will cause A to be compiled and B to not be compiled
|
# which will cause A to be compiled and B to not be compiled
|
||||||
vllm_config = VllmConfig(cache_config=CacheConfig(
|
vllm_config = VllmConfig(
|
||||||
kv_sharing_fast_prefill=False, ),
|
cache_config=CacheConfig(
|
||||||
compilation_config=CompilationConfig(
|
kv_sharing_fast_prefill=False,
|
||||||
level=CompilationLevel.PIECEWISE,
|
),
|
||||||
use_cudagraph=True,
|
compilation_config=CompilationConfig(
|
||||||
splitting_ops=["silly.attention"],
|
level=CompilationLevel.PIECEWISE,
|
||||||
cudagraph_capture_sizes=[1, 2],
|
use_cudagraph=True,
|
||||||
))
|
splitting_ops=["silly.attention"],
|
||||||
|
cudagraph_capture_sizes=[1, 2],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=1,
|
num_graphs_seen=1,
|
||||||
num_piecewise_graphs_seen=7,
|
num_piecewise_graphs_seen=7,
|
||||||
# 3 attn ops and 4 non-attn ops
|
# 3 attn ops and 4 non-attn ops
|
||||||
num_piecewise_capturable_graphs_seen=4,
|
num_piecewise_capturable_graphs_seen=4,
|
||||||
num_backend_compilations=4,
|
num_backend_compilations=4,
|
||||||
num_cudagraph_captured=8,
|
num_cudagraph_captured=8,
|
||||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||||
|
|||||||
@@ -14,8 +14,7 @@ from tests.quantization.utils import is_quant_method_supported
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig
|
||||||
PassConfig)
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_torch_equal_or_newer
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
@@ -25,43 +24,54 @@ from ..utils import create_new_process_for_each_test
|
|||||||
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
|
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
|
||||||
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
|
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
|
||||||
("facebook/opt-125m", {}),
|
("facebook/opt-125m", {}),
|
||||||
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
|
(
|
||||||
"dtype": torch.float16,
|
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
|
||||||
}),
|
{
|
||||||
("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", {
|
"dtype": torch.float16,
|
||||||
"dtype": torch.float16,
|
},
|
||||||
}),
|
),
|
||||||
|
(
|
||||||
|
"neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
|
||||||
|
{
|
||||||
|
"dtype": torch.float16,
|
||||||
|
},
|
||||||
|
),
|
||||||
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
|
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
|
||||||
("meta-llama/Llama-3.2-1B-Instruct", {}),
|
("meta-llama/Llama-3.2-1B-Instruct", {}),
|
||||||
]
|
]
|
||||||
|
|
||||||
if all:
|
if all:
|
||||||
|
|
||||||
# TODO: figure out why this fails.
|
# TODO: figure out why this fails.
|
||||||
if False and is_quant_method_supported("gguf"): # noqa: SIM223
|
if False and is_quant_method_supported("gguf"): # noqa: SIM223
|
||||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
|
TEST_MODELS.append(
|
||||||
"quantization": "gguf"
|
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {"quantization": "gguf"})
|
||||||
}))
|
)
|
||||||
|
|
||||||
if is_quant_method_supported("gptq"):
|
if is_quant_method_supported("gptq"):
|
||||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
|
TEST_MODELS.append(
|
||||||
"quantization": "gptq"
|
("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {"quantization": "gptq"})
|
||||||
}))
|
)
|
||||||
|
|
||||||
if is_quant_method_supported("gptq_marlin"):
|
if is_quant_method_supported("gptq_marlin"):
|
||||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
|
TEST_MODELS.append(
|
||||||
"quantization": "gptq_marlin"
|
(
|
||||||
}))
|
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
|
||||||
|
{"quantization": "gptq_marlin"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if is_quant_method_supported("gptq_marlin_24"):
|
if is_quant_method_supported("gptq_marlin_24"):
|
||||||
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
|
TEST_MODELS.append(
|
||||||
"quantization": "gptq_marlin_24"
|
(
|
||||||
}))
|
"alexm-nm/tinyllama-24-marlin24-4bit-g128",
|
||||||
|
{"quantization": "gptq_marlin_24"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
|
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
|
||||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
|
TEST_MODELS.append(
|
||||||
"quantization": "AWQ"
|
("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {"quantization": "AWQ"})
|
||||||
}))
|
)
|
||||||
|
|
||||||
if keywords is None:
|
if keywords is None:
|
||||||
return TEST_MODELS
|
return TEST_MODELS
|
||||||
@@ -95,22 +105,34 @@ def test_full_graph(
|
|||||||
"compilation_config, model_info",
|
"compilation_config, model_info",
|
||||||
[
|
[
|
||||||
# additional compile sizes, only some of the models
|
# additional compile sizes, only some of the models
|
||||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
(
|
||||||
compile_sizes=[1, 2]), model)
|
CompilationConfig(level=CompilationLevel.PIECEWISE, compile_sizes=[1, 2]),
|
||||||
|
model,
|
||||||
|
)
|
||||||
for model in models_list(all=False)
|
for model in models_list(all=False)
|
||||||
] + [
|
]
|
||||||
|
+ [
|
||||||
# RMSNorm + quant fusion, only 8-bit quant models
|
# RMSNorm + quant fusion, only 8-bit quant models
|
||||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
(
|
||||||
custom_ops=["+rms_norm"],
|
CompilationConfig(
|
||||||
pass_config=PassConfig(enable_fusion=True,
|
level=CompilationLevel.PIECEWISE,
|
||||||
enable_noop=True)), model)
|
custom_ops=["+rms_norm"],
|
||||||
|
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||||
|
),
|
||||||
|
model,
|
||||||
|
)
|
||||||
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
|
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
|
||||||
] + [
|
]
|
||||||
|
+ [
|
||||||
# Test depyf integration works
|
# Test depyf integration works
|
||||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
(
|
||||||
debug_dump_path=tempfile.gettempdir()),
|
CompilationConfig(
|
||||||
("facebook/opt-125m", {})),
|
level=CompilationLevel.PIECEWISE, debug_dump_path=tempfile.gettempdir()
|
||||||
] + [
|
),
|
||||||
|
("facebook/opt-125m", {}),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
+ [
|
||||||
# graph inductor partition
|
# graph inductor partition
|
||||||
(
|
(
|
||||||
CompilationConfig(
|
CompilationConfig(
|
||||||
@@ -119,20 +141,24 @@ def test_full_graph(
|
|||||||
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
|
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
|
||||||
use_inductor_graph_partition=True,
|
use_inductor_graph_partition=True,
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
compile_sizes=[1, 2]),
|
compile_sizes=[1, 2],
|
||||||
model) for model in models_list(all=False)
|
),
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
for model in models_list(all=False)
|
||||||
if is_torch_equal_or_newer("2.9.0.dev")
|
if is_torch_equal_or_newer("2.9.0.dev")
|
||||||
])
|
],
|
||||||
|
)
|
||||||
# only test some of the models
|
# only test some of the models
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_custom_compile_config(
|
def test_custom_compile_config(
|
||||||
compilation_config: CompilationConfig,
|
compilation_config: CompilationConfig,
|
||||||
model_info: tuple[str, dict[str, Any]],
|
model_info: tuple[str, dict[str, Any]],
|
||||||
):
|
):
|
||||||
if (compilation_config.use_inductor_graph_partition
|
if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
|
||||||
and not is_torch_equal_or_newer("2.9.0.dev")):
|
"2.9.0.dev"
|
||||||
pytest.skip("inductor graph partition is only available "
|
):
|
||||||
"in PyTorch 2.9+")
|
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||||
|
|
||||||
model, model_kwargs = model_info
|
model, model_kwargs = model_info
|
||||||
print(f"MODEL={model}")
|
print(f"MODEL={model}")
|
||||||
@@ -156,8 +182,7 @@ def test_fp8_kv_scale_compile(optimization_level: int):
|
|||||||
|
|
||||||
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
||||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
pytest.skip("inductor graph partition is only available "
|
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||||
"in PyTorch 2.9+")
|
|
||||||
|
|
||||||
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
|
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
@@ -171,14 +196,16 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
|||||||
"kv_cache_dtype": "fp8",
|
"kv_cache_dtype": "fp8",
|
||||||
"max_model_len": 1024,
|
"max_model_len": 1024,
|
||||||
}
|
}
|
||||||
with caplog_vllm.at_level(
|
with (
|
||||||
logging.DEBUG), global_force_attn_backend_context_manager(
|
caplog_vllm.at_level(logging.DEBUG),
|
||||||
_Backend.FLASHINFER):
|
global_force_attn_backend_context_manager(_Backend.FLASHINFER),
|
||||||
|
):
|
||||||
run_model(compilation_config, model, model_kwargs)
|
run_model(compilation_config, model, model_kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert ("Fused quantization onto 48 attention nodes"
|
assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, (
|
||||||
in caplog_vllm.text), caplog_vllm.text
|
caplog_vllm.text
|
||||||
|
)
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
# Note: this message is only triggered when the compilation goes
|
# Note: this message is only triggered when the compilation goes
|
||||||
# through the custom pass. Due to multiple layers of cache on
|
# through the custom pass. Due to multiple layers of cache on
|
||||||
@@ -189,8 +216,11 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
|||||||
assert "Fused quantization" not in caplog_vllm.text
|
assert "Fused quantization" not in caplog_vllm.text
|
||||||
|
|
||||||
|
|
||||||
def run_model(compile_config: Union[int, CompilationConfig], model: str,
|
def run_model(
|
||||||
model_kwargs: dict[str, Any]):
|
compile_config: Union[int, CompilationConfig],
|
||||||
|
model: str,
|
||||||
|
model_kwargs: dict[str, Any],
|
||||||
|
):
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
"The president of the United States is",
|
"The president of the United States is",
|
||||||
|
|||||||
@@ -14,10 +14,8 @@ from vllm.compilation.post_cleanup import PostCleanupPass
|
|||||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
GroupShape)
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
||||||
Fp8LinearOp)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@@ -28,7 +26,6 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
|||||||
|
|
||||||
|
|
||||||
class TestSiluMul(torch.nn.Module):
|
class TestSiluMul(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size: int = 128):
|
def __init__(self, hidden_size: int = 128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.silu_and_mul = SiluAndMul()
|
self.silu_and_mul = SiluAndMul()
|
||||||
@@ -36,8 +33,7 @@ class TestSiluMul(torch.nn.Module):
|
|||||||
self.scale = torch.rand(1, dtype=torch.float32)
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
|
|
||||||
if TEST_FP8:
|
if TEST_FP8:
|
||||||
self.w = torch.rand(hidden_size,
|
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||||
hidden_size).to(dtype=FP8_DTYPE).t()
|
|
||||||
self.fp8_linear = Fp8LinearOp(
|
self.fp8_linear = Fp8LinearOp(
|
||||||
act_quant_static=True,
|
act_quant_static=True,
|
||||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||||
@@ -46,17 +42,14 @@ class TestSiluMul(torch.nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = self.silu_and_mul(x)
|
y = self.silu_and_mul(x)
|
||||||
if TEST_FP8:
|
if TEST_FP8:
|
||||||
x2 = self.fp8_linear.apply(y,
|
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
|
||||||
self.w,
|
|
||||||
self.wscale,
|
|
||||||
input_scale=self.wscale)
|
|
||||||
return x2
|
return x2
|
||||||
else:
|
else:
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def example_inputs(self, num_tokens=32, hidden_size=128):
|
def example_inputs(self, num_tokens=32, hidden_size=128):
|
||||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||||
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype), )
|
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),)
|
||||||
|
|
||||||
def ops_in_model(self, do_fusion):
|
def ops_in_model(self, do_fusion):
|
||||||
if TEST_FP8 and do_fusion:
|
if TEST_FP8 and do_fusion:
|
||||||
@@ -69,7 +62,6 @@ class TestSiluMul(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestFusedAddRMSNorm(torch.nn.Module):
|
class TestFusedAddRMSNorm(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -78,10 +70,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
|||||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||||
|
|
||||||
self.gate_proj = torch.nn.Parameter(
|
self.gate_proj = torch.nn.Parameter(
|
||||||
torch.empty((intermediate_size, hidden_size), dtype=dtype))
|
torch.empty((intermediate_size, hidden_size), dtype=dtype)
|
||||||
|
)
|
||||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||||
self.norm.weight = torch.nn.Parameter(
|
self.norm.weight = torch.nn.Parameter(
|
||||||
torch.ones(intermediate_size, dtype=dtype))
|
torch.ones(intermediate_size, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||||
|
|
||||||
@@ -89,8 +83,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
|||||||
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
||||||
|
|
||||||
self.scale = torch.rand(1, dtype=torch.float32)
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
self.w = torch.rand(hidden_size,
|
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||||
intermediate_size).to(dtype=FP8_DTYPE).t()
|
|
||||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual):
|
def forward(self, hidden_states, residual):
|
||||||
@@ -120,10 +113,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
|||||||
|
|
||||||
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
|
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
|
||||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||||
dtype=dtype)
|
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||||
residual = torch.randn((batch_size * seq_len, hidden_size),
|
|
||||||
dtype=dtype)
|
|
||||||
return (hidden_states, residual)
|
return (hidden_states, residual)
|
||||||
|
|
||||||
def ops_in_model(self, do_fusion):
|
def ops_in_model(self, do_fusion):
|
||||||
@@ -137,12 +128,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestRotaryEmbedding(torch.nn.Module):
|
class TestRotaryEmbedding(torch.nn.Module):
|
||||||
|
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
|
||||||
def __init__(self,
|
|
||||||
head_dim=64,
|
|
||||||
rotary_dim=None,
|
|
||||||
max_position=2048,
|
|
||||||
base=10000):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.rotary_dim = rotary_dim or head_dim
|
self.rotary_dim = rotary_dim or head_dim
|
||||||
@@ -173,21 +159,15 @@ class TestRotaryEmbedding(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||||
|
def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000):
|
||||||
def __init__(self,
|
|
||||||
head_dim=64,
|
|
||||||
num_heads=4,
|
|
||||||
max_position=2048,
|
|
||||||
base=10000):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.hidden_size = head_dim * num_heads
|
self.hidden_size = head_dim * num_heads
|
||||||
|
|
||||||
self.qkv_proj = torch.nn.Linear(self.hidden_size,
|
self.qkv_proj = torch.nn.Linear(
|
||||||
self.hidden_size * 3,
|
self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16
|
||||||
bias=False,
|
)
|
||||||
dtype=torch.float16)
|
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -233,21 +213,24 @@ MODELS = [
|
|||||||
|
|
||||||
@pytest.mark.parametrize("model_class", MODELS)
|
@pytest.mark.parametrize("model_class", MODELS)
|
||||||
@pytest.mark.parametrize("do_fusion", [True, False])
|
@pytest.mark.parametrize("do_fusion", [True, False])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
|
||||||
reason="Only test on CUDA")
|
|
||||||
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
|
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
vllm_config.compilation_config = CompilationConfig(
|
vllm_config.compilation_config = CompilationConfig(
|
||||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
|
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)
|
||||||
|
)
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||||
cleanup_pass = PostCleanupPass(vllm_config)
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||||
|
|
||||||
passes = ([noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
passes = (
|
||||||
if do_fusion else [noop_pass, cleanup_pass])
|
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
||||||
|
if do_fusion
|
||||||
|
else [noop_pass, cleanup_pass]
|
||||||
|
)
|
||||||
func_pass = FixFunctionalizationPass(vllm_config)
|
func_pass = FixFunctionalizationPass(vllm_config)
|
||||||
|
|
||||||
backend_func = TestBackend(*passes, func_pass)
|
backend_func = TestBackend(*passes, func_pass)
|
||||||
@@ -260,8 +243,7 @@ def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
|
|||||||
# check if the functionalization pass is applied
|
# check if the functionalization pass is applied
|
||||||
for op in model.ops_in_model(do_fusion):
|
for op in model.ops_in_model(do_fusion):
|
||||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||||
assert (find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op)
|
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501
|
||||||
is None) # noqa: E501
|
|
||||||
|
|
||||||
# make sure the ops were all de-functionalized
|
# make sure the ops were all de-functionalized
|
||||||
found = dict()
|
found = dict()
|
||||||
|
|||||||
@@ -5,17 +5,26 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.plugins
|
import vllm.plugins
|
||||||
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
from vllm.compilation.fusion import (
|
||||||
RMSNormQuantFusionPass)
|
FUSED_OPS,
|
||||||
|
QUANT_OPS,
|
||||||
|
FusedRMSQuantKey,
|
||||||
|
RMSNormQuantFusionPass,
|
||||||
|
)
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
|
||||||
VllmConfig)
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape, QuantKey, ScaleDesc)
|
GroupShape,
|
||||||
|
QuantKey,
|
||||||
|
ScaleDesc,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity)
|
Fp8LinearOp,
|
||||||
|
cutlass_fp8_supported,
|
||||||
|
maybe_create_device_identity,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import override_cutlass_fp8_supported
|
from ..utils import override_cutlass_fp8_supported
|
||||||
@@ -25,9 +34,15 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
|||||||
|
|
||||||
|
|
||||||
class TestModel(torch.nn.Module):
|
class TestModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
def __init__(self, hidden_size: int, eps: float, static: bool,
|
self,
|
||||||
cuda_force_torch: bool, *args, **kwargs):
|
hidden_size: int,
|
||||||
|
eps: float,
|
||||||
|
static: bool,
|
||||||
|
cuda_force_torch: bool,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.cuda_force_torch = cuda_force_torch
|
self.cuda_force_torch = cuda_force_torch
|
||||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||||
@@ -54,17 +69,15 @@ class TestModel(torch.nn.Module):
|
|||||||
resid = torch.sqrt(x)
|
resid = torch.sqrt(x)
|
||||||
y = self.norm[0](x)
|
y = self.norm[0](x)
|
||||||
|
|
||||||
x2 = self.fp8_linear.apply(y,
|
x2 = self.fp8_linear.apply(
|
||||||
self.w[0],
|
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||||
self.wscale[0],
|
)
|
||||||
input_scale=self.scale[0])
|
|
||||||
# make sure resid is used for replacement to work
|
# make sure resid is used for replacement to work
|
||||||
y2, resid = self.norm[1](x2, resid)
|
y2, resid = self.norm[1](x2, resid)
|
||||||
|
|
||||||
x3 = self.fp8_linear.apply(y2,
|
x3 = self.fp8_linear.apply(
|
||||||
self.w[1],
|
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||||
self.wscale[1],
|
)
|
||||||
input_scale=self.scale[1])
|
|
||||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||||
return y3
|
return y3
|
||||||
|
|
||||||
@@ -74,7 +87,7 @@ class TestModel(torch.nn.Module):
|
|||||||
def ops_in_model_after(self):
|
def ops_in_model_after(self):
|
||||||
return [
|
return [
|
||||||
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
|
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
|
||||||
FUSED_OPS[FusedRMSQuantKey(self.key, True)]
|
FUSED_OPS[FusedRMSQuantKey(self.key, True)],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -85,22 +98,27 @@ class TestModel(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize("static", [True, False])
|
@pytest.mark.parametrize("static", [True, False])
|
||||||
# cuda_force_torch used to test torch code path on platforms that
|
# cuda_force_torch used to test torch code path on platforms that
|
||||||
# cutlass_fp8_supported() == True.
|
# cutlass_fp8_supported() == True.
|
||||||
@pytest.mark.parametrize("cuda_force_torch",
|
@pytest.mark.parametrize(
|
||||||
[True, False] if cutlass_fp8_supported() else [True])
|
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
)
|
||||||
reason="Only test on CUDA and ROCm")
|
@pytest.mark.skipif(
|
||||||
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
|
||||||
cuda_force_torch):
|
)
|
||||||
|
def test_fusion_rmsnorm_quant(
|
||||||
|
dtype, hidden_size, num_tokens, eps, static, cuda_force_torch
|
||||||
|
):
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
compilation_config=CompilationConfig(
|
||||||
custom_ops=["+rms_norm", "+quant_fp8"],
|
level=CompilationLevel.PIECEWISE,
|
||||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
custom_ops=["+rms_norm", "+quant_fp8"],
|
||||||
))
|
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||||
|
)
|
||||||
|
)
|
||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
# Reshape pass is needed for the fusion pass to work
|
# Reshape pass is needed for the fusion pass to work
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
|
|||||||
@@ -10,14 +10,24 @@ from vllm.compilation.collective_fusion import AllReduceFusionPass
|
|||||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
|
from vllm.config import (
|
||||||
ModelConfig, PassConfig, VllmConfig)
|
CompilationConfig,
|
||||||
|
CompilationLevel,
|
||||||
|
DeviceConfig,
|
||||||
|
ModelConfig,
|
||||||
|
PassConfig,
|
||||||
|
VllmConfig,
|
||||||
|
)
|
||||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
from vllm.distributed.parallel_state import (
|
||||||
initialize_model_parallel)
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
GroupShape, QuantFP8)
|
GroupShape,
|
||||||
|
QuantFP8,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
@@ -26,7 +36,6 @@ from .backend import TestBackend
|
|||||||
|
|
||||||
|
|
||||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -47,7 +56,6 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -68,25 +76,22 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
|
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.norm = RMSNorm(hidden_size, eps)
|
self.norm = RMSNorm(hidden_size, eps)
|
||||||
self.quant_fp8 = QuantFP8(static=True,
|
self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
|
||||||
group_shape=GroupShape.PER_TENSOR)
|
|
||||||
self.scale = torch.rand(1, dtype=torch.float32)
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
self.output = torch.empty((token_num, hidden_size),
|
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, residual):
|
def forward(self, hidden_states, residual):
|
||||||
view = hidden_states.reshape(-1, self.hidden_size)
|
view = hidden_states.reshape(-1, self.hidden_size)
|
||||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||||
torch.ops._C.static_scaled_fp8_quant(self.output,
|
torch.ops._C.static_scaled_fp8_quant(
|
||||||
norm_output.contiguous(),
|
self.output, norm_output.contiguous(), self.scale
|
||||||
self.scale)
|
)
|
||||||
return self.output, residual_output
|
return self.output, residual_output
|
||||||
|
|
||||||
def ops_in_model_after(self):
|
def ops_in_model_after(self):
|
||||||
@@ -95,35 +100,33 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
|
|||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
return [
|
return [
|
||||||
torch.ops.vllm.all_reduce.default,
|
torch.ops.vllm.all_reduce.default,
|
||||||
torch.ops._C.static_scaled_fp8_quant.default
|
torch.ops._C.static_scaled_fp8_quant.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.norm = RMSNorm(hidden_size, eps)
|
self.norm = RMSNorm(hidden_size, eps)
|
||||||
self.scale = torch.rand(1, dtype=torch.float32)
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
self.output = torch.empty((token_num, hidden_size),
|
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
round_up = lambda x, y: (x + y - 1) // y * y
|
round_up = lambda x, y: (x + y - 1) // y * y
|
||||||
rounded_m = round_up(token_num, 128)
|
rounded_m = round_up(token_num, 128)
|
||||||
scale_n = hidden_size // 16
|
scale_n = hidden_size // 16
|
||||||
rounded_n = round_up(scale_n, 4)
|
rounded_n = round_up(scale_n, 4)
|
||||||
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
|
self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32)
|
||||||
dtype=torch.int32)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, residual):
|
def forward(self, hidden_states, residual):
|
||||||
view = hidden_states.reshape(-1, self.hidden_size)
|
view = hidden_states.reshape(-1, self.hidden_size)
|
||||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||||
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
|
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
|
||||||
torch.ops._C.scaled_fp4_quant(self.output, norm_output,
|
torch.ops._C.scaled_fp4_quant(
|
||||||
self.output_scale, self.scale)
|
self.output, norm_output, self.output_scale, self.scale
|
||||||
|
)
|
||||||
return self.output, residual_output, self.output_scale
|
return self.output, residual_output, self.output_scale
|
||||||
|
|
||||||
def ops_in_model_after(self):
|
def ops_in_model_after(self):
|
||||||
@@ -132,7 +135,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
|||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
return [
|
return [
|
||||||
torch.ops.vllm.all_reduce.default,
|
torch.ops.vllm.all_reduce.default,
|
||||||
torch.ops._C.scaled_fp4_quant.default
|
torch.ops._C.scaled_fp4_quant.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -145,41 +148,55 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
|||||||
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
|
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
|
||||||
# TODO: Enable with torch==2.8.0
|
# TODO: Enable with torch==2.8.0
|
||||||
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
|
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
|
||||||
])
|
],
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize("seq_len", [8])
|
@pytest.mark.parametrize("seq_len", [8])
|
||||||
@pytest.mark.parametrize("hidden_size", [16])
|
@pytest.mark.parametrize("hidden_size", [16])
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||||
reason="Only test on CUDA")
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not find_spec("flashinfer")
|
not find_spec("flashinfer")
|
||||||
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
|
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
|
||||||
reason="flashinfer is not found or flashinfer "
|
reason="flashinfer is not found or flashinfer "
|
||||||
"is not compiled with trtllm_allreduce_fusion")
|
"is not compiled with trtllm_allreduce_fusion",
|
||||||
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
|
)
|
||||||
batch_size: int, seq_len: int,
|
def test_all_reduce_fusion_pass_replace(
|
||||||
hidden_size: int, dtype: torch.dtype):
|
test_model: torch.nn.Module,
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
num_processes = 2
|
num_processes = 2
|
||||||
if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
|
if (
|
||||||
and not current_platform.has_device_capability(100)):
|
test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
|
||||||
pytest.skip("Skip as nvfp4 is only supported on "
|
and not current_platform.has_device_capability(100)
|
||||||
"devices with compute capability 10.0 (Blackwell)")
|
):
|
||||||
|
pytest.skip(
|
||||||
|
"Skip as nvfp4 is only supported on "
|
||||||
|
"devices with compute capability 10.0 (Blackwell)"
|
||||||
|
)
|
||||||
|
|
||||||
def run_torch_spawn(fn, nprocs):
|
def run_torch_spawn(fn, nprocs):
|
||||||
torch.multiprocessing.spawn(fn,
|
torch.multiprocessing.spawn(
|
||||||
args=(num_processes, test_model,
|
fn,
|
||||||
batch_size, seq_len, hidden_size,
|
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
|
||||||
dtype),
|
nprocs=nprocs,
|
||||||
nprocs=nprocs)
|
)
|
||||||
|
|
||||||
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
|
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
|
||||||
|
|
||||||
|
|
||||||
def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
def all_reduce_fusion_pass_on_test_model(
|
||||||
test_model_cls: torch.nn.Module,
|
local_rank: int,
|
||||||
batch_size: int, seq_len: int,
|
world_size: int,
|
||||||
hidden_size: int, dtype: torch.dtype):
|
test_model_cls: torch.nn.Module,
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
@@ -187,39 +204,42 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
|
|
||||||
update_environment_variables({
|
update_environment_variables(
|
||||||
'RANK': str(local_rank),
|
{
|
||||||
'LOCAL_RANK': str(local_rank),
|
"RANK": str(local_rank),
|
||||||
'WORLD_SIZE': str(world_size),
|
"LOCAL_RANK": str(local_rank),
|
||||||
'MASTER_ADDR': 'localhost',
|
"WORLD_SIZE": str(world_size),
|
||||||
'MASTER_PORT': '12345',
|
"MASTER_ADDR": "localhost",
|
||||||
})
|
"MASTER_PORT": "12345",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
compilation_config=CompilationConfig(
|
||||||
custom_ops=["+rms_norm", "+quant_fp8"]))
|
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm", "+quant_fp8"]
|
||||||
|
)
|
||||||
|
)
|
||||||
vllm_config.compilation_config.pass_config = PassConfig(
|
vllm_config.compilation_config.pass_config = PassConfig(
|
||||||
enable_fi_allreduce_fusion=True, enable_noop=True)
|
enable_fi_allreduce_fusion=True, enable_noop=True
|
||||||
|
)
|
||||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||||
|
|
||||||
# this is a fake model name to construct the model config
|
# this is a fake model name to construct the model config
|
||||||
# in the vllm_config, it's not really used.
|
# in the vllm_config, it's not really used.
|
||||||
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||||
vllm_config.model_config = ModelConfig(model=model_name,
|
vllm_config.model_config = ModelConfig(
|
||||||
trust_remote_code=True,
|
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||||
dtype=dtype,
|
)
|
||||||
seed=42)
|
|
||||||
|
|
||||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
func_pass = FixFunctionalizationPass(vllm_config)
|
func_pass = FixFunctionalizationPass(vllm_config)
|
||||||
cleanup_pass = PostCleanupPass(vllm_config)
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
|
|
||||||
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass,
|
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass)
|
||||||
cleanup_pass)
|
|
||||||
|
|
||||||
token_num = batch_size * seq_len
|
token_num = batch_size * seq_len
|
||||||
model = test_model_cls(hidden_size, token_num)
|
model = test_model_cls(hidden_size, token_num)
|
||||||
|
|||||||
@@ -19,14 +19,23 @@ from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
|||||||
from vllm.compilation.fx_utils import find_op_nodes
|
from vllm.compilation.fx_utils import find_op_nodes
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
from vllm.config import (
|
||||||
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
|
CacheConfig,
|
||||||
set_current_vllm_config)
|
CompilationConfig,
|
||||||
|
CompilationLevel,
|
||||||
|
ModelConfig,
|
||||||
|
PassConfig,
|
||||||
|
SchedulerConfig,
|
||||||
|
VllmConfig,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
|
QuantKey,
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
kFp8StaticTensorSym,
|
||||||
Fp8LinearOp)
|
kNvfp4Quant,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_torch_equal_or_newer
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
@@ -40,14 +49,16 @@ backend_unfused: Optional[TestBackend] = None
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, quant_key",
|
"model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]
|
||||||
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
|
)
|
||||||
@pytest.mark.parametrize("use_triton_fa", [True, False])
|
@pytest.mark.parametrize("use_triton_fa", [True, False])
|
||||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
@pytest.mark.skipif(
|
||||||
reason="V0 attn quant fusion only on ROCm")
|
not current_platform.is_rocm(), reason="V0 attn quant fusion only on ROCm"
|
||||||
def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
)
|
||||||
quant_key: QuantKey, use_triton_fa: bool):
|
def test_attention_fusion_v0(
|
||||||
|
example_prompts, monkeypatch, model: str, quant_key: QuantKey, use_triton_fa: bool
|
||||||
|
):
|
||||||
# Clean Dynamo cache to avoid reusing other test cases
|
# Clean Dynamo cache to avoid reusing other test cases
|
||||||
# (for some reason the reset at the end is not enough)
|
# (for some reason the reset at the end is not enough)
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
@@ -69,22 +80,24 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
|||||||
backend="tests.compile.test_fusion_attn.backend_unfused",
|
backend="tests.compile.test_fusion_attn.backend_unfused",
|
||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
)
|
)
|
||||||
vllm_config = VllmConfig(compilation_config=compile_config,
|
vllm_config = VllmConfig(
|
||||||
model_config=ModelConfig(
|
compilation_config=compile_config,
|
||||||
model=model,
|
model_config=ModelConfig(
|
||||||
dtype=torch.bfloat16,
|
model=model,
|
||||||
))
|
dtype=torch.bfloat16,
|
||||||
|
),
|
||||||
|
)
|
||||||
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
|
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
|
||||||
|
|
||||||
llm = LLM(model,
|
llm = LLM(
|
||||||
enforce_eager=True,
|
model,
|
||||||
compilation_config=compile_config,
|
enforce_eager=True,
|
||||||
gpu_memory_utilization=0.5,
|
compilation_config=compile_config,
|
||||||
max_model_len=2048)
|
gpu_memory_utilization=0.5,
|
||||||
|
max_model_len=2048,
|
||||||
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.0,
|
sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_p=0.95)
|
||||||
max_tokens=10,
|
|
||||||
top_p=0.95)
|
|
||||||
|
|
||||||
unfused_output = llm.generate(prompts, sampling_params)
|
unfused_output = llm.generate(prompts, sampling_params)
|
||||||
backend_unfused = None # Reset backend to make sure llm gets released
|
backend_unfused = None # Reset backend to make sure llm gets released
|
||||||
@@ -97,21 +110,25 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
|||||||
backend="tests.compile.test_fusion_attn.backend",
|
backend="tests.compile.test_fusion_attn.backend",
|
||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
)
|
)
|
||||||
vllm_config = VllmConfig(compilation_config=compile_config,
|
vllm_config = VllmConfig(
|
||||||
model_config=ModelConfig(
|
compilation_config=compile_config,
|
||||||
model=model,
|
model_config=ModelConfig(
|
||||||
dtype=torch.bfloat16,
|
model=model,
|
||||||
))
|
dtype=torch.bfloat16,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# AttnFusionPass needs attention layers to be registered in config upon init
|
# AttnFusionPass needs attention layers to be registered in config upon init
|
||||||
# so we initialize it during compilation.
|
# so we initialize it during compilation.
|
||||||
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
|
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
|
||||||
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
|
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
|
||||||
llm2 = LLM(model,
|
llm2 = LLM(
|
||||||
enforce_eager=True,
|
model,
|
||||||
compilation_config=compile_config,
|
enforce_eager=True,
|
||||||
gpu_memory_utilization=0.5,
|
compilation_config=compile_config,
|
||||||
max_model_len=2048)
|
gpu_memory_utilization=0.5,
|
||||||
|
max_model_len=2048,
|
||||||
|
)
|
||||||
|
|
||||||
# check support
|
# check support
|
||||||
attn_fusion_supported = [
|
attn_fusion_supported = [
|
||||||
@@ -132,9 +149,9 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
|||||||
for i in range(len(attn_nodes_pre)):
|
for i in range(len(attn_nodes_pre)):
|
||||||
assert attn_nodes_pre[i].kwargs["output_scale"] is None
|
assert attn_nodes_pre[i].kwargs["output_scale"] is None
|
||||||
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
|
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
|
||||||
assert fused == attn_fusion_supported[i], \
|
assert fused == attn_fusion_supported[i], (
|
||||||
f"Node {i} {'' if fused else 'not '} expected " \
|
f"Node {i} {'' if fused else 'not '} expected to have fused output quant"
|
||||||
f"to have fused output quant"
|
)
|
||||||
|
|
||||||
# check outputs
|
# check outputs
|
||||||
fused_output = llm2.generate(prompts, sampling_params)
|
fused_output = llm2.generate(prompts, sampling_params)
|
||||||
@@ -160,9 +177,16 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
|||||||
class AttentionQuantPatternModel(torch.nn.Module):
|
class AttentionQuantPatternModel(torch.nn.Module):
|
||||||
"""Base model for AttentionQuantPattern fusion."""
|
"""Base model for AttentionQuantPattern fusion."""
|
||||||
|
|
||||||
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
|
def __init__(
|
||||||
kv_cache_dtype: torch.dtype, device: torch.device,
|
self,
|
||||||
vllm_config: VllmConfig, **kwargs):
|
num_qo_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
kv_cache_dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_qo_heads = num_qo_heads
|
self.num_qo_heads = num_qo_heads
|
||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
@@ -197,33 +221,30 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_attn_metadata(self, batch_size: int, use_hnd: bool) \
|
def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetadata:
|
||||||
-> AttentionMetadata:
|
|
||||||
"""Initialize attention metadata."""
|
"""Initialize attention metadata."""
|
||||||
|
|
||||||
# Create common attn metadata
|
# Create common attn metadata
|
||||||
batch_spec = BatchSpec(seq_lens=[1] * batch_size,
|
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
|
||||||
query_lens=[1] * batch_size)
|
|
||||||
common_attn_metadata = create_common_attn_metadata(
|
common_attn_metadata = create_common_attn_metadata(
|
||||||
batch_spec,
|
batch_spec, self.block_size, self.device, arange_block_indices=True
|
||||||
self.block_size,
|
)
|
||||||
self.device,
|
|
||||||
arange_block_indices=True)
|
|
||||||
|
|
||||||
max_blocks = (max(batch_spec.seq_lens) + self.block_size -
|
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
|
||||||
1) // self.block_size
|
|
||||||
num_blocks = batch_size * max_blocks
|
num_blocks = batch_size * max_blocks
|
||||||
|
|
||||||
# Create dummy KV cache for FlashInfer TRTLLM
|
# Create dummy KV cache for FlashInfer TRTLLM
|
||||||
# - NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
# - NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||||
# - HND: [num_blocks, num_kv_heads, block_size, head_size]
|
# - HND: [num_blocks, num_kv_heads, block_size, head_size]
|
||||||
kv_cache = torch.zeros(num_blocks,
|
kv_cache = torch.zeros(
|
||||||
2,
|
num_blocks,
|
||||||
self.num_kv_heads,
|
2,
|
||||||
self.block_size,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.block_size,
|
||||||
dtype=self.kv_cache_dtype,
|
self.head_size,
|
||||||
device=self.device)
|
dtype=self.kv_cache_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
# k/v as 1st dimention
|
# k/v as 1st dimention
|
||||||
if use_hnd:
|
if use_hnd:
|
||||||
@@ -239,7 +260,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
|||||||
|
|
||||||
# Build attn metadata
|
# Build attn metadata
|
||||||
self.attn_metadata = self.builder.build(
|
self.attn_metadata = self.builder.build(
|
||||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata)
|
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||||
|
)
|
||||||
|
|
||||||
return self.attn_metadata
|
return self.attn_metadata
|
||||||
|
|
||||||
@@ -254,27 +276,30 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
|||||||
|
|
||||||
self.fp8_linear = Fp8LinearOp(
|
self.fp8_linear = Fp8LinearOp(
|
||||||
act_quant_static=self.quant_key.scale.static,
|
act_quant_static=self.quant_key.scale.static,
|
||||||
act_quant_group_shape=self.quant_key.scale.group_shape)
|
act_quant_group_shape=self.quant_key.scale.group_shape,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_size = self.num_qo_heads * self.head_size
|
hidden_size = self.num_qo_heads * self.head_size
|
||||||
self.w = kwargs.get(
|
self.w = kwargs.get(
|
||||||
"w", {
|
"w",
|
||||||
"weight":
|
{
|
||||||
torch.randn(hidden_size, hidden_size).to(
|
"weight": torch.randn(hidden_size, hidden_size)
|
||||||
dtype=FP8_DTYPE, device=self.device).t(),
|
.to(dtype=FP8_DTYPE, device=self.device)
|
||||||
"wscale":
|
.t(),
|
||||||
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
"wscale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||||
"scale":
|
"scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||||
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
},
|
||||||
})
|
)
|
||||||
|
|
||||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||||
"""Forward pass that creates the pattern to be fused."""
|
"""Forward pass that creates the pattern to be fused."""
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
return self.fp8_linear.apply(input=attn_output,
|
return self.fp8_linear.apply(
|
||||||
weight=self.w["weight"],
|
input=attn_output,
|
||||||
weight_scale=self.w["wscale"],
|
weight=self.w["weight"],
|
||||||
input_scale=self.w["scale"])
|
weight_scale=self.w["wscale"],
|
||||||
|
input_scale=self.w["scale"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
||||||
@@ -287,42 +312,54 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
|||||||
|
|
||||||
hidden_size = self.num_qo_heads * self.head_size
|
hidden_size = self.num_qo_heads * self.head_size
|
||||||
self.w = kwargs.get(
|
self.w = kwargs.get(
|
||||||
"w", {
|
"w",
|
||||||
"weight":
|
{
|
||||||
torch.randint(256, (hidden_size, hidden_size // 2),
|
"weight": torch.randint(
|
||||||
dtype=FP4_DTYPE,
|
256,
|
||||||
device=self.device),
|
(hidden_size, hidden_size // 2),
|
||||||
"wscale_swizzled":
|
dtype=FP4_DTYPE,
|
||||||
torch.randn(hidden_size, hidden_size // 16).to(
|
device=self.device,
|
||||||
dtype=FP8_DTYPE, device=self.device),
|
),
|
||||||
"wscale":
|
"wscale_swizzled": torch.randn(hidden_size, hidden_size // 16).to(
|
||||||
torch.tensor([500], dtype=torch.float32, device=self.device),
|
dtype=FP8_DTYPE, device=self.device
|
||||||
"scale":
|
),
|
||||||
torch.tensor([0.002], dtype=torch.float32, device=self.device),
|
"wscale": torch.tensor([500], dtype=torch.float32, device=self.device),
|
||||||
})
|
"scale": torch.tensor([0.002], dtype=torch.float32, device=self.device),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||||
"""Forward pass that creates the pattern to be fused."""
|
"""Forward pass that creates the pattern to be fused."""
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
quant_output, output_block_scale = scaled_fp4_quant(
|
quant_output, output_block_scale = scaled_fp4_quant(
|
||||||
attn_output, 1 / self.w["scale"])
|
attn_output, 1 / self.w["scale"]
|
||||||
return cutlass_scaled_fp4_mm(a=quant_output,
|
)
|
||||||
b=self.w["weight"],
|
return cutlass_scaled_fp4_mm(
|
||||||
block_scale_a=output_block_scale,
|
a=quant_output,
|
||||||
block_scale_b=self.w["wscale_swizzled"],
|
b=self.w["weight"],
|
||||||
alpha=self.w["scale"] * self.w["wscale"],
|
block_scale_a=output_block_scale,
|
||||||
out_dtype=attn_output.dtype)
|
block_scale_b=self.w["wscale_swizzled"],
|
||||||
|
alpha=self.w["scale"] * self.w["wscale"],
|
||||||
|
out_dtype=attn_output.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
MODELS = [
|
||||||
TestAttentionFp8StaticQuantPatternModel),
|
(
|
||||||
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||||
TestAttentionNvfp4QuantPatternModel)]
|
TestAttentionFp8StaticQuantPatternModel,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
||||||
|
TestAttentionNvfp4QuantPatternModel,
|
||||||
|
),
|
||||||
|
]
|
||||||
HEADS = [(64, 8), (40, 8)]
|
HEADS = [(64, 8), (40, 8)]
|
||||||
elif current_platform.is_rocm():
|
elif current_platform.is_rocm():
|
||||||
MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV",
|
MODELS = [
|
||||||
TestAttentionFp8StaticQuantPatternModel)]
|
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
|
||||||
|
]
|
||||||
HEADS = [(32, 8), (40, 8)]
|
HEADS = [(32, 8), (40, 8)]
|
||||||
else:
|
else:
|
||||||
MODELS = []
|
MODELS = []
|
||||||
@@ -331,41 +368,53 @@ else:
|
|||||||
|
|
||||||
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
|
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
|
||||||
@pytest.mark.parametrize("head_size", [128])
|
@pytest.mark.parametrize("head_size", [128])
|
||||||
@pytest.mark.parametrize("batch_size",
|
@pytest.mark.parametrize(
|
||||||
[7, 256, 533] if current_platform.is_cuda() else [8])
|
"batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||||
@pytest.mark.parametrize("model_name, model_class", MODELS)
|
@pytest.mark.parametrize("model_name, model_class", MODELS)
|
||||||
@pytest.mark.parametrize("backend",
|
|
||||||
[_Backend.FLASHINFER] if current_platform.is_cuda()
|
|
||||||
else [_Backend.TRITON_ATTN])
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"split_attention",
|
"backend",
|
||||||
[False, True] if current_platform.is_rocm() else [False])
|
[_Backend.FLASHINFER] if current_platform.is_cuda() else [_Backend.TRITON_ATTN],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"split_attention", [False, True] if current_platform.is_rocm() else [False]
|
||||||
|
)
|
||||||
# TODO(boyuan): test inductor graph partition on rocm
|
# TODO(boyuan): test inductor graph partition on rocm
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"use_inductor_graph_partition",
|
"use_inductor_graph_partition",
|
||||||
[False] if current_platform.is_rocm() else [False, True])
|
[False] if current_platform.is_rocm() else [False, True],
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
)
|
||||||
reason="Only test ROCm or CUDA")
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
|
||||||
|
)
|
||||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||||
@pytest.mark.skipif(current_platform.is_cuda()
|
@pytest.mark.skipif(
|
||||||
and not current_platform.is_device_capability((10, 0)),
|
current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)),
|
||||||
reason="On CUDA only test on SM100(Blackwell)")
|
reason="On CUDA only test on SM100(Blackwell)",
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
)
|
||||||
reason="Only test ROCm or CUDA")
|
@pytest.mark.skipif(
|
||||||
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
|
||||||
head_size: int, batch_size: int,
|
)
|
||||||
dtype: torch.dtype, model_name: str,
|
def test_attention_quant_pattern(
|
||||||
model_class: type[AttentionQuantPatternModel],
|
num_qo_heads: int,
|
||||||
backend: _Backend, split_attention: bool,
|
num_kv_heads: int,
|
||||||
use_inductor_graph_partition: bool,
|
head_size: int,
|
||||||
monkeypatch, dist_init, caplog_vllm):
|
batch_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
model_name: str,
|
||||||
|
model_class: type[AttentionQuantPatternModel],
|
||||||
|
backend: _Backend,
|
||||||
|
split_attention: bool,
|
||||||
|
use_inductor_graph_partition: bool,
|
||||||
|
monkeypatch,
|
||||||
|
dist_init,
|
||||||
|
caplog_vllm,
|
||||||
|
):
|
||||||
"""Test AttentionStaticQuantPattern fusion pass"""
|
"""Test AttentionStaticQuantPattern fusion pass"""
|
||||||
|
|
||||||
if use_inductor_graph_partition and not is_torch_equal_or_newer(
|
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
"2.9.0.dev"):
|
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||||
pytest.skip("inductor graph partition is only available "
|
|
||||||
"in PyTorch 2.9+")
|
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
if split_attention:
|
if split_attention:
|
||||||
@@ -386,21 +435,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||||
),
|
),
|
||||||
cache_config=CacheConfig(cache_dtype="fp8"))
|
cache_config=CacheConfig(cache_dtype="fp8"),
|
||||||
|
)
|
||||||
|
|
||||||
# Create test inputs
|
# Create test inputs
|
||||||
q = torch.randn(batch_size,
|
q = torch.randn(batch_size, num_qo_heads * head_size, dtype=dtype, device=device)
|
||||||
num_qo_heads * head_size,
|
k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||||
dtype=dtype,
|
v = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||||
device=device)
|
|
||||||
k = torch.randn(batch_size,
|
|
||||||
num_kv_heads * head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
v = torch.randn(batch_size,
|
|
||||||
num_kv_heads * head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
# Mark first dimension as dynamic for realistic testing
|
# Mark first dimension as dynamic for realistic testing
|
||||||
torch._dynamo.mark_dynamic(q, 0)
|
torch._dynamo.mark_dynamic(q, 0)
|
||||||
@@ -409,42 +450,53 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
|
|
||||||
# Run model directly without compilation and fusion
|
# Run model directly without compilation and fusion
|
||||||
vllm_config_unfused = copy.deepcopy(vllm_config)
|
vllm_config_unfused = copy.deepcopy(vllm_config)
|
||||||
with set_current_vllm_config(vllm_config_unfused), set_forward_context(
|
with (
|
||||||
attn_metadata=None, vllm_config=vllm_config_unfused
|
set_current_vllm_config(vllm_config_unfused),
|
||||||
), global_force_attn_backend_context_manager(backend):
|
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
|
||||||
model_unfused = model_class(num_qo_heads=num_qo_heads,
|
global_force_attn_backend_context_manager(backend),
|
||||||
num_kv_heads=num_kv_heads,
|
):
|
||||||
head_size=head_size,
|
model_unfused = model_class(
|
||||||
kv_cache_dtype=FP8_DTYPE,
|
num_qo_heads=num_qo_heads,
|
||||||
device=device,
|
num_kv_heads=num_kv_heads,
|
||||||
vllm_config=vllm_config_unfused)
|
head_size=head_size,
|
||||||
|
kv_cache_dtype=FP8_DTYPE,
|
||||||
|
device=device,
|
||||||
|
vllm_config=vllm_config_unfused,
|
||||||
|
)
|
||||||
model_unfused = model_unfused.to(device)
|
model_unfused = model_unfused.to(device)
|
||||||
|
|
||||||
forward_ctx = get_forward_context()
|
forward_ctx = get_forward_context()
|
||||||
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
|
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
|
||||||
batch_size, use_hnd=split_attention)
|
batch_size, use_hnd=split_attention
|
||||||
|
)
|
||||||
|
|
||||||
# Run model directly without compilation and fusion
|
# Run model directly without compilation and fusion
|
||||||
result_unfused = model_unfused(q, k, v)
|
result_unfused = model_unfused(q, k, v)
|
||||||
|
|
||||||
# Run model with attn fusion enabled
|
# Run model with attn fusion enabled
|
||||||
vllm_config.compilation_config.pass_config = PassConfig(
|
vllm_config.compilation_config.pass_config = PassConfig(
|
||||||
enable_attn_fusion=True, enable_noop=True)
|
enable_attn_fusion=True, enable_noop=True
|
||||||
with set_current_vllm_config(vllm_config), set_forward_context(
|
)
|
||||||
attn_metadata=None, vllm_config=vllm_config
|
with (
|
||||||
), global_force_attn_backend_context_manager(backend):
|
set_current_vllm_config(vllm_config),
|
||||||
model_fused = model_class(num_qo_heads=num_qo_heads,
|
set_forward_context(attn_metadata=None, vllm_config=vllm_config),
|
||||||
num_kv_heads=num_kv_heads,
|
global_force_attn_backend_context_manager(backend),
|
||||||
head_size=head_size,
|
):
|
||||||
kv_cache_dtype=FP8_DTYPE,
|
model_fused = model_class(
|
||||||
device=device,
|
num_qo_heads=num_qo_heads,
|
||||||
vllm_config=vllm_config,
|
num_kv_heads=num_kv_heads,
|
||||||
w=model_unfused.w)
|
head_size=head_size,
|
||||||
|
kv_cache_dtype=FP8_DTYPE,
|
||||||
|
device=device,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
w=model_unfused.w,
|
||||||
|
)
|
||||||
model_fused = model_fused.to(device)
|
model_fused = model_fused.to(device)
|
||||||
|
|
||||||
forward_ctx = get_forward_context()
|
forward_ctx = get_forward_context()
|
||||||
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
|
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
|
||||||
batch_size, use_hnd=split_attention)
|
batch_size, use_hnd=split_attention
|
||||||
|
)
|
||||||
|
|
||||||
# Create test backend with fusion passes enabled
|
# Create test backend with fusion passes enabled
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
@@ -454,9 +506,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
|
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
|
||||||
|
|
||||||
# Compile model with fusion enabled
|
# Compile model with fusion enabled
|
||||||
model_compiled = torch.compile(model_fused,
|
model_compiled = torch.compile(
|
||||||
backend=test_backend,
|
model_fused, backend=test_backend, fullgraph=True
|
||||||
fullgraph=True)
|
)
|
||||||
assert model_compiled.attn._o_scale_float is None
|
assert model_compiled.attn._o_scale_float is None
|
||||||
|
|
||||||
result_fused_1 = model_compiled(q, k, v)
|
result_fused_1 = model_compiled(q, k, v)
|
||||||
@@ -471,49 +523,49 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
|
|
||||||
assert model_compiled.attn._o_scale_float is not None
|
assert model_compiled.attn._o_scale_float is not None
|
||||||
|
|
||||||
torch.testing.assert_close(result_unfused,
|
torch.testing.assert_close(
|
||||||
result_fused_2,
|
result_unfused, result_fused_2, atol=1e-2, rtol=1e-2
|
||||||
atol=1e-2,
|
)
|
||||||
rtol=1e-2)
|
|
||||||
|
|
||||||
# Check attn fusion support
|
# Check attn fusion support
|
||||||
quant_key = model_class.quant_key
|
quant_key = model_class.quant_key
|
||||||
attn_fusion_supported = [
|
attn_fusion_supported = [
|
||||||
layer.impl.fused_output_quant_supported(quant_key) for key, layer in
|
layer.impl.fused_output_quant_supported(quant_key)
|
||||||
vllm_config.compilation_config.static_forward_context.items()
|
for key, layer in vllm_config.compilation_config.static_forward_context.items()
|
||||||
]
|
]
|
||||||
if any(attn_fusion_supported):
|
if any(attn_fusion_supported):
|
||||||
# Check quantization ops in the graph before and after fusion
|
# Check quantization ops in the graph before and after fusion
|
||||||
test_backend.check_before_ops([QUANT_OPS[quant_key]],
|
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True)
|
||||||
fully_replaced=True)
|
|
||||||
|
|
||||||
# access the underlying `AttnFusionPass` on the `LazyInitPass`
|
# access the underlying `AttnFusionPass` on the `LazyInitPass`
|
||||||
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
|
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
|
||||||
|
|
||||||
# Check attention ops in the graph before and after fusion
|
# Check attention ops in the graph before and after fusion
|
||||||
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
|
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
|
||||||
attn_nodes_post = list(find_op_nodes(ATTN_OP,
|
attn_nodes_post = list(find_op_nodes(ATTN_OP, test_backend.graph_post_pass))
|
||||||
test_backend.graph_post_pass))
|
|
||||||
|
|
||||||
assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
|
assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
|
||||||
assert len(attn_nodes_pre) == len(attn_nodes_post), \
|
assert len(attn_nodes_pre) == len(attn_nodes_post), (
|
||||||
"Should have same number of attention nodes before and after fusion"
|
"Should have same number of attention nodes before and after fusion"
|
||||||
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \
|
)
|
||||||
|
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
|
||||||
"Attention should not have output_scale before fusion"
|
"Attention should not have output_scale before fusion"
|
||||||
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
|
)
|
||||||
|
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
|
||||||
"Attention should have output_scale after fusion"
|
"Attention should have output_scale after fusion"
|
||||||
|
)
|
||||||
|
|
||||||
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \
|
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
|
||||||
"Attention should not have output_block_scale before fusion"
|
"Attention should not have output_block_scale before fusion"
|
||||||
|
)
|
||||||
if quant_key.dtype == FP8_DTYPE:
|
if quant_key.dtype == FP8_DTYPE:
|
||||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \
|
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
|
||||||
"Attention should not have output_block_scale after FP8 fusion"
|
"Attention should not have output_block_scale after FP8 fusion"
|
||||||
|
)
|
||||||
elif quant_key.dtype == FP4_DTYPE:
|
elif quant_key.dtype == FP4_DTYPE:
|
||||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
|
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
|
||||||
"Attention should have output_block_scale after FP4 fusion" # noqa: E501
|
"Attention should have output_block_scale after FP4 fusion"
|
||||||
|
) # noqa: E501
|
||||||
|
|
||||||
# Check that results are close
|
# Check that results are close
|
||||||
torch.testing.assert_close(result_unfused,
|
torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2)
|
||||||
result_fused_1,
|
|
||||||
atol=1e-2,
|
|
||||||
rtol=1e-2)
|
|
||||||
|
|||||||
@@ -6,14 +6,12 @@ import torch
|
|||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
|
||||||
VllmConfig)
|
|
||||||
|
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype",
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||||
[torch.float16, torch.bfloat16, torch.float32])
|
|
||||||
@pytest.mark.parametrize("num_tokens", [256, 1024])
|
@pytest.mark.parametrize("num_tokens", [256, 1024])
|
||||||
@pytest.mark.parametrize("hidden_size", [64, 4096])
|
@pytest.mark.parametrize("hidden_size", [64, 4096])
|
||||||
def test_noop_elimination(dtype, num_tokens, hidden_size):
|
def test_noop_elimination(dtype, num_tokens, hidden_size):
|
||||||
@@ -22,7 +20,6 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
|
|||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Chain of reshapes
|
# Chain of reshapes
|
||||||
y = x.reshape(-1, 128, 32)
|
y = x.reshape(-1, 128, 32)
|
||||||
@@ -32,7 +29,7 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
|
|||||||
# Final reshape that should remain
|
# Final reshape that should remain
|
||||||
b = a.reshape(-1, 128, 32)
|
b = a.reshape(-1, 128, 32)
|
||||||
# No-op slice
|
# No-op slice
|
||||||
c = b[0:b.shape[0]]
|
c = b[0 : b.shape[0]]
|
||||||
# The pass should replace the result of this op with `c`.
|
# The pass should replace the result of this op with `c`.
|
||||||
d = torch.slice_scatter(
|
d = torch.slice_scatter(
|
||||||
torch.ones_like(c), # Dummy tensor to be scattered into
|
torch.ones_like(c), # Dummy tensor to be scattered into
|
||||||
@@ -43,10 +40,12 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
|
|||||||
)
|
)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
compilation_config=CompilationConfig(
|
||||||
pass_config=PassConfig(enable_noop=True),
|
level=CompilationLevel.PIECEWISE,
|
||||||
))
|
pass_config=PassConfig(enable_noop=True),
|
||||||
|
)
|
||||||
|
)
|
||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
|
|
||||||
@@ -82,17 +81,18 @@ def test_non_noop_slice_preserved():
|
|||||||
x = torch.randn(16, 16)
|
x = torch.randn(16, 16)
|
||||||
|
|
||||||
class SliceModel(torch.nn.Module):
|
class SliceModel(torch.nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
base = x.clone()
|
base = x.clone()
|
||||||
src = torch.ones(15, 16)
|
src = torch.ones(15, 16)
|
||||||
y = torch.slice_scatter(base, src, dim=0, start=0, end=-1)
|
y = torch.slice_scatter(base, src, dim=0, start=0, end=-1)
|
||||||
return x[0:-1, :], y
|
return x[0:-1, :], y
|
||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
compilation_config=CompilationConfig(
|
||||||
pass_config=PassConfig(enable_noop=True),
|
level=CompilationLevel.PIECEWISE,
|
||||||
))
|
pass_config=PassConfig(enable_noop=True),
|
||||||
|
)
|
||||||
|
)
|
||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
backend = TestBackend(noop_pass)
|
backend = TestBackend(noop_pass)
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ def test_bad_callable():
|
|||||||
|
|
||||||
# Pass that inherits from InductorPass
|
# Pass that inherits from InductorPass
|
||||||
class ProperPass(InductorPass):
|
class ProperPass(InductorPass):
|
||||||
|
|
||||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -39,8 +38,7 @@ class ProperPass(InductorPass):
|
|||||||
ProperPass(),
|
ProperPass(),
|
||||||
# Can also wrap callables in CallableInductorPass for compliance
|
# Can also wrap callables in CallableInductorPass for compliance
|
||||||
CallableInductorPass(simple_callable),
|
CallableInductorPass(simple_callable),
|
||||||
CallableInductorPass(simple_callable,
|
CallableInductorPass(simple_callable, InductorPass.hash_source(__file__)),
|
||||||
InductorPass.hash_source(__file__))
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_pass_manager_uuid(callable):
|
def test_pass_manager_uuid(callable):
|
||||||
@@ -65,8 +63,9 @@ def test_pass_manager_uuid(callable):
|
|||||||
|
|
||||||
# UUID should be different due to config change
|
# UUID should be different due to config change
|
||||||
config2 = copy.deepcopy(config)
|
config2 = copy.deepcopy(config)
|
||||||
config2.compilation_config.pass_config.enable_fusion = not \
|
config2.compilation_config.pass_config.enable_fusion = (
|
||||||
config2.compilation_config.pass_config.enable_fusion
|
not config2.compilation_config.pass_config.enable_fusion
|
||||||
|
)
|
||||||
pass_manager3 = PostGradPassManager()
|
pass_manager3 = PostGradPassManager()
|
||||||
pass_manager3.configure(config2)
|
pass_manager3.configure(config2)
|
||||||
pass_manager3.add(callable)
|
pass_manager3.add(callable)
|
||||||
|
|||||||
@@ -12,14 +12,20 @@ from vllm.compilation.noop_elimination import NoOpEliminationPass
|
|||||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||||
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (
|
||||||
PassConfig, VllmConfig)
|
CompilationConfig,
|
||||||
|
DeviceConfig,
|
||||||
|
ModelConfig,
|
||||||
|
PassConfig,
|
||||||
|
VllmConfig,
|
||||||
|
)
|
||||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
from vllm.distributed.parallel_state import (
|
||||||
initialize_model_parallel)
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||||
Fp8LinearOp)
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
@@ -36,16 +42,15 @@ prompts = [
|
|||||||
|
|
||||||
|
|
||||||
class TestModel(torch.nn.Module):
|
class TestModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
|
||||||
hidden_size=16,
|
):
|
||||||
intermediate_size=32,
|
|
||||||
vllm_config: VllmConfig = None):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.gate_proj = torch.nn.Parameter(
|
self.gate_proj = torch.nn.Parameter(
|
||||||
torch.empty((intermediate_size, hidden_size)))
|
torch.empty((intermediate_size, hidden_size))
|
||||||
|
)
|
||||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||||
@@ -53,18 +58,18 @@ class TestModel(torch.nn.Module):
|
|||||||
def forward(self, hidden_states, residual):
|
def forward(self, hidden_states, residual):
|
||||||
"""
|
"""
|
||||||
Forward pass implementing the operations in the FX graph
|
Forward pass implementing the operations in the FX graph
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hidden_states: Input tensor
|
hidden_states: Input tensor
|
||||||
residual: Residual tensor from previous layer
|
residual: Residual tensor from previous layer
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing the output tensor
|
Tuple containing the output tensor
|
||||||
"""
|
"""
|
||||||
# Reshape input
|
# Reshape input
|
||||||
view = hidden_states.reshape(-1, self.hidden_size)
|
view = hidden_states.reshape(-1, self.hidden_size)
|
||||||
|
|
||||||
#matrix multiplication
|
# matrix multiplication
|
||||||
permute = self.gate_proj.permute(1, 0)
|
permute = self.gate_proj.permute(1, 0)
|
||||||
mm = torch.mm(view, permute)
|
mm = torch.mm(view, permute)
|
||||||
|
|
||||||
@@ -82,7 +87,7 @@ class TestModel(torch.nn.Module):
|
|||||||
def ops_in_model_after(self):
|
def ops_in_model_after(self):
|
||||||
return [
|
return [
|
||||||
torch.ops.vllm.reduce_scatter.default,
|
torch.ops.vllm.reduce_scatter.default,
|
||||||
torch.ops.vllm.all_gather.default
|
torch.ops.vllm.all_gather.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
def ops_in_model(self):
|
def ops_in_model(self):
|
||||||
@@ -90,18 +95,16 @@ class TestModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestQuantModel(torch.nn.Module):
|
class TestQuantModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
|
||||||
hidden_size=16,
|
):
|
||||||
intermediate_size=32,
|
|
||||||
vllm_config: VllmConfig = None):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.gate_proj = torch.nn.Parameter(torch.empty(
|
self.gate_proj = torch.nn.Parameter(
|
||||||
(intermediate_size, hidden_size)),
|
torch.empty((intermediate_size, hidden_size)), requires_grad=False
|
||||||
requires_grad=False)
|
)
|
||||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||||
@@ -111,25 +114,24 @@ class TestQuantModel(torch.nn.Module):
|
|||||||
self.scale = torch.rand(1, dtype=torch.float32)
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
# Create a weight that is compatible with torch._scaled_mm,
|
# Create a weight that is compatible with torch._scaled_mm,
|
||||||
# which expects a column-major layout.
|
# which expects a column-major layout.
|
||||||
self.w = torch.rand(hidden_size,
|
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||||
intermediate_size).to(dtype=FP8_DTYPE).t()
|
|
||||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual):
|
def forward(self, hidden_states, residual):
|
||||||
"""
|
"""
|
||||||
Forward pass implementing the operations in the FX graph
|
Forward pass implementing the operations in the FX graph
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hidden_states: Input tensor
|
hidden_states: Input tensor
|
||||||
residual: Residual tensor from previous layer
|
residual: Residual tensor from previous layer
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing the output tensor
|
Tuple containing the output tensor
|
||||||
"""
|
"""
|
||||||
# Reshape input
|
# Reshape input
|
||||||
view = hidden_states.reshape(-1, self.hidden_size)
|
view = hidden_states.reshape(-1, self.hidden_size)
|
||||||
|
|
||||||
#matrix multiplication
|
# matrix multiplication
|
||||||
permute = self.gate_proj.permute(1, 0)
|
permute = self.gate_proj.permute(1, 0)
|
||||||
mm = torch.mm(view, permute)
|
mm = torch.mm(view, permute)
|
||||||
|
|
||||||
@@ -140,45 +142,51 @@ class TestQuantModel(torch.nn.Module):
|
|||||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||||
|
|
||||||
# scaled_mm with static input quantization
|
# scaled_mm with static input quantization
|
||||||
fp8_linear_result = self.fp8_linear.apply(norm_output,
|
fp8_linear_result = self.fp8_linear.apply(
|
||||||
self.w,
|
norm_output,
|
||||||
self.wscale,
|
self.w,
|
||||||
input_scale=self.scale.to(
|
self.wscale,
|
||||||
norm_output.device))
|
input_scale=self.scale.to(norm_output.device),
|
||||||
|
)
|
||||||
|
|
||||||
return fp8_linear_result, residual_output
|
return fp8_linear_result, residual_output
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
ops_to_remove = [torch.ops.vllm.all_reduce.default
|
ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP
|
||||||
] # Always removed by SP
|
|
||||||
# The following are only removed if fusion happens
|
# The following are only removed if fusion happens
|
||||||
if self.vllm_config and self.vllm_config.compilation_config \
|
if (
|
||||||
.pass_config.enable_fusion:
|
self.vllm_config
|
||||||
ops_to_remove.extend([
|
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
||||||
torch.ops._C.fused_add_rms_norm.default,
|
):
|
||||||
torch.ops._C.static_scaled_fp8_quant.default,
|
ops_to_remove.extend(
|
||||||
])
|
[
|
||||||
|
torch.ops._C.fused_add_rms_norm.default,
|
||||||
|
torch.ops._C.static_scaled_fp8_quant.default,
|
||||||
|
]
|
||||||
|
)
|
||||||
return ops_to_remove
|
return ops_to_remove
|
||||||
|
|
||||||
def ops_in_model_after(self):
|
def ops_in_model_after(self):
|
||||||
ops_to_add = [
|
ops_to_add = [
|
||||||
torch.ops.vllm.reduce_scatter.default,
|
torch.ops.vllm.reduce_scatter.default,
|
||||||
torch.ops.vllm.all_gather.default
|
torch.ops.vllm.all_gather.default,
|
||||||
]
|
]
|
||||||
# The following is only added if fusion happens
|
# The following is only added if fusion happens
|
||||||
if self.vllm_config and self.vllm_config.compilation_config \
|
if (
|
||||||
.pass_config.enable_fusion:
|
self.vllm_config
|
||||||
ops_to_add.append(
|
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
||||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
|
):
|
||||||
|
ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
|
||||||
return ops_to_add
|
return ops_to_add
|
||||||
|
|
||||||
def ops_in_model(self):
|
def ops_in_model(self):
|
||||||
if self.vllm_config and self.vllm_config.compilation_config \
|
if (
|
||||||
.pass_config.enable_fusion:
|
self.vllm_config
|
||||||
|
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
||||||
|
):
|
||||||
# If fusion happens, the fused op is the one
|
# If fusion happens, the fused op is the one
|
||||||
# we check for (de)functionalization
|
# we check for (de)functionalization
|
||||||
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
|
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] # noqa: E501
|
||||||
] # noqa: E501
|
|
||||||
else:
|
else:
|
||||||
# If no fusion, the original ops are checked
|
# If no fusion, the original ops are checked
|
||||||
return [
|
return [
|
||||||
@@ -195,30 +203,47 @@ class TestQuantModel(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize("hidden_size", [16])
|
@pytest.mark.parametrize("hidden_size", [16])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("enable_fusion", [True, False])
|
@pytest.mark.parametrize("enable_fusion", [True, False])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||||
reason="Only test on CUDA")
|
def test_sequence_parallelism_pass(
|
||||||
def test_sequence_parallelism_pass(test_model_cls: type[torch.nn.Module],
|
test_model_cls: type[torch.nn.Module],
|
||||||
batch_size: int, seq_len: int,
|
batch_size: int,
|
||||||
hidden_size: int, dtype: torch.dtype,
|
seq_len: int,
|
||||||
enable_fusion: bool):
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
enable_fusion: bool,
|
||||||
|
):
|
||||||
num_processes = 2
|
num_processes = 2
|
||||||
|
|
||||||
def run_torch_spawn(fn, nprocs):
|
def run_torch_spawn(fn, nprocs):
|
||||||
# need to use torch.mp.spawn otherwise will have problems with
|
# need to use torch.mp.spawn otherwise will have problems with
|
||||||
# torch.distributed and cuda
|
# torch.distributed and cuda
|
||||||
torch.multiprocessing.spawn(fn,
|
torch.multiprocessing.spawn(
|
||||||
args=(num_processes, test_model_cls,
|
fn,
|
||||||
batch_size, seq_len, hidden_size,
|
args=(
|
||||||
dtype, enable_fusion),
|
num_processes,
|
||||||
nprocs=nprocs)
|
test_model_cls,
|
||||||
|
batch_size,
|
||||||
|
seq_len,
|
||||||
|
hidden_size,
|
||||||
|
dtype,
|
||||||
|
enable_fusion,
|
||||||
|
),
|
||||||
|
nprocs=nprocs,
|
||||||
|
)
|
||||||
|
|
||||||
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
|
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
|
||||||
|
|
||||||
|
|
||||||
def sequence_parallelism_pass_on_test_model(
|
def sequence_parallelism_pass_on_test_model(
|
||||||
local_rank: int, world_size: int,
|
local_rank: int,
|
||||||
test_model_cls: type[torch.nn.Module], batch_size: int, seq_len: int,
|
world_size: int,
|
||||||
hidden_size: int, dtype: torch.dtype, enable_fusion: bool):
|
test_model_cls: type[torch.nn.Module],
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
enable_fusion: bool,
|
||||||
|
):
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
@@ -226,13 +251,15 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
|
|
||||||
update_environment_variables({
|
update_environment_variables(
|
||||||
'RANK': str(local_rank),
|
{
|
||||||
'LOCAL_RANK': str(local_rank),
|
"RANK": str(local_rank),
|
||||||
'WORLD_SIZE': str(world_size),
|
"LOCAL_RANK": str(local_rank),
|
||||||
'MASTER_ADDR': 'localhost',
|
"WORLD_SIZE": str(world_size),
|
||||||
'MASTER_PORT': '12345',
|
"MASTER_ADDR": "localhost",
|
||||||
})
|
"MASTER_PORT": "12345",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# initialize distributed
|
# initialize distributed
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
@@ -240,27 +267,28 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
|
|
||||||
# configure vllm config for SequenceParallelismPass
|
# configure vllm config for SequenceParallelismPass
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
|
vllm_config.compilation_config = CompilationConfig(
|
||||||
enable_sequence_parallelism=True,
|
pass_config=PassConfig(
|
||||||
enable_fusion=enable_fusion,
|
enable_sequence_parallelism=True,
|
||||||
enable_noop=True)) # NoOp needed for fusion
|
enable_fusion=enable_fusion,
|
||||||
|
enable_noop=True,
|
||||||
|
)
|
||||||
|
) # NoOp needed for fusion
|
||||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||||
|
|
||||||
# this is a fake model name to construct the model config
|
# this is a fake model name to construct the model config
|
||||||
# in the vllm_config, it's not really used.
|
# in the vllm_config, it's not really used.
|
||||||
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||||
vllm_config.model_config = ModelConfig(model=model_name,
|
vllm_config.model_config = ModelConfig(
|
||||||
trust_remote_code=True,
|
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||||
dtype=dtype,
|
)
|
||||||
seed=42)
|
|
||||||
|
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||||
func_pass = FixFunctionalizationPass(vllm_config)
|
func_pass = FixFunctionalizationPass(vllm_config)
|
||||||
cleanup_pass = PostCleanupPass(vllm_config)
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
|
|
||||||
passes_for_backend: list[VllmInductorPass] = \
|
passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass]
|
||||||
[noop_pass, sequence_parallelism_pass]
|
|
||||||
|
|
||||||
if enable_fusion:
|
if enable_fusion:
|
||||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||||
@@ -271,12 +299,9 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
backend_no_func = TestBackend(*passes_for_backend)
|
backend_no_func = TestBackend(*passes_for_backend)
|
||||||
backend_func = TestBackend(*passes_for_backend, func_pass)
|
backend_func = TestBackend(*passes_for_backend, func_pass)
|
||||||
|
|
||||||
model = test_model_cls(hidden_size,
|
model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config)
|
||||||
hidden_size * 2,
|
|
||||||
vllm_config=vllm_config)
|
|
||||||
|
|
||||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||||
dtype=dtype)
|
|
||||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||||
|
|
||||||
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
||||||
@@ -297,8 +322,7 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
# check if the functionalization pass is applied
|
# check if the functionalization pass is applied
|
||||||
for op in model.ops_in_model():
|
for op in model.ops_in_model():
|
||||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
|
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501
|
||||||
op) is None # noqa: E501
|
|
||||||
|
|
||||||
# make sure the ops were all de-functionalized
|
# make sure the ops were all de-functionalized
|
||||||
found = dict()
|
found = dict()
|
||||||
|
|||||||
@@ -8,10 +8,15 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
|
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
|
||||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
|
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.compilation.activation_quant_fusion import (
|
from vllm.compilation.activation_quant_fusion import (
|
||||||
FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass)
|
FUSED_OPS,
|
||||||
|
SILU_MUL_OP,
|
||||||
|
ActivationQuantFusionPass,
|
||||||
|
)
|
||||||
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.compilation.fusion import QUANT_OPS
|
from vllm.compilation.fusion import QUANT_OPS
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
@@ -19,9 +24,14 @@ from vllm.compilation.post_cleanup import PostCleanupPass
|
|||||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape, kFp8StaticTensorSym, kNvfp4Quant)
|
GroupShape,
|
||||||
|
kFp8StaticTensorSym,
|
||||||
|
kNvfp4Quant,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
Fp8LinearOp, cutlass_fp8_supported)
|
Fp8LinearOp,
|
||||||
|
cutlass_fp8_supported,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import override_cutlass_fp8_supported
|
from ..utils import override_cutlass_fp8_supported
|
||||||
@@ -36,7 +46,6 @@ def is_nvfp4_supported():
|
|||||||
|
|
||||||
|
|
||||||
class TestSiluMulFp8QuantModel(torch.nn.Module):
|
class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
|
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.silu_and_mul = SiluAndMul()
|
self.silu_and_mul = SiluAndMul()
|
||||||
@@ -53,10 +62,7 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = self.silu_and_mul(x)
|
y = self.silu_and_mul(x)
|
||||||
x2 = self.fp8_linear.apply(y,
|
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
|
||||||
self.w,
|
|
||||||
self.wscale,
|
|
||||||
input_scale=self.wscale)
|
|
||||||
return x2
|
return x2
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
@@ -67,11 +73,12 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
|
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from vllm.compilation.activation_quant_fusion import (
|
from vllm.compilation.activation_quant_fusion import (
|
||||||
silu_and_mul_nvfp4_quant_supported)
|
silu_and_mul_nvfp4_quant_supported,
|
||||||
|
)
|
||||||
|
|
||||||
assert silu_and_mul_nvfp4_quant_supported
|
assert silu_and_mul_nvfp4_quant_supported
|
||||||
|
|
||||||
self.silu_and_mul = SiluAndMul()
|
self.silu_and_mul = SiluAndMul()
|
||||||
@@ -88,12 +95,14 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = self.silu_and_mul(x)
|
y = self.silu_and_mul(x)
|
||||||
y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale)
|
y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale)
|
||||||
out = cutlass_scaled_fp4_mm(a=y_quant,
|
out = cutlass_scaled_fp4_mm(
|
||||||
b=self.w,
|
a=y_quant,
|
||||||
block_scale_a=y_block_scale,
|
b=self.w,
|
||||||
block_scale_b=self.w_block_scale,
|
block_scale_a=y_block_scale,
|
||||||
alpha=self.alpha,
|
block_scale_b=self.w_block_scale,
|
||||||
out_dtype=y.dtype)
|
alpha=self.alpha,
|
||||||
|
out_dtype=y.dtype,
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
@@ -108,16 +117,24 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_class",
|
"model_class",
|
||||||
cast(list[type], [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
|
cast(
|
||||||
if is_nvfp4_supported() else [TestSiluMulFp8QuantModel]))
|
list[type],
|
||||||
|
[TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
|
||||||
|
if is_nvfp4_supported()
|
||||||
|
else [TestSiluMulFp8QuantModel],
|
||||||
|
),
|
||||||
|
)
|
||||||
# cuda_force_torch used to test torch code path on platforms that
|
# cuda_force_torch used to test torch code path on platforms that
|
||||||
# cutlass_fp8_supported() == True.
|
# cutlass_fp8_supported() == True.
|
||||||
@pytest.mark.parametrize("cuda_force_torch",
|
@pytest.mark.parametrize(
|
||||||
[True, False] if cutlass_fp8_supported() else [True])
|
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
)
|
||||||
reason="Only test on CUDA and ROCm")
|
@pytest.mark.skipif(
|
||||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
|
||||||
cuda_force_torch):
|
)
|
||||||
|
def test_fusion_silu_and_mul_quant(
|
||||||
|
num_tokens, hidden_size, dtype, model_class, cuda_force_torch
|
||||||
|
):
|
||||||
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
|
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
|
||||||
pytest.skip("Duplicate tests for NVFP4")
|
pytest.skip("Duplicate tests for NVFP4")
|
||||||
|
|
||||||
@@ -129,17 +146,13 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
|||||||
# Reshape pass is needed for the fusion pass to work
|
# Reshape pass is needed for the fusion pass to work
|
||||||
config = VllmConfig()
|
config = VllmConfig()
|
||||||
config.compilation_config = CompilationConfig(
|
config.compilation_config = CompilationConfig(
|
||||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
|
pass_config=PassConfig(enable_fusion=True, enable_noop=True)
|
||||||
|
)
|
||||||
fusion_pass = ActivationQuantFusionPass(config)
|
fusion_pass = ActivationQuantFusionPass(config)
|
||||||
|
|
||||||
passes = [
|
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
|
||||||
NoOpEliminationPass(config), fusion_pass,
|
|
||||||
PostCleanupPass(config)
|
|
||||||
]
|
|
||||||
backend = TestBackend(*passes)
|
backend = TestBackend(*passes)
|
||||||
model = model_class(hidden_size=hidden_size,
|
model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x)
|
||||||
cuda_force_torch=cuda_force_torch,
|
|
||||||
x=x)
|
|
||||||
|
|
||||||
# First dimension dynamic
|
# First dimension dynamic
|
||||||
torch._dynamo.mark_dynamic(x, 0)
|
torch._dynamo.mark_dynamic(x, 0)
|
||||||
@@ -155,10 +168,9 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
|||||||
elif model_class == TestSiluMulNvfp4QuantModel:
|
elif model_class == TestSiluMulNvfp4QuantModel:
|
||||||
atol, rtol = 1e-1, 1e-1
|
atol, rtol = 1e-1, 1e-1
|
||||||
|
|
||||||
torch.testing.assert_close(result[0].to(dtype=dtype),
|
torch.testing.assert_close(
|
||||||
result2[0].to(dtype=dtype),
|
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
|
||||||
atol=atol,
|
)
|
||||||
rtol=rtol)
|
|
||||||
|
|
||||||
assert fusion_pass.matched_count == 1
|
assert fusion_pass.matched_count == 1
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from vllm.config import CompilationLevel
|
|||||||
|
|
||||||
|
|
||||||
class MyMod(torch.nn.Module):
|
class MyMod(torch.nn.Module):
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
return x + cache
|
return x + cache
|
||||||
@@ -18,12 +17,12 @@ class MyMod(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
|
class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
|
||||||
|
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self.model = model
|
self.model = model
|
||||||
compiled_callable = torch.compile(self.forward, backend="eager")
|
compiled_callable = torch.compile(self.forward, backend="eager")
|
||||||
super().__init__(compiled_callable,
|
super().__init__(
|
||||||
compilation_level=CompilationLevel.DYNAMO_ONCE)
|
compiled_callable, compilation_level=CompilationLevel.DYNAMO_ONCE
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||||
# this is the function to be compiled
|
# this is the function to be compiled
|
||||||
@@ -54,10 +53,8 @@ def test_torch_compile_wrapper():
|
|||||||
|
|
||||||
# for new input, dispatch to the compiled code directly
|
# for new input, dispatch to the compiled code directly
|
||||||
new_x = torch.tensor([3])
|
new_x = torch.tensor([3])
|
||||||
assert wrapper(new_x,
|
assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code
|
||||||
None).item() == 6 # dispatch to the first compiled code
|
assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code
|
||||||
assert wrapper(
|
|
||||||
new_x, cache).item() == 5 # dispatch to the second compiled code
|
|
||||||
|
|
||||||
for wrapper in wrappers:
|
for wrapper in wrappers:
|
||||||
# make sure they have independent compiled codes
|
# make sure they have independent compiled codes
|
||||||
|
|||||||
@@ -14,8 +14,9 @@ def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def create_config():
|
def create_config():
|
||||||
engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite",
|
engine_args = EngineArgs(
|
||||||
trust_remote_code=True)
|
model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True
|
||||||
|
)
|
||||||
return engine_args.create_engine_config()
|
return engine_args.create_engine_config()
|
||||||
|
|
||||||
# Create config with CUDA_VISIBLE_DEVICES set normally
|
# Create config with CUDA_VISIBLE_DEVICES set normally
|
||||||
@@ -34,16 +35,18 @@ def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch):
|
|||||||
empty_config_dict.pop("instance_id", None)
|
empty_config_dict.pop("instance_id", None)
|
||||||
|
|
||||||
assert deep_compare(normal_config_dict, empty_config_dict), (
|
assert deep_compare(normal_config_dict, empty_config_dict), (
|
||||||
"Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\""
|
'Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=""'
|
||||||
" should be equivalent")
|
" should be equivalent"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
|
def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
|
||||||
# In testing, this method needs to be nested inside as ray does not
|
# In testing, this method needs to be nested inside as ray does not
|
||||||
# see the test module.
|
# see the test module.
|
||||||
def create_config():
|
def create_config():
|
||||||
engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite",
|
engine_args = EngineArgs(
|
||||||
trust_remote_code=True)
|
model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True
|
||||||
|
)
|
||||||
return engine_args.create_engine_config()
|
return engine_args.create_engine_config()
|
||||||
|
|
||||||
config = create_config()
|
config = create_config()
|
||||||
@@ -51,6 +54,7 @@ def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert parallel_config.ray_runtime_env is None
|
assert parallel_config.ray_runtime_env is None
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
ray.init()
|
ray.init()
|
||||||
|
|
||||||
runtime_env = {
|
runtime_env = {
|
||||||
@@ -59,13 +63,13 @@ def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
config_ref = ray.remote(create_config).options(
|
config_ref = ray.remote(create_config).options(runtime_env=runtime_env).remote()
|
||||||
runtime_env=runtime_env).remote()
|
|
||||||
|
|
||||||
config = ray.get(config_ref)
|
config = ray.get(config_ref)
|
||||||
parallel_config = config.parallel_config
|
parallel_config = config.parallel_config
|
||||||
assert parallel_config.ray_runtime_env is not None
|
assert parallel_config.ray_runtime_env is not None
|
||||||
assert parallel_config.ray_runtime_env.env_vars().get(
|
assert (
|
||||||
"TEST_ENV_VAR") == "test_value"
|
parallel_config.ray_runtime_env.env_vars().get("TEST_ENV_VAR") == "test_value"
|
||||||
|
)
|
||||||
|
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|||||||
@@ -16,13 +16,13 @@ def test_mp_reducer(monkeypatch):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Use V1 AsyncLLM which calls maybe_register_config_serialize_by_value
|
# Use V1 AsyncLLM which calls maybe_register_config_serialize_by_value
|
||||||
monkeypatch.setenv('VLLM_USE_V1', '1')
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
# Ensure transformers_modules is not in sys.modules
|
# Ensure transformers_modules is not in sys.modules
|
||||||
if 'transformers_modules' in sys.modules:
|
if "transformers_modules" in sys.modules:
|
||||||
del sys.modules['transformers_modules']
|
del sys.modules["transformers_modules"]
|
||||||
|
|
||||||
with patch('multiprocessing.reducer.register') as mock_register:
|
with patch("multiprocessing.reducer.register") as mock_register:
|
||||||
engine_args = AsyncEngineArgs(
|
engine_args = AsyncEngineArgs(
|
||||||
model="facebook/opt-125m",
|
model="facebook/opt-125m",
|
||||||
max_model_len=32,
|
max_model_len=32,
|
||||||
@@ -36,7 +36,8 @@ def test_mp_reducer(monkeypatch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert mock_register.called, (
|
assert mock_register.called, (
|
||||||
"multiprocessing.reducer.register should have been called")
|
"multiprocessing.reducer.register should have been called"
|
||||||
|
)
|
||||||
|
|
||||||
vllm_config_registered = False
|
vllm_config_registered = False
|
||||||
for call_args in mock_register.call_args_list:
|
for call_args in mock_register.call_args_list:
|
||||||
@@ -45,8 +46,7 @@ def test_mp_reducer(monkeypatch):
|
|||||||
vllm_config_registered = True
|
vllm_config_registered = True
|
||||||
|
|
||||||
reducer_func = call_args[0][1]
|
reducer_func = call_args[0][1]
|
||||||
assert callable(
|
assert callable(reducer_func), "Reducer function should be callable"
|
||||||
reducer_func), "Reducer function should be callable"
|
|
||||||
break
|
break
|
||||||
|
|
||||||
assert vllm_config_registered, (
|
assert vllm_config_registered, (
|
||||||
|
|||||||
@@ -30,22 +30,27 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
from transformers import (
|
||||||
BatchEncoding, BatchFeature)
|
AutoConfig,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
BatchEncoding,
|
||||||
|
BatchFeature,
|
||||||
|
)
|
||||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||||
|
|
||||||
from tests.models.utils import (TokensTextLogprobs,
|
from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs
|
||||||
TokensTextLogprobsPromptLogprobs)
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.assets.audio import AudioAsset
|
from vllm.assets.audio import AudioAsset
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.assets.video import VideoAsset
|
from vllm.assets.video import VideoAsset
|
||||||
from vllm.config.model import (ConvertOption, RunnerOption,
|
from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype
|
||||||
_get_and_verify_dtype)
|
|
||||||
from vllm.connections import global_http_connection
|
from vllm.connections import global_http_connection
|
||||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
from vllm.distributed import (
|
||||||
init_distributed_environment,
|
cleanup_dist_env_and_memory,
|
||||||
initialize_model_parallel)
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logprobs import Logprob
|
from vllm.logprobs import Logprob
|
||||||
from vllm.multimodal.utils import fetch_image
|
from vllm.multimodal.utils import fetch_image
|
||||||
@@ -82,12 +87,13 @@ class ImageAssetPrompts(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
class ImageTestAssets(list[ImageAsset]):
|
class ImageTestAssets(list[ImageAsset]):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__([
|
super().__init__(
|
||||||
ImageAsset("stop_sign"),
|
[
|
||||||
ImageAsset("cherry_blossom"),
|
ImageAsset("stop_sign"),
|
||||||
])
|
ImageAsset("cherry_blossom"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
|
def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@@ -104,11 +110,12 @@ class VideoAssetPrompts(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
class VideoTestAssets(list[VideoAsset]):
|
class VideoTestAssets(list[VideoAsset]):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__([
|
super().__init__(
|
||||||
VideoAsset("baby_reading"),
|
[
|
||||||
])
|
VideoAsset("baby_reading"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
|
def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
|
||||||
return [prompts["baby_reading"]]
|
return [prompts["baby_reading"]]
|
||||||
@@ -120,12 +127,13 @@ class AudioAssetPrompts(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
class AudioTestAssets(list[AudioAsset]):
|
class AudioTestAssets(list[AudioAsset]):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__([
|
super().__init__(
|
||||||
AudioAsset("mary_had_lamb"),
|
[
|
||||||
AudioAsset("winning_call"),
|
AudioAsset("mary_had_lamb"),
|
||||||
])
|
AudioAsset("winning_call"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
|
def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
|
||||||
return [prompts["mary_had_lamb"], prompts["winning_call"]]
|
return [prompts["mary_had_lamb"], prompts["winning_call"]]
|
||||||
@@ -220,6 +228,7 @@ def example_system_message() -> str:
|
|||||||
|
|
||||||
class DecoderPromptType(Enum):
|
class DecoderPromptType(Enum):
|
||||||
"""For encoder/decoder models only."""
|
"""For encoder/decoder models only."""
|
||||||
|
|
||||||
CUSTOM = 1
|
CUSTOM = 1
|
||||||
NONE = 2
|
NONE = 2
|
||||||
EMPTY_STR = 3
|
EMPTY_STR = 3
|
||||||
@@ -253,15 +262,13 @@ _R = TypeVar("_R")
|
|||||||
|
|
||||||
|
|
||||||
class HfRunner:
|
class HfRunner:
|
||||||
|
|
||||||
def get_default_device(self):
|
def get_default_device(self):
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
return ("cpu"
|
return "cpu" if current_platform.is_cpu() else current_platform.device_type
|
||||||
if current_platform.is_cpu() else current_platform.device_type)
|
|
||||||
|
|
||||||
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
||||||
if x is None or isinstance(x, (bool, )):
|
if x is None or isinstance(x, (bool,)):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
@@ -289,8 +296,11 @@ class HfRunner:
|
|||||||
# Set this to avoid hanging issue
|
# Set this to avoid hanging issue
|
||||||
default_torch_num_threads: Optional[int] = None,
|
default_torch_num_threads: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
init_ctx = (nullcontext() if default_torch_num_threads is None else
|
init_ctx = (
|
||||||
set_default_torch_num_threads(default_torch_num_threads))
|
nullcontext()
|
||||||
|
if default_torch_num_threads is None
|
||||||
|
else set_default_torch_num_threads(default_torch_num_threads)
|
||||||
|
)
|
||||||
|
|
||||||
with init_ctx:
|
with init_ctx:
|
||||||
self._init(
|
self._init(
|
||||||
@@ -362,14 +372,15 @@ class HfRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# in case some unquantized custom models are not in same dtype
|
# in case some unquantized custom models are not in same dtype
|
||||||
if (getattr(model, "quantization_method", None) is None
|
if getattr(model, "quantization_method", None) is None and any(
|
||||||
and any(p.dtype != self.dtype
|
p.dtype != self.dtype for p in model.parameters()
|
||||||
for p in model.parameters())):
|
):
|
||||||
model = model.to(dtype=self.dtype)
|
model = model.to(dtype=self.dtype)
|
||||||
|
|
||||||
if (getattr(model, "quantization_method", None) != "bitsandbytes"
|
if (
|
||||||
and len({p.device
|
getattr(model, "quantization_method", None) != "bitsandbytes"
|
||||||
for p in model.parameters()}) < 2):
|
and len({p.device for p in model.parameters()}) < 2
|
||||||
|
):
|
||||||
model = model.to(device=self.device)
|
model = model.to(device=self.device)
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -384,6 +395,7 @@ class HfRunner:
|
|||||||
# don't put this import at the top level
|
# don't put this import at the top level
|
||||||
# it will call torch.cuda.device_count()
|
# it will call torch.cuda.device_count()
|
||||||
from transformers import AutoProcessor # noqa: F401
|
from transformers import AutoProcessor # noqa: F401
|
||||||
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -471,10 +483,9 @@ class HfRunner:
|
|||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[tuple[list[list[int]], list[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
all_inputs = self.get_inputs(prompts,
|
all_inputs = self.get_inputs(
|
||||||
images=images,
|
prompts, images=images, videos=videos, audios=audios
|
||||||
videos=videos,
|
)
|
||||||
audios=audios)
|
|
||||||
|
|
||||||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||||
for inputs in all_inputs:
|
for inputs in all_inputs:
|
||||||
@@ -501,16 +512,17 @@ class HfRunner:
|
|||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[tuple[list[int], str]]:
|
) -> list[tuple[list[int], str]]:
|
||||||
outputs = self.generate(prompts,
|
outputs = self.generate(
|
||||||
do_sample=False,
|
prompts,
|
||||||
max_new_tokens=max_tokens,
|
do_sample=False,
|
||||||
images=images,
|
max_new_tokens=max_tokens,
|
||||||
videos=videos,
|
images=images,
|
||||||
audios=audios,
|
videos=videos,
|
||||||
**kwargs)
|
audios=audios,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
return [(output_ids[0], output_str[0])
|
return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
|
||||||
for output_ids, output_str in outputs]
|
|
||||||
|
|
||||||
def generate_beam_search(
|
def generate_beam_search(
|
||||||
self,
|
self,
|
||||||
@@ -521,21 +533,22 @@ class HfRunner:
|
|||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
) -> list[tuple[list[list[int]], list[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
outputs = self.generate(prompts,
|
outputs = self.generate(
|
||||||
do_sample=False,
|
prompts,
|
||||||
max_new_tokens=max_tokens,
|
do_sample=False,
|
||||||
num_beams=beam_width,
|
max_new_tokens=max_tokens,
|
||||||
num_return_sequences=beam_width,
|
num_beams=beam_width,
|
||||||
images=images,
|
num_return_sequences=beam_width,
|
||||||
videos=videos,
|
images=images,
|
||||||
audios=audios)
|
videos=videos,
|
||||||
|
audios=audios,
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(len(outputs)):
|
for i in range(len(outputs)):
|
||||||
output_ids, output_str = outputs[i]
|
output_ids, output_str = outputs[i]
|
||||||
for j in range(len(output_ids)):
|
for j in range(len(output_ids)):
|
||||||
output_ids[j] = [
|
output_ids[j] = [
|
||||||
x for x in output_ids[j]
|
x for x in output_ids[j] if x != self.tokenizer.pad_token_id
|
||||||
if x != self.tokenizer.pad_token_id
|
|
||||||
]
|
]
|
||||||
outputs[i] = (output_ids, output_str)
|
outputs[i] = (output_ids, output_str)
|
||||||
return outputs
|
return outputs
|
||||||
@@ -549,10 +562,9 @@ class HfRunner:
|
|||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[list[torch.Tensor]]:
|
) -> list[list[torch.Tensor]]:
|
||||||
all_inputs = self.get_inputs(prompts,
|
all_inputs = self.get_inputs(
|
||||||
images=images,
|
prompts, images=images, videos=videos, audios=audios
|
||||||
videos=videos,
|
)
|
||||||
audios=audios)
|
|
||||||
|
|
||||||
all_logprobs: list[list[torch.Tensor]] = []
|
all_logprobs: list[list[torch.Tensor]] = []
|
||||||
for inputs in all_inputs:
|
for inputs in all_inputs:
|
||||||
@@ -565,8 +577,7 @@ class HfRunner:
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
seq_logprobs = self._hidden_states_to_seq_logprobs(
|
seq_logprobs = self._hidden_states_to_seq_logprobs(output.hidden_states)
|
||||||
output.hidden_states)
|
|
||||||
all_logprobs.append(seq_logprobs)
|
all_logprobs.append(seq_logprobs)
|
||||||
return all_logprobs
|
return all_logprobs
|
||||||
|
|
||||||
@@ -630,10 +641,9 @@ class HfRunner:
|
|||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[TokensTextLogprobs]:
|
) -> list[TokensTextLogprobs]:
|
||||||
all_inputs = self.get_inputs(prompts,
|
all_inputs = self.get_inputs(
|
||||||
images=images,
|
prompts, images=images, videos=videos, audios=audios
|
||||||
videos=videos,
|
)
|
||||||
audios=audios)
|
|
||||||
|
|
||||||
all_logprobs: list[list[dict[int, float]]] = []
|
all_logprobs: list[list[dict[int, float]]] = []
|
||||||
all_output_ids: list[list[int]] = []
|
all_output_ids: list[list[int]] = []
|
||||||
@@ -653,8 +663,7 @@ class HfRunner:
|
|||||||
(
|
(
|
||||||
seq_logprobs_lst,
|
seq_logprobs_lst,
|
||||||
output_len,
|
output_len,
|
||||||
) = self._hidden_states_to_logprobs(output.hidden_states,
|
) = self._hidden_states_to_logprobs(output.hidden_states, num_logprobs)
|
||||||
num_logprobs)
|
|
||||||
|
|
||||||
all_logprobs.append(seq_logprobs_lst)
|
all_logprobs.append(seq_logprobs_lst)
|
||||||
seq_ids = output.sequences[0]
|
seq_ids = output.sequences[0]
|
||||||
@@ -664,19 +673,16 @@ class HfRunner:
|
|||||||
all_output_strs.append(self.tokenizer.decode(output_ids))
|
all_output_strs.append(self.tokenizer.decode(output_ids))
|
||||||
|
|
||||||
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
|
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
|
||||||
return [(output_ids, output_str, output_logprobs)
|
return [
|
||||||
for output_ids, output_str, output_logprobs in outputs]
|
(output_ids, output_str, output_logprobs)
|
||||||
|
for output_ids, output_str, output_logprobs in outputs
|
||||||
|
]
|
||||||
|
|
||||||
def encode(self, prompts: list[str], *args,
|
def encode(self, prompts: list[str], *args, **kwargs) -> list[list[torch.Tensor]]:
|
||||||
**kwargs) -> list[list[torch.Tensor]]:
|
|
||||||
return self.model.encode(prompts, *args, **kwargs)
|
return self.model.encode(prompts, *args, **kwargs)
|
||||||
|
|
||||||
def predict(self, prompts: list[list[str]], *args,
|
def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
|
||||||
**kwargs) -> torch.Tensor:
|
return self.model.predict(prompts, *args, convert_to_tensor=True, **kwargs)
|
||||||
return self.model.predict(prompts,
|
|
||||||
*args,
|
|
||||||
convert_to_tensor=True,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
@@ -727,8 +733,11 @@ class VllmRunner:
|
|||||||
default_torch_num_threads: Optional[int] = None,
|
default_torch_num_threads: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
init_ctx = (nullcontext() if default_torch_num_threads is None else
|
init_ctx = (
|
||||||
set_default_torch_num_threads(default_torch_num_threads))
|
nullcontext()
|
||||||
|
if default_torch_num_threads is None
|
||||||
|
else set_default_torch_num_threads(default_torch_num_threads)
|
||||||
|
)
|
||||||
|
|
||||||
if not kwargs.get("compilation_config", None):
|
if not kwargs.get("compilation_config", None):
|
||||||
kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]}
|
kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]}
|
||||||
@@ -760,11 +769,12 @@ class VllmRunner:
|
|||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
if any(x is not None and len(x) != len(prompts)
|
if any(
|
||||||
for x in [images, videos, audios]):
|
x is not None and len(x) != len(prompts) for x in [images, videos, audios]
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"All non-None multimodal inputs must have the same length as "
|
"All non-None multimodal inputs must have the same length as prompts"
|
||||||
"prompts")
|
)
|
||||||
|
|
||||||
inputs = list[dict[str, Any]]()
|
inputs = list[dict[str, Any]]()
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
@@ -800,14 +810,11 @@ class VllmRunner:
|
|||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[tuple[list[list[int]], list[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
inputs = self.get_inputs(prompts,
|
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||||
images=images,
|
|
||||||
videos=videos,
|
|
||||||
audios=audios)
|
|
||||||
|
|
||||||
req_outputs = self.llm.generate(inputs,
|
req_outputs = self.llm.generate(
|
||||||
sampling_params=sampling_params,
|
inputs, sampling_params=sampling_params, **kwargs
|
||||||
**kwargs)
|
)
|
||||||
|
|
||||||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||||
for req_output in req_outputs:
|
for req_output in req_outputs:
|
||||||
@@ -834,8 +841,9 @@ class VllmRunner:
|
|||||||
output_str = sample.text
|
output_str = sample.text
|
||||||
output_ids = list(sample.token_ids)
|
output_ids = list(sample.token_ids)
|
||||||
output_logprobs = sample.logprobs
|
output_logprobs = sample.logprobs
|
||||||
outputs.append((output_ids, output_str, output_logprobs,
|
outputs.append(
|
||||||
req_output.prompt_logprobs))
|
(output_ids, output_str, output_logprobs, req_output.prompt_logprobs)
|
||||||
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def generate_w_logprobs(
|
def generate_w_logprobs(
|
||||||
@@ -846,23 +854,22 @@ class VllmRunner:
|
|||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[list[TokensTextLogprobs],
|
) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]:
|
||||||
list[TokensTextLogprobsPromptLogprobs]]:
|
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||||
inputs = self.get_inputs(prompts,
|
|
||||||
images=images,
|
|
||||||
videos=videos,
|
|
||||||
audios=audios)
|
|
||||||
|
|
||||||
req_outputs = self.llm.generate(inputs,
|
req_outputs = self.llm.generate(
|
||||||
sampling_params=sampling_params,
|
inputs, sampling_params=sampling_params, **kwargs
|
||||||
**kwargs)
|
)
|
||||||
|
|
||||||
toks_str_logsprobs_prompt_logprobs = (
|
toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs(
|
||||||
self._final_steps_generate_w_logprobs(req_outputs))
|
req_outputs
|
||||||
|
)
|
||||||
# Omit prompt logprobs if not required by sampling params
|
# Omit prompt logprobs if not required by sampling params
|
||||||
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
return (
|
||||||
if sampling_params.prompt_logprobs is None else
|
[x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
||||||
toks_str_logsprobs_prompt_logprobs)
|
if sampling_params.prompt_logprobs is None
|
||||||
|
else toks_str_logsprobs_prompt_logprobs
|
||||||
|
)
|
||||||
|
|
||||||
def generate_greedy(
|
def generate_greedy(
|
||||||
self,
|
self,
|
||||||
@@ -874,14 +881,15 @@ class VllmRunner:
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[tuple[list[int], str]]:
|
) -> list[tuple[list[int], str]]:
|
||||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||||
outputs = self.generate(prompts,
|
outputs = self.generate(
|
||||||
greedy_params,
|
prompts,
|
||||||
images=images,
|
greedy_params,
|
||||||
videos=videos,
|
images=images,
|
||||||
audios=audios,
|
videos=videos,
|
||||||
**kwargs)
|
audios=audios,
|
||||||
return [(output_ids[0], output_str[0])
|
**kwargs,
|
||||||
for output_ids, output_str in outputs]
|
)
|
||||||
|
return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
|
||||||
|
|
||||||
def generate_greedy_logprobs(
|
def generate_greedy_logprobs(
|
||||||
self,
|
self,
|
||||||
@@ -895,22 +903,24 @@ class VllmRunner:
|
|||||||
stop_token_ids: Optional[list[int]] = None,
|
stop_token_ids: Optional[list[int]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[list[TokensTextLogprobs],
|
) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]:
|
||||||
list[TokensTextLogprobsPromptLogprobs]]:
|
|
||||||
greedy_logprobs_params = SamplingParams(
|
greedy_logprobs_params = SamplingParams(
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
logprobs=num_logprobs,
|
logprobs=num_logprobs,
|
||||||
prompt_logprobs=num_prompt_logprobs,
|
prompt_logprobs=num_prompt_logprobs,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
stop=stop)
|
stop=stop,
|
||||||
|
)
|
||||||
|
|
||||||
return self.generate_w_logprobs(prompts,
|
return self.generate_w_logprobs(
|
||||||
greedy_logprobs_params,
|
prompts,
|
||||||
images=images,
|
greedy_logprobs_params,
|
||||||
audios=audios,
|
images=images,
|
||||||
videos=videos,
|
audios=audios,
|
||||||
**kwargs)
|
videos=videos,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
|
def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
|
||||||
"""
|
"""
|
||||||
@@ -919,10 +929,9 @@ class VllmRunner:
|
|||||||
:param prompts: list of prompts to score
|
:param prompts: list of prompts to score
|
||||||
:return: perplexity score of each prompt
|
:return: perplexity score of each prompt
|
||||||
"""
|
"""
|
||||||
outputs = self.generate_greedy_logprobs(prompts,
|
outputs = self.generate_greedy_logprobs(
|
||||||
max_tokens=1,
|
prompts, max_tokens=1, num_logprobs=None, num_prompt_logprobs=0
|
||||||
num_logprobs=None,
|
)
|
||||||
num_prompt_logprobs=0)
|
|
||||||
|
|
||||||
perplexities = []
|
perplexities = []
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
@@ -951,15 +960,13 @@ class VllmRunner:
|
|||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
concurrency_limit: Optional[int] = None,
|
concurrency_limit: Optional[int] = None,
|
||||||
) -> list[tuple[list[list[int]], list[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
inputs = self.get_inputs(prompts,
|
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||||
images=images,
|
|
||||||
videos=videos,
|
|
||||||
audios=audios)
|
|
||||||
|
|
||||||
outputs = self.llm.beam_search(inputs,
|
outputs = self.llm.beam_search(
|
||||||
BeamSearchParams(beam_width=beam_width,
|
inputs,
|
||||||
max_tokens=max_tokens),
|
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens),
|
||||||
concurrency_limit=concurrency_limit)
|
concurrency_limit=concurrency_limit,
|
||||||
|
)
|
||||||
returned_outputs = []
|
returned_outputs = []
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
token_ids = [x.tokens for x in output.sequences]
|
token_ids = [x.tokens for x in output.sequences]
|
||||||
@@ -971,17 +978,16 @@ class VllmRunner:
|
|||||||
req_outputs = self.llm.classify(prompts)
|
req_outputs = self.llm.classify(prompts)
|
||||||
return [req_output.outputs.probs for req_output in req_outputs]
|
return [req_output.outputs.probs for req_output in req_outputs]
|
||||||
|
|
||||||
def embed(self,
|
def embed(
|
||||||
prompts: list[str],
|
self,
|
||||||
images: Optional[PromptImageInput] = None,
|
prompts: list[str],
|
||||||
videos: Optional[PromptVideoInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
*args,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
**kwargs) -> list[list[float]]:
|
*args,
|
||||||
inputs = self.get_inputs(prompts,
|
**kwargs,
|
||||||
images=images,
|
) -> list[list[float]]:
|
||||||
videos=videos,
|
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||||
audios=audios)
|
|
||||||
|
|
||||||
req_outputs = self.llm.embed(inputs, *args, **kwargs)
|
req_outputs = self.llm.embed(inputs, *args, **kwargs)
|
||||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||||
@@ -1026,6 +1032,7 @@ def vllm_runner():
|
|||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def temporary_enable_log_propagate():
|
def temporary_enable_log_propagate():
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger("vllm")
|
logger = logging.getLogger("vllm")
|
||||||
logger.propagate = True
|
logger.propagate = True
|
||||||
yield
|
yield
|
||||||
@@ -1045,6 +1052,7 @@ def num_gpus_available():
|
|||||||
in current process."""
|
in current process."""
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
return current_platform.device_count()
|
return current_platform.device_count()
|
||||||
|
|
||||||
|
|
||||||
@@ -1058,12 +1066,11 @@ _dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
|
|||||||
def dummy_opt_path():
|
def dummy_opt_path():
|
||||||
json_path = os.path.join(_dummy_opt_path, "config.json")
|
json_path = os.path.join(_dummy_opt_path, "config.json")
|
||||||
if not os.path.exists(_dummy_opt_path):
|
if not os.path.exists(_dummy_opt_path):
|
||||||
snapshot_download(repo_id="facebook/opt-125m",
|
snapshot_download(
|
||||||
local_dir=_dummy_opt_path,
|
repo_id="facebook/opt-125m",
|
||||||
ignore_patterns=[
|
local_dir=_dummy_opt_path,
|
||||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
ignore_patterns=["*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack"],
|
||||||
"*.msgpack"
|
)
|
||||||
])
|
|
||||||
assert os.path.exists(json_path)
|
assert os.path.exists(json_path)
|
||||||
with open(json_path) as f:
|
with open(json_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
@@ -1077,12 +1084,18 @@ def dummy_opt_path():
|
|||||||
def dummy_llava_path():
|
def dummy_llava_path():
|
||||||
json_path = os.path.join(_dummy_llava_path, "config.json")
|
json_path = os.path.join(_dummy_llava_path, "config.json")
|
||||||
if not os.path.exists(_dummy_llava_path):
|
if not os.path.exists(_dummy_llava_path):
|
||||||
snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
|
snapshot_download(
|
||||||
local_dir=_dummy_llava_path,
|
repo_id="llava-hf/llava-1.5-7b-hf",
|
||||||
ignore_patterns=[
|
local_dir=_dummy_llava_path,
|
||||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
ignore_patterns=[
|
||||||
"*.msgpack", "*.safetensors"
|
"*.bin",
|
||||||
])
|
"*.bin.index.json",
|
||||||
|
"*.pt",
|
||||||
|
"*.h5",
|
||||||
|
"*.msgpack",
|
||||||
|
"*.safetensors",
|
||||||
|
],
|
||||||
|
)
|
||||||
assert os.path.exists(json_path)
|
assert os.path.exists(json_path)
|
||||||
with open(json_path) as f:
|
with open(json_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
@@ -1096,12 +1109,18 @@ def dummy_llava_path():
|
|||||||
def dummy_gemma2_embedding_path():
|
def dummy_gemma2_embedding_path():
|
||||||
json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
|
json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
|
||||||
if not os.path.exists(_dummy_gemma2_embedding_path):
|
if not os.path.exists(_dummy_gemma2_embedding_path):
|
||||||
snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
|
snapshot_download(
|
||||||
local_dir=_dummy_gemma2_embedding_path,
|
repo_id="BAAI/bge-multilingual-gemma2",
|
||||||
ignore_patterns=[
|
local_dir=_dummy_gemma2_embedding_path,
|
||||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
ignore_patterns=[
|
||||||
"*.msgpack", "*.safetensors"
|
"*.bin",
|
||||||
])
|
"*.bin.index.json",
|
||||||
|
"*.pt",
|
||||||
|
"*.h5",
|
||||||
|
"*.msgpack",
|
||||||
|
"*.safetensors",
|
||||||
|
],
|
||||||
|
)
|
||||||
assert os.path.exists(json_path)
|
assert os.path.exists(json_path)
|
||||||
with open(json_path) as f:
|
with open(json_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
@@ -1114,10 +1133,9 @@ def dummy_gemma2_embedding_path():
|
|||||||
# Add the flag `--optional` to allow run tests
|
# Add the flag `--optional` to allow run tests
|
||||||
# that are marked with @pytest.mark.optional
|
# that are marked with @pytest.mark.optional
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser):
|
||||||
parser.addoption("--optional",
|
parser.addoption(
|
||||||
action="store_true",
|
"--optional", action="store_true", default=False, help="run optional test"
|
||||||
default=False,
|
)
|
||||||
help="run optional test")
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(config, items):
|
def pytest_collection_modifyitems(config, items):
|
||||||
@@ -1185,7 +1203,6 @@ def _find_free_port() -> int:
|
|||||||
|
|
||||||
|
|
||||||
class LocalAssetServer:
|
class LocalAssetServer:
|
||||||
|
|
||||||
address: str
|
address: str
|
||||||
port: int
|
port: int
|
||||||
server: Optional[http.server.ThreadingHTTPServer]
|
server: Optional[http.server.ThreadingHTTPServer]
|
||||||
@@ -1200,9 +1217,9 @@ class LocalAssetServer:
|
|||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.port = _find_free_port()
|
self.port = _find_free_port()
|
||||||
self.server = http.server.ThreadingHTTPServer(
|
self.server = http.server.ThreadingHTTPServer(
|
||||||
(self.address, self.port), AssetHandler)
|
(self.address, self.port), AssetHandler
|
||||||
self.thread = threading.Thread(target=self.server.serve_forever,
|
)
|
||||||
daemon=True)
|
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
||||||
self.thread.start()
|
self.thread.start()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -1236,7 +1253,7 @@ class LocalAssetServer:
|
|||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def local_asset_server() -> Generator[LocalAssetServer, None, None]:
|
def local_asset_server() -> Generator[LocalAssetServer, None, None]:
|
||||||
"""
|
"""
|
||||||
Starts a thread based HTTP server bound to 127.0.0.1 on a random free port.
|
Starts a thread based HTTP server bound to 127.0.0.1 on a random free port.
|
||||||
The server currently servers images at:
|
The server currently servers images at:
|
||||||
http://127.0.0.1:<port>/<name>.<ext>
|
http://127.0.0.1:<port>/<name>.<ext>
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from vllm.platforms import current_platform
|
|||||||
def check_cuda_context():
|
def check_cuda_context():
|
||||||
"""Check CUDA driver context status"""
|
"""Check CUDA driver context status"""
|
||||||
try:
|
try:
|
||||||
cuda = ctypes.CDLL('libcuda.so')
|
cuda = ctypes.CDLL("libcuda.so")
|
||||||
device = ctypes.c_int()
|
device = ctypes.c_int()
|
||||||
result = cuda.cuCtxGetDevice(ctypes.byref(device))
|
result = cuda.cuCtxGetDevice(ctypes.byref(device))
|
||||||
return (True, device.value) if result == 0 else (False, None)
|
return (True, device.value) if result == 0 else (False, None)
|
||||||
@@ -27,9 +27,11 @@ def run_cuda_test_in_thread(device_input, expected_device_id):
|
|||||||
# New thread should have no CUDA context initially
|
# New thread should have no CUDA context initially
|
||||||
valid_before, device_before = check_cuda_context()
|
valid_before, device_before = check_cuda_context()
|
||||||
if valid_before:
|
if valid_before:
|
||||||
return False, \
|
return (
|
||||||
"CUDA context should not exist in new thread, " \
|
False,
|
||||||
f"got device {device_before}"
|
"CUDA context should not exist in new thread, "
|
||||||
|
f"got device {device_before}",
|
||||||
|
)
|
||||||
|
|
||||||
# Test setting CUDA context
|
# Test setting CUDA context
|
||||||
current_platform.set_device(device_input)
|
current_platform.set_device(device_input)
|
||||||
@@ -39,8 +41,7 @@ def run_cuda_test_in_thread(device_input, expected_device_id):
|
|||||||
if not valid_after:
|
if not valid_after:
|
||||||
return False, "CUDA context should be valid after set_cuda_context"
|
return False, "CUDA context should be valid after set_cuda_context"
|
||||||
if device_id != expected_device_id:
|
if device_id != expected_device_id:
|
||||||
return False, \
|
return False, f"Expected device {expected_device_id}, got {device_id}"
|
||||||
f"Expected device {expected_device_id}, got {device_id}"
|
|
||||||
|
|
||||||
return True, "Success"
|
return True, "Success"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -50,30 +51,30 @@ def run_cuda_test_in_thread(device_input, expected_device_id):
|
|||||||
class TestSetCudaContext:
|
class TestSetCudaContext:
|
||||||
"""Test suite for the set_cuda_context function."""
|
"""Test suite for the set_cuda_context function."""
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||||
reason="CUDA not available")
|
@pytest.mark.parametrize(
|
||||||
@pytest.mark.parametrize(argnames="device_input,expected_device_id",
|
argnames="device_input,expected_device_id",
|
||||||
argvalues=[
|
argvalues=[
|
||||||
(0, 0),
|
(0, 0),
|
||||||
(torch.device('cuda:0'), 0),
|
(torch.device("cuda:0"), 0),
|
||||||
('cuda:0', 0),
|
("cuda:0", 0),
|
||||||
],
|
],
|
||||||
ids=["int", "torch_device", "string"])
|
ids=["int", "torch_device", "string"],
|
||||||
def test_set_cuda_context_parametrized(self, device_input,
|
)
|
||||||
expected_device_id):
|
def test_set_cuda_context_parametrized(self, device_input, expected_device_id):
|
||||||
"""Test setting CUDA context in isolated threads."""
|
"""Test setting CUDA context in isolated threads."""
|
||||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
future = executor.submit(run_cuda_test_in_thread, device_input,
|
future = executor.submit(
|
||||||
expected_device_id)
|
run_cuda_test_in_thread, device_input, expected_device_id
|
||||||
|
)
|
||||||
success, message = future.result(timeout=30)
|
success, message = future.result(timeout=30)
|
||||||
assert success, message
|
assert success, message
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||||
reason="CUDA not available")
|
|
||||||
def test_set_cuda_context_invalid_device_type(self):
|
def test_set_cuda_context_invalid_device_type(self):
|
||||||
"""Test error handling for invalid device type."""
|
"""Test error handling for invalid device type."""
|
||||||
with pytest.raises(ValueError, match="Expected a cuda device"):
|
with pytest.raises(ValueError, match="Expected a cuda device"):
|
||||||
current_platform.set_device(torch.device('cpu'))
|
current_platform.set_device(torch.device("cpu"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -17,20 +17,16 @@ def test_computed_prefix_blocks(model: str):
|
|||||||
prompt = (
|
prompt = (
|
||||||
"You are a helpful assistant. How do I build a car from cardboard and "
|
"You are a helpful assistant. How do I build a car from cardboard and "
|
||||||
"paper clips? Is there an easy to follow video tutorial available "
|
"paper clips? Is there an easy to follow video tutorial available "
|
||||||
"online for free?")
|
"online for free?"
|
||||||
|
)
|
||||||
|
|
||||||
llm = LLM(model=model)
|
llm = LLM(model=model)
|
||||||
sampling_params = SamplingParams(max_tokens=10,
|
sampling_params = SamplingParams(max_tokens=10, temperature=0.0, detokenize=False)
|
||||||
temperature=0.0,
|
|
||||||
detokenize=False)
|
|
||||||
|
|
||||||
outputs_no_detokenization = llm.generate(prompt,
|
outputs_no_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0]
|
||||||
sampling_params)[0].outputs[0]
|
|
||||||
sampling_params.detokenize = True
|
sampling_params.detokenize = True
|
||||||
outputs_with_detokenization = llm.generate(prompt,
|
outputs_with_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0]
|
||||||
sampling_params)[0].outputs[0]
|
|
||||||
|
|
||||||
assert outputs_no_detokenization.text == ''
|
assert outputs_no_detokenization.text == ""
|
||||||
assert outputs_with_detokenization.text != ''
|
assert outputs_with_detokenization.text != ""
|
||||||
assert outputs_no_detokenization.token_ids == \
|
assert outputs_no_detokenization.token_ids == outputs_with_detokenization.token_ids
|
||||||
outputs_with_detokenization.token_ids
|
|
||||||
|
|||||||
@@ -8,15 +8,17 @@ from vllm import SamplingParams
|
|||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.detokenizer import FastIncrementalDetokenizer
|
from vllm.v1.engine.detokenizer import FastIncrementalDetokenizer
|
||||||
|
|
||||||
PROMPT = "Hello, my name is Lee, and I'm a student in the " + \
|
PROMPT = "Hello, my name is Lee, and I'm a student in the " + "college of engineering"
|
||||||
"college of engineering"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("min_tokens,stop,truth", [
|
@pytest.mark.parametrize(
|
||||||
(0, None, " is Lee, and I'm a student in the college of engineering"),
|
"min_tokens,stop,truth",
|
||||||
(0, "e", " is L"),
|
[
|
||||||
(5, "e", " is Lee, and I'm a stud"),
|
(0, None, " is Lee, and I'm a student in the college of engineering"),
|
||||||
])
|
(0, "e", " is L"),
|
||||||
|
(5, "e", " is Lee, and I'm a stud"),
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
|
def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
|
||||||
"""Test for a specific min_tokens and stop.
|
"""Test for a specific min_tokens and stop.
|
||||||
|
|
||||||
@@ -31,16 +33,18 @@ def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
min_tokens=min_tokens,
|
min_tokens=min_tokens,
|
||||||
)
|
)
|
||||||
request = EngineCoreRequest(request_id="",
|
request = EngineCoreRequest(
|
||||||
prompt_token_ids=prompt_token_ids,
|
request_id="",
|
||||||
mm_features=None,
|
prompt_token_ids=prompt_token_ids,
|
||||||
sampling_params=params,
|
mm_features=None,
|
||||||
pooling_params=None,
|
sampling_params=params,
|
||||||
eos_token_id=None,
|
pooling_params=None,
|
||||||
arrival_time=0.0,
|
eos_token_id=None,
|
||||||
lora_request=None,
|
arrival_time=0.0,
|
||||||
cache_salt=None,
|
lora_request=None,
|
||||||
data_parallel_rank=None)
|
cache_salt=None,
|
||||||
|
data_parallel_rank=None,
|
||||||
|
)
|
||||||
|
|
||||||
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
|
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
|
||||||
|
|
||||||
|
|||||||
@@ -31,34 +31,39 @@ def test_stop_reason(vllm_model, example_prompts):
|
|||||||
llm = vllm_model.llm
|
llm = vllm_model.llm
|
||||||
|
|
||||||
# test stop token
|
# test stop token
|
||||||
outputs = llm.generate(example_prompts,
|
outputs = llm.generate(
|
||||||
sampling_params=SamplingParams(
|
example_prompts,
|
||||||
ignore_eos=True,
|
sampling_params=SamplingParams(
|
||||||
seed=SEED,
|
ignore_eos=True,
|
||||||
max_tokens=MAX_TOKENS,
|
seed=SEED,
|
||||||
stop_token_ids=[stop_token_id]))
|
max_tokens=MAX_TOKENS,
|
||||||
|
stop_token_ids=[stop_token_id],
|
||||||
|
),
|
||||||
|
)
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
output = output.outputs[0]
|
output = output.outputs[0]
|
||||||
assert output.finish_reason == "stop"
|
assert output.finish_reason == "stop"
|
||||||
assert output.stop_reason == stop_token_id
|
assert output.stop_reason == stop_token_id
|
||||||
|
|
||||||
# test stop string
|
# test stop string
|
||||||
outputs = llm.generate(example_prompts,
|
outputs = llm.generate(
|
||||||
sampling_params=SamplingParams(
|
example_prompts,
|
||||||
ignore_eos=True,
|
sampling_params=SamplingParams(
|
||||||
seed=SEED,
|
ignore_eos=True, seed=SEED, max_tokens=MAX_TOKENS, stop="."
|
||||||
max_tokens=MAX_TOKENS,
|
),
|
||||||
stop="."))
|
)
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
output = output.outputs[0]
|
output = output.outputs[0]
|
||||||
assert output.finish_reason == "stop"
|
assert output.finish_reason == "stop"
|
||||||
assert output.stop_reason == STOP_STR
|
assert output.stop_reason == STOP_STR
|
||||||
|
|
||||||
# test EOS token
|
# test EOS token
|
||||||
outputs = llm.generate(example_prompts,
|
outputs = llm.generate(
|
||||||
sampling_params=SamplingParams(
|
example_prompts,
|
||||||
seed=SEED, max_tokens=MAX_TOKENS))
|
sampling_params=SamplingParams(seed=SEED, max_tokens=MAX_TOKENS),
|
||||||
|
)
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
output = output.outputs[0]
|
output = output.outputs[0]
|
||||||
assert output.finish_reason == "length" or (
|
assert output.finish_reason == "length" or (
|
||||||
output.finish_reason == "stop" and output.stop_reason is None)
|
output.finish_reason == "stop" and output.stop_reason is None
|
||||||
|
)
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ def include_stop_str_in_output(request):
|
|||||||
|
|
||||||
|
|
||||||
class _DummyDetokenizer(BaseIncrementalDetokenizer):
|
class _DummyDetokenizer(BaseIncrementalDetokenizer):
|
||||||
|
|
||||||
def __init__(self, request: EngineCoreRequest):
|
def __init__(self, request: EngineCoreRequest):
|
||||||
super().__init__(request)
|
super().__init__(request)
|
||||||
|
|
||||||
@@ -27,7 +26,8 @@ def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0):
|
|||||||
params = SamplingParams(
|
params = SamplingParams(
|
||||||
stop=stop,
|
stop=stop,
|
||||||
include_stop_str_in_output=include_stop_str_in_output,
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
min_tokens=min_tokens)
|
min_tokens=min_tokens,
|
||||||
|
)
|
||||||
# Keep other fields minimal for unit test purposes.
|
# Keep other fields minimal for unit test purposes.
|
||||||
req = EngineCoreRequest(
|
req = EngineCoreRequest(
|
||||||
request_id="test",
|
request_id="test",
|
||||||
@@ -44,26 +44,25 @@ def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0):
|
|||||||
return req
|
return req
|
||||||
|
|
||||||
|
|
||||||
def test_stop_string_while_stop_token_terminates(
|
def test_stop_string_while_stop_token_terminates(include_stop_str_in_output: bool):
|
||||||
include_stop_str_in_output: bool):
|
|
||||||
"""
|
"""
|
||||||
This test verifies that the detokenizer correctly handles the case where
|
This test verifies that the detokenizer correctly handles the case where
|
||||||
the generated token sequence contains both:
|
the generated token sequence contains both:
|
||||||
- a stop token
|
- a stop token
|
||||||
- an <eos> token
|
- an <eos> token
|
||||||
|
|
||||||
The detokenizer should respect the stop string and truncate the output
|
The detokenizer should respect the stop string and truncate the output
|
||||||
accordingly.
|
accordingly.
|
||||||
|
|
||||||
Imagine the following sequence:
|
Imagine the following sequence:
|
||||||
- "abcdeZ" is generated, where "Z" is the <eos> token.
|
- "abcdeZ" is generated, where "Z" is the <eos> token.
|
||||||
- "cd" is the stop string.
|
- "cd" is the stop string.
|
||||||
|
|
||||||
If include_stop_str_in_output=False, the detokenizer should truncate the
|
If include_stop_str_in_output=False, the detokenizer should truncate the
|
||||||
output to "ab" because the stop string "cd" is excluded.
|
output to "ab" because the stop string "cd" is excluded.
|
||||||
If include_stop_str_in_output=True, the detokenizer should include the stop
|
If include_stop_str_in_output=True, the detokenizer should include the stop
|
||||||
string "cd" in the output, resulting in "abcd".
|
string "cd" in the output, resulting in "abcd".
|
||||||
|
|
||||||
|
|
||||||
This verifies the behavioral change introduced in BaseIncrementalDetokenizer
|
This verifies the behavioral change introduced in BaseIncrementalDetokenizer
|
||||||
where stop-string evaluation occurs before the early-return on
|
where stop-string evaluation occurs before the early-return on
|
||||||
@@ -78,8 +77,9 @@ def test_stop_string_while_stop_token_terminates(
|
|||||||
token_ids = [ord(c) for c in generated_text]
|
token_ids = [ord(c) for c in generated_text]
|
||||||
|
|
||||||
# Create a request with the stop string and initialize the detokenizer.
|
# Create a request with the stop string and initialize the detokenizer.
|
||||||
req = _make_request(stop=[stop_string],
|
req = _make_request(
|
||||||
include_stop_str_in_output=include_stop_str_in_output)
|
stop=[stop_string], include_stop_str_in_output=include_stop_str_in_output
|
||||||
|
)
|
||||||
detok = _DummyDetokenizer(req)
|
detok = _DummyDetokenizer(req)
|
||||||
|
|
||||||
# Simulate that the last token ('Z') is a stop token (stop_terminated=True).
|
# Simulate that the last token ('Z') is a stop token (stop_terminated=True).
|
||||||
@@ -99,5 +99,4 @@ def test_stop_string_while_stop_token_terminates(
|
|||||||
|
|
||||||
# get_next_output_text should return the full text when finished=True.
|
# get_next_output_text should return the full text when finished=True.
|
||||||
# (Buffering only applies during streaming when finished=False.)
|
# (Buffering only applies during streaming when finished=False.)
|
||||||
assert detok.get_next_output_text(finished=True,
|
assert detok.get_next_output_text(finished=True, delta=False) == expected_text
|
||||||
delta=False) == expected_text
|
|
||||||
|
|||||||
@@ -11,12 +11,14 @@ MODEL = "meta-llama/llama-2-7b-hf"
|
|||||||
MAX_TOKENS = 200
|
MAX_TOKENS = 200
|
||||||
|
|
||||||
|
|
||||||
def _test_stopping(llm: LLM,
|
def _test_stopping(
|
||||||
expected_output: str,
|
llm: LLM,
|
||||||
expected_reason: Any,
|
expected_output: str,
|
||||||
stop: Optional[list[str]] = None,
|
expected_reason: Any,
|
||||||
stop_token_ids: Optional[list[int]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
include_in_output: bool = False) -> None:
|
stop_token_ids: Optional[list[int]] = None,
|
||||||
|
include_in_output: bool = False,
|
||||||
|
) -> None:
|
||||||
output = llm.generate(
|
output = llm.generate(
|
||||||
"A story about vLLM:\n",
|
"A story about vLLM:\n",
|
||||||
SamplingParams(
|
SamplingParams(
|
||||||
@@ -25,7 +27,8 @@ def _test_stopping(llm: LLM,
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
include_stop_str_in_output=include_in_output,
|
include_stop_str_in_output=include_in_output,
|
||||||
))[0].outputs[0]
|
),
|
||||||
|
)[0].outputs[0]
|
||||||
|
|
||||||
assert output is not None
|
assert output is not None
|
||||||
assert output.text == expected_output
|
assert output.text == expected_output
|
||||||
@@ -33,17 +36,21 @@ def _test_stopping(llm: LLM,
|
|||||||
|
|
||||||
|
|
||||||
def _stop_basic(llm):
|
def _stop_basic(llm):
|
||||||
_test_stopping(llm,
|
_test_stopping(
|
||||||
stop=["."],
|
llm,
|
||||||
include_in_output=False,
|
stop=["."],
|
||||||
expected_output="VLLM is a 100% volunteer organization",
|
include_in_output=False,
|
||||||
expected_reason=".")
|
expected_output="VLLM is a 100% volunteer organization",
|
||||||
|
expected_reason=".",
|
||||||
|
)
|
||||||
|
|
||||||
_test_stopping(llm,
|
_test_stopping(
|
||||||
stop=["."],
|
llm,
|
||||||
include_in_output=True,
|
stop=["."],
|
||||||
expected_output="VLLM is a 100% volunteer organization.",
|
include_in_output=True,
|
||||||
expected_reason=".")
|
expected_output="VLLM is a 100% volunteer organization.",
|
||||||
|
expected_reason=".",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _stop_multi_tokens(llm):
|
def _stop_multi_tokens(llm):
|
||||||
@@ -52,45 +59,54 @@ def _stop_multi_tokens(llm):
|
|||||||
stop=["group of peo", "short"],
|
stop=["group of peo", "short"],
|
||||||
include_in_output=False,
|
include_in_output=False,
|
||||||
expected_output="VLLM is a 100% volunteer organization. We are a ",
|
expected_output="VLLM is a 100% volunteer organization. We are a ",
|
||||||
expected_reason="group of peo")
|
expected_reason="group of peo",
|
||||||
|
)
|
||||||
|
|
||||||
_test_stopping(
|
_test_stopping(
|
||||||
llm,
|
llm,
|
||||||
stop=["group of peo", "short"],
|
stop=["group of peo", "short"],
|
||||||
include_in_output=True,
|
include_in_output=True,
|
||||||
expected_output=
|
expected_output="VLLM is a 100% volunteer organization. We are a group of peo",
|
||||||
"VLLM is a 100% volunteer organization. We are a group of peo",
|
expected_reason="group of peo",
|
||||||
expected_reason="group of peo")
|
)
|
||||||
|
|
||||||
|
|
||||||
def _stop_partial_token(llm):
|
def _stop_partial_token(llm):
|
||||||
_test_stopping(llm,
|
_test_stopping(
|
||||||
stop=["gani"],
|
llm,
|
||||||
include_in_output=False,
|
stop=["gani"],
|
||||||
expected_output="VLLM is a 100% volunteer or",
|
include_in_output=False,
|
||||||
expected_reason="gani")
|
expected_output="VLLM is a 100% volunteer or",
|
||||||
|
expected_reason="gani",
|
||||||
|
)
|
||||||
|
|
||||||
_test_stopping(llm,
|
_test_stopping(
|
||||||
stop=["gani"],
|
llm,
|
||||||
include_in_output=True,
|
stop=["gani"],
|
||||||
expected_output="VLLM is a 100% volunteer organi",
|
include_in_output=True,
|
||||||
expected_reason="gani")
|
expected_output="VLLM is a 100% volunteer organi",
|
||||||
|
expected_reason="gani",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _stop_token_id(llm):
|
def _stop_token_id(llm):
|
||||||
# token id 13013 => " organization"
|
# token id 13013 => " organization"
|
||||||
|
|
||||||
_test_stopping(llm,
|
_test_stopping(
|
||||||
stop_token_ids=[13013],
|
llm,
|
||||||
include_in_output=False,
|
stop_token_ids=[13013],
|
||||||
expected_output="VLLM is a 100% volunteer",
|
include_in_output=False,
|
||||||
expected_reason=13013)
|
expected_output="VLLM is a 100% volunteer",
|
||||||
|
expected_reason=13013,
|
||||||
|
)
|
||||||
|
|
||||||
_test_stopping(llm,
|
_test_stopping(
|
||||||
stop_token_ids=[13013],
|
llm,
|
||||||
include_in_output=True,
|
stop_token_ids=[13013],
|
||||||
expected_output="VLLM is a 100% volunteer organization",
|
include_in_output=True,
|
||||||
expected_reason=13013)
|
expected_output="VLLM is a 100% volunteer organization",
|
||||||
|
expected_reason=13013,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
|
|||||||
@@ -111,8 +111,7 @@ class MockSubscriber:
|
|||||||
self.last_seq = -1
|
self.last_seq = -1
|
||||||
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
|
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
|
||||||
|
|
||||||
def receive_one(self,
|
def receive_one(self, timeout=1000) -> Union[tuple[int, SampleBatch], None]:
|
||||||
timeout=1000) -> Union[tuple[int, SampleBatch], None]:
|
|
||||||
"""Receive a single message with timeout"""
|
"""Receive a single message with timeout"""
|
||||||
if not self.sub.poll(timeout):
|
if not self.sub.poll(timeout):
|
||||||
return None
|
return None
|
||||||
@@ -135,8 +134,7 @@ class MockSubscriber:
|
|||||||
|
|
||||||
self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big"))
|
self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big"))
|
||||||
|
|
||||||
def receive_replay(self,
|
def receive_replay(self, socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
|
||||||
socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
|
|
||||||
"""Receive replayed messages from a specific replay socket"""
|
"""Receive replayed messages from a specific replay socket"""
|
||||||
if not self.replay_sockets:
|
if not self.replay_sockets:
|
||||||
raise ValueError("Replay sockets not initialized")
|
raise ValueError("Replay sockets not initialized")
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||||
from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa
|
from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa
|
||||||
CustomAllreduce)
|
CustomAllreduce,
|
||||||
|
)
|
||||||
|
|
||||||
# create a cpu process group for communicating metadata (ipc handle)
|
# create a cpu process group for communicating metadata (ipc handle)
|
||||||
dist.init_process_group(backend="gloo")
|
dist.init_process_group(backend="gloo")
|
||||||
@@ -52,7 +53,8 @@ for p in pointers:
|
|||||||
assert ord(host_data[i]) == byte_value, (
|
assert ord(host_data[i]) == byte_value, (
|
||||||
f"Rank {rank} failed"
|
f"Rank {rank} failed"
|
||||||
f" to verify buffer {p}. Expected {byte_value}, "
|
f" to verify buffer {p}. Expected {byte_value}, "
|
||||||
f"got {ord(host_data[i])}")
|
f"got {ord(host_data[i])}"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Rank {rank} verified all buffers")
|
print(f"Rank {rank} verified all buffers")
|
||||||
|
|
||||||
|
|||||||
@@ -13,13 +13,19 @@ import pytest
|
|||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
|
from vllm.distributed import (
|
||||||
tensor_model_parallel_all_gather,
|
broadcast_tensor_dict,
|
||||||
tensor_model_parallel_all_reduce,
|
get_pp_group,
|
||||||
tensor_model_parallel_reduce_scatter)
|
tensor_model_parallel_all_gather,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
tensor_model_parallel_reduce_scatter,
|
||||||
|
)
|
||||||
|
|
||||||
from ..utils import (init_test_distributed_environment, multi_gpu_test,
|
from ..utils import (
|
||||||
multi_process_parallel)
|
init_test_distributed_environment,
|
||||||
|
multi_gpu_test,
|
||||||
|
multi_process_parallel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(num_gpus=1, max_calls=1)
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
@@ -37,12 +43,11 @@ def all_reduce_test_worker(
|
|||||||
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
num_elements = 8
|
num_elements = 8
|
||||||
all_tensors = [
|
all_tensors = [
|
||||||
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
|
torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
|
||||||
(r + 1) for r in range(tp_size)
|
for r in range(tp_size)
|
||||||
]
|
]
|
||||||
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||||
t = all_tensors[rank % tp_size]
|
t = all_tensors[rank % tp_size]
|
||||||
@@ -51,28 +56,31 @@ def all_reduce_test_worker(
|
|||||||
|
|
||||||
|
|
||||||
@ray.remote(num_gpus=1, max_calls=1)
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int,
|
def reduce_scatter_test_worker(
|
||||||
pp_size: int, rank: int,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
distributed_init_port: str):
|
tp_size: int,
|
||||||
|
pp_size: int,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_port: str,
|
||||||
|
):
|
||||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||||
# so that each worker can see all the GPUs
|
# so that each worker can see all the GPUs
|
||||||
# they will be able to set the device to the correct GPU
|
# they will be able to set the device to the correct GPU
|
||||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
|
|
||||||
num_elements = 8
|
num_elements = 8
|
||||||
all_tensors = [
|
all_tensors = [
|
||||||
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
|
torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
|
||||||
(r + 1) for r in range(tp_size)
|
for r in range(tp_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
index = rank % tp_size
|
index = rank % tp_size
|
||||||
partition_size = num_elements // tp_size
|
partition_size = num_elements // tp_size
|
||||||
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||||
expected = all_reduce[index * partition_size:(index + 1) * partition_size]
|
expected = all_reduce[index * partition_size : (index + 1) * partition_size]
|
||||||
t = all_tensors[index]
|
t = all_tensors[index]
|
||||||
t = tensor_model_parallel_reduce_scatter(t, 0)
|
t = tensor_model_parallel_reduce_scatter(t, 0)
|
||||||
torch.testing.assert_close(t, expected)
|
torch.testing.assert_close(t, expected)
|
||||||
@@ -92,8 +100,7 @@ def all_gather_test_worker(
|
|||||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
num_dimensions = 3
|
num_dimensions = 3
|
||||||
tensor_size = list(range(2, num_dimensions + 2))
|
tensor_size = list(range(2, num_dimensions + 2))
|
||||||
total_size = 1
|
total_size = 1
|
||||||
@@ -101,8 +108,10 @@ def all_gather_test_worker(
|
|||||||
total_size *= s
|
total_size *= s
|
||||||
for all_gather_dimension in range(num_dimensions):
|
for all_gather_dimension in range(num_dimensions):
|
||||||
all_tensors = [
|
all_tensors = [
|
||||||
torch.arange(total_size, dtype=torch.float32,
|
torch.arange(total_size, dtype=torch.float32, device="cuda").reshape(
|
||||||
device="cuda").reshape(tensor_size) * (r + 1)
|
tensor_size
|
||||||
|
)
|
||||||
|
* (r + 1)
|
||||||
for r in range(tp_size)
|
for r in range(tp_size)
|
||||||
]
|
]
|
||||||
expected = torch.cat(all_tensors, dim=all_gather_dimension)
|
expected = torch.cat(all_tensors, dim=all_gather_dimension)
|
||||||
@@ -125,8 +134,7 @@ def broadcast_tensor_dict_test_worker(
|
|||||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
test_dict = {
|
test_dict = {
|
||||||
# device tensor
|
# device tensor
|
||||||
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
||||||
@@ -134,10 +142,7 @@ def broadcast_tensor_dict_test_worker(
|
|||||||
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
||||||
"c": "test",
|
"c": "test",
|
||||||
"d": [1, 2, 3],
|
"d": [1, 2, 3],
|
||||||
"e": {
|
"e": {"a": 1, "b": 2},
|
||||||
"a": 1,
|
|
||||||
"b": 2
|
|
||||||
},
|
|
||||||
# empty tensor
|
# empty tensor
|
||||||
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
||||||
}
|
}
|
||||||
@@ -166,8 +171,7 @@ def send_recv_tensor_dict_test_worker(
|
|||||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
|
|
||||||
test_dict = {
|
test_dict = {
|
||||||
# device tensor
|
# device tensor
|
||||||
@@ -176,10 +180,7 @@ def send_recv_tensor_dict_test_worker(
|
|||||||
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
||||||
"c": "test",
|
"c": "test",
|
||||||
"d": [1, 2, 3],
|
"d": [1, 2, 3],
|
||||||
"e": {
|
"e": {"a": 1, "b": 2},
|
||||||
"a": 1,
|
|
||||||
"b": 2
|
|
||||||
},
|
|
||||||
# empty tensor
|
# empty tensor
|
||||||
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
||||||
}
|
}
|
||||||
@@ -211,8 +212,7 @@ def send_recv_test_worker(
|
|||||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
|
|
||||||
size = 64
|
size = 64
|
||||||
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")
|
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")
|
||||||
@@ -229,10 +229,10 @@ def send_recv_test_worker(
|
|||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@pytest.mark.parametrize("tp_size", [2])
|
@pytest.mark.parametrize("tp_size", [2])
|
||||||
@pytest.mark.parametrize("test_target", [
|
@pytest.mark.parametrize(
|
||||||
all_reduce_test_worker, all_gather_test_worker,
|
"test_target",
|
||||||
broadcast_tensor_dict_test_worker
|
[all_reduce_test_worker, all_gather_test_worker, broadcast_tensor_dict_test_worker],
|
||||||
])
|
)
|
||||||
def test_multi_process_tensor_parallel(
|
def test_multi_process_tensor_parallel(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
@@ -244,7 +244,8 @@ def test_multi_process_tensor_parallel(
|
|||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@pytest.mark.parametrize("pp_size", [2])
|
@pytest.mark.parametrize("pp_size", [2])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
|
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]
|
||||||
|
)
|
||||||
def test_multi_process_pipeline_parallel(
|
def test_multi_process_pipeline_parallel(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
pp_size: int,
|
pp_size: int,
|
||||||
@@ -256,11 +257,16 @@ def test_multi_process_pipeline_parallel(
|
|||||||
@multi_gpu_test(num_gpus=4)
|
@multi_gpu_test(num_gpus=4)
|
||||||
@pytest.mark.parametrize("tp_size", [2])
|
@pytest.mark.parametrize("tp_size", [2])
|
||||||
@pytest.mark.parametrize("pp_size", [2])
|
@pytest.mark.parametrize("pp_size", [2])
|
||||||
@pytest.mark.parametrize("test_target", [
|
@pytest.mark.parametrize(
|
||||||
send_recv_test_worker, send_recv_tensor_dict_test_worker,
|
"test_target",
|
||||||
all_reduce_test_worker, all_gather_test_worker,
|
[
|
||||||
broadcast_tensor_dict_test_worker
|
send_recv_test_worker,
|
||||||
])
|
send_recv_tensor_dict_test_worker,
|
||||||
|
all_reduce_test_worker,
|
||||||
|
all_gather_test_worker,
|
||||||
|
broadcast_tensor_dict_test_worker,
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_multi_process_tensor_parallel_pipeline_parallel(
|
def test_multi_process_tensor_parallel_pipeline_parallel(
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
pp_size: int,
|
pp_size: int,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
|||||||
all workers in a node other than the head node, which can cause the test
|
all workers in a node other than the head node, which can cause the test
|
||||||
to fail.
|
to fail.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -56,7 +57,8 @@ class CPTestSettings:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Length mismatch: distributed_backends "
|
f"Length mismatch: distributed_backends "
|
||||||
f"({len(self.distributed_backends)}) != "
|
f"({len(self.distributed_backends)}) != "
|
||||||
f"vllm_major_versions ({len(self.vllm_major_versions)})")
|
f"vllm_major_versions ({len(self.vllm_major_versions)})"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def detailed(
|
def detailed(
|
||||||
@@ -74,29 +76,39 @@ class CPTestSettings:
|
|||||||
for dcp_multiplier in [0.5, 1]:
|
for dcp_multiplier in [0.5, 1]:
|
||||||
for chunked_prefill_val in [True]:
|
for chunked_prefill_val in [True]:
|
||||||
parallel_setups.append(
|
parallel_setups.append(
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(
|
||||||
pp_size=pp_multiplier * pp_base,
|
tp_size=tp_base,
|
||||||
dcp_size=int(dcp_multiplier *
|
pp_size=pp_multiplier * pp_base,
|
||||||
tp_base),
|
dcp_size=int(dcp_multiplier * tp_base),
|
||||||
eager_mode=eager_mode_val,
|
eager_mode=eager_mode_val,
|
||||||
chunked_prefill=chunked_prefill_val))
|
chunked_prefill=chunked_prefill_val,
|
||||||
|
)
|
||||||
|
)
|
||||||
return CPTestSettings(
|
return CPTestSettings(
|
||||||
parallel_setups=parallel_setups,
|
parallel_setups=parallel_setups,
|
||||||
distributed_backends=["mp"],
|
distributed_backends=["mp"],
|
||||||
vllm_major_versions=["1"],
|
vllm_major_versions=["1"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=CPTestOptions(multi_node_only=multi_node_only,
|
test_options=CPTestOptions(
|
||||||
load_format=load_format),
|
multi_node_only=multi_node_only, load_format=load_format
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_params(self, model_id: str):
|
def iter_params(self, model_id: str):
|
||||||
opts = self.test_options
|
opts = self.test_options
|
||||||
|
|
||||||
for parallel_setup in self.parallel_setups:
|
for parallel_setup in self.parallel_setups:
|
||||||
for backend, vllm_major_version in zip(self.distributed_backends,
|
for backend, vllm_major_version in zip(
|
||||||
self.vllm_major_versions):
|
self.distributed_backends, self.vllm_major_versions
|
||||||
yield (model_id, parallel_setup, backend, vllm_major_version,
|
):
|
||||||
self.runner, opts)
|
yield (
|
||||||
|
model_id,
|
||||||
|
parallel_setup,
|
||||||
|
backend,
|
||||||
|
vllm_major_version,
|
||||||
|
self.runner,
|
||||||
|
opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _compare_cp_with_tp(
|
def _compare_cp_with_tp(
|
||||||
@@ -148,8 +160,10 @@ def _compare_cp_with_tp(
|
|||||||
if num_gpus_available < tp_size * pp_size:
|
if num_gpus_available < tp_size * pp_size:
|
||||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
pytest.skip(
|
||||||
"multiprocessing distributed backend")
|
"Skipping multi-node pipeline parallel test for "
|
||||||
|
"multiprocessing distributed backend"
|
||||||
|
)
|
||||||
if multi_node_only and not VLLM_MULTI_NODE:
|
if multi_node_only and not VLLM_MULTI_NODE:
|
||||||
pytest.skip("Not in multi-node setting")
|
pytest.skip("Not in multi-node setting")
|
||||||
|
|
||||||
@@ -178,8 +192,7 @@ def _compare_cp_with_tp(
|
|||||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||||
|
|
||||||
cp_env = tp_env = {
|
cp_env = tp_env = {
|
||||||
"VLLM_USE_V1":
|
"VLLM_USE_V1": vllm_major_version, # Note(hc): DCP only support V1 engine only
|
||||||
vllm_major_version, # Note(hc): DCP only support V1 engine only
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cp_args = [
|
cp_args = [
|
||||||
@@ -205,13 +218,15 @@ def _compare_cp_with_tp(
|
|||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
compare_two_settings(model_id,
|
compare_two_settings(
|
||||||
cp_args,
|
model_id,
|
||||||
tp_args,
|
cp_args,
|
||||||
cp_env,
|
tp_args,
|
||||||
tp_env,
|
cp_env,
|
||||||
method=method,
|
tp_env,
|
||||||
max_wait_seconds=720)
|
method=method,
|
||||||
|
max_wait_seconds=720,
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
testing_ray_compiled_graph = cp_env is not None
|
testing_ray_compiled_graph = cp_env is not None
|
||||||
if testing_ray_compiled_graph and vllm_major_version == "0":
|
if testing_ray_compiled_graph and vllm_major_version == "0":
|
||||||
@@ -224,9 +239,10 @@ def _compare_cp_with_tp(
|
|||||||
|
|
||||||
CP_TEXT_GENERATION_MODELS = {
|
CP_TEXT_GENERATION_MODELS = {
|
||||||
# [MLA attention only]
|
# [MLA attention only]
|
||||||
"deepseek-ai/DeepSeek-V2-Lite-Chat":
|
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
||||||
[CPTestSettings.detailed(),
|
CPTestSettings.detailed(),
|
||||||
CPTestSettings.detailed(tp_base=2)],
|
CPTestSettings.detailed(tp_base=2),
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
CP_TEST_MODELS = [
|
CP_TEST_MODELS = [
|
||||||
@@ -237,11 +253,19 @@ CP_TEST_MODELS = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
(
|
||||||
"runner", "test_options"),
|
"model_id",
|
||||||
|
"parallel_setup",
|
||||||
|
"distributed_backend",
|
||||||
|
"vllm_major_version",
|
||||||
|
"runner",
|
||||||
|
"test_options",
|
||||||
|
),
|
||||||
[
|
[
|
||||||
params for model_id, settings in CP_TEXT_GENERATION_MODELS.items()
|
params
|
||||||
for setting in settings for params in setting.iter_params(model_id)
|
for model_id, settings in CP_TEXT_GENERATION_MODELS.items()
|
||||||
|
for setting in settings
|
||||||
|
for params in setting.iter_params(model_id)
|
||||||
if model_id in CP_TEST_MODELS
|
if model_id in CP_TEST_MODELS
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -255,12 +279,14 @@ def test_cp_generation(
|
|||||||
test_options: CPTestOptions,
|
test_options: CPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
_compare_cp_with_tp(model_id,
|
_compare_cp_with_tp(
|
||||||
parallel_setup,
|
model_id,
|
||||||
distributed_backend,
|
parallel_setup,
|
||||||
vllm_major_version,
|
distributed_backend,
|
||||||
runner,
|
vllm_major_version,
|
||||||
test_options,
|
runner,
|
||||||
num_gpus_available,
|
test_options,
|
||||||
method="generate",
|
num_gpus_available,
|
||||||
is_multimodal=False)
|
method="generate",
|
||||||
|
is_multimodal=False,
|
||||||
|
)
|
||||||
|
|||||||
@@ -8,12 +8,14 @@ import ray
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from vllm.distributed.communication_op import ( # noqa
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||||
tensor_model_parallel_all_reduce)
|
|
||||||
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
||||||
|
|
||||||
from ..utils import (ensure_model_parallel_initialized,
|
from ..utils import (
|
||||||
init_test_distributed_environment, multi_process_parallel)
|
ensure_model_parallel_initialized,
|
||||||
|
init_test_distributed_environment,
|
||||||
|
multi_process_parallel,
|
||||||
|
)
|
||||||
|
|
||||||
random.seed(42)
|
random.seed(42)
|
||||||
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
|
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
|
||||||
@@ -33,8 +35,7 @@ def graph_allreduce(
|
|||||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||||
group = get_tp_group().device_group
|
group = get_tp_group().device_group
|
||||||
|
|
||||||
@@ -60,18 +61,15 @@ def graph_allreduce(
|
|||||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||||
with graph_capture(device=device) as graph_capture_context:
|
with graph_capture(device=device) as graph_capture_context:
|
||||||
# use integers so result matches NCCL exactly
|
# use integers so result matches NCCL exactly
|
||||||
inp1 = torch.randint(1,
|
inp1 = torch.randint(
|
||||||
16, (sz, ),
|
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||||
dtype=dtype,
|
)
|
||||||
device=torch.cuda.current_device())
|
inp2 = torch.randint(
|
||||||
inp2 = torch.randint(1,
|
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||||
16, (sz, ),
|
)
|
||||||
dtype=dtype,
|
|
||||||
device=torch.cuda.current_device())
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(graph,
|
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
|
||||||
stream=graph_capture_context.stream):
|
|
||||||
for i in range(num_communication):
|
for i in range(num_communication):
|
||||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||||
# the input buffer is immediately modified to test
|
# the input buffer is immediately modified to test
|
||||||
@@ -96,8 +94,7 @@ def eager_allreduce(
|
|||||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
|
|
||||||
# we use the first group to communicate once
|
# we use the first group to communicate once
|
||||||
# and the second group to communicate twice
|
# and the second group to communicate twice
|
||||||
@@ -132,5 +129,4 @@ def test_custom_allreduce(
|
|||||||
world_size = tp_size * pipeline_parallel_size
|
world_size = tp_size * pipeline_parallel_size
|
||||||
if world_size > torch.cuda.device_count():
|
if world_size > torch.cuda.device_count():
|
||||||
pytest.skip("Not enough GPUs to run the test.")
|
pytest.skip("Not enough GPUs to run the test.")
|
||||||
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size,
|
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
|
||||||
test_target)
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from ..entrypoints.openai.test_oot_registration import (
|
from ..entrypoints.openai.test_oot_registration import run_and_test_dummy_opt_api_server
|
||||||
run_and_test_dummy_opt_api_server)
|
|
||||||
|
|
||||||
|
|
||||||
def test_distributed_oot(dummy_opt_path: str):
|
def test_distributed_oot(dummy_opt_path: str):
|
||||||
|
|||||||
@@ -10,10 +10,12 @@ from vllm.distributed.eplb.rebalance_algo import rebalance_experts
|
|||||||
def test_basic_rebalance():
|
def test_basic_rebalance():
|
||||||
"""Test basic rebalancing functionality"""
|
"""Test basic rebalancing functionality"""
|
||||||
# Example from https://github.com/deepseek-ai/eplb
|
# Example from https://github.com/deepseek-ai/eplb
|
||||||
weight = torch.tensor([
|
weight = torch.tensor(
|
||||||
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
[
|
||||||
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
||||||
])
|
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
num_layers = weight.shape[0]
|
num_layers = weight.shape[0]
|
||||||
num_replicas = 16
|
num_replicas = 16
|
||||||
@@ -21,45 +23,49 @@ def test_basic_rebalance():
|
|||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 8
|
num_gpus = 8
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Verify output shapes
|
# Verify output shapes
|
||||||
assert phy2log.shape == (
|
assert phy2log.shape == (
|
||||||
2,
|
2,
|
||||||
16,
|
16,
|
||||||
), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}"
|
), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}"
|
||||||
assert (log2phy.shape[0] == 2
|
assert log2phy.shape[0] == 2, (
|
||||||
), f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}"
|
f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}"
|
||||||
assert (
|
)
|
||||||
log2phy.shape[1] == 12
|
assert log2phy.shape[1] == 12, (
|
||||||
), f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}"
|
f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}"
|
||||||
|
)
|
||||||
assert logcnt.shape == (
|
assert logcnt.shape == (
|
||||||
2,
|
2,
|
||||||
12,
|
12,
|
||||||
), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}"
|
), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}"
|
||||||
|
|
||||||
# Verify physical to logical expert mapping range is correct
|
# Verify physical to logical expert mapping range is correct
|
||||||
assert torch.all(phy2log >= 0) and torch.all(
|
assert torch.all(phy2log >= 0) and torch.all(phy2log < 12), (
|
||||||
phy2log < 12), "Physical to logical mapping should be in range [0, 12)"
|
"Physical to logical mapping should be in range [0, 12)"
|
||||||
|
)
|
||||||
|
|
||||||
# Verify expert count reasonableness
|
# Verify expert count reasonableness
|
||||||
assert torch.all(
|
assert torch.all(logcnt >= 1), "Each logical expert should have at least 1 replica"
|
||||||
logcnt >= 1), "Each logical expert should have at least 1 replica"
|
assert torch.sum(logcnt, dim=1).sum() == num_replicas * num_layers, (
|
||||||
assert (
|
f"Total replicas should be {num_replicas * num_layers}"
|
||||||
torch.sum(logcnt, dim=1).sum() == num_replicas *
|
)
|
||||||
num_layers), f"Total replicas should be {num_replicas * num_layers}"
|
|
||||||
|
|
||||||
# Verify expected output
|
# Verify expected output
|
||||||
expected_phy2log = torch.tensor([
|
expected_phy2log = torch.tensor(
|
||||||
[5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1],
|
[
|
||||||
[7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1],
|
[5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1],
|
||||||
])
|
[7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
assert torch.all(phy2log == expected_phy2log)
|
assert torch.all(phy2log == expected_phy2log)
|
||||||
|
|
||||||
expected_logcnt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1],
|
expected_logcnt = torch.tensor(
|
||||||
[1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]])
|
[[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]]
|
||||||
|
)
|
||||||
assert torch.all(logcnt == expected_logcnt)
|
assert torch.all(logcnt == expected_logcnt)
|
||||||
|
|
||||||
|
|
||||||
@@ -71,9 +77,9 @@ def test_single_gpu_case():
|
|||||||
num_nodes = 1
|
num_nodes = 1
|
||||||
num_gpus = 1
|
num_gpus = 1
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Verify shapes
|
# Verify shapes
|
||||||
assert phy2log.shape == (1, 4)
|
assert phy2log.shape == (1, 4)
|
||||||
@@ -93,19 +99,19 @@ def test_equal_weights():
|
|||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 4
|
num_gpus = 4
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Verify shapes
|
# Verify shapes
|
||||||
assert phy2log.shape == (1, 8)
|
assert phy2log.shape == (1, 8)
|
||||||
assert logcnt.shape == (1, 8)
|
assert logcnt.shape == (1, 8)
|
||||||
|
|
||||||
# With equal weights, each expert should have exactly one replica
|
# With equal weights, each expert should have exactly one replica
|
||||||
assert torch.all(
|
assert torch.all(logcnt == 1), (
|
||||||
logcnt == 1
|
"With equal weights and no replication, "
|
||||||
), "With equal weights and no replication, " \
|
"each expert should have exactly 1 replica"
|
||||||
"each expert should have exactly 1 replica"
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extreme_weight_imbalance():
|
def test_extreme_weight_imbalance():
|
||||||
@@ -116,35 +122,37 @@ def test_extreme_weight_imbalance():
|
|||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 4
|
num_gpus = 4
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Verify shapes
|
# Verify shapes
|
||||||
assert phy2log.shape == (1, 12)
|
assert phy2log.shape == (1, 12)
|
||||||
assert logcnt.shape == (1, 8)
|
assert logcnt.shape == (1, 8)
|
||||||
|
|
||||||
# Expert with highest weight (index 0) should have more replicas
|
# Expert with highest weight (index 0) should have more replicas
|
||||||
assert (
|
assert logcnt[0, 0] > logcnt[0, 1], (
|
||||||
logcnt[0, 0]
|
"Expert with highest weight should have more replicas"
|
||||||
> logcnt[0, 1]), "Expert with highest weight should have more replicas"
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_layers():
|
def test_multiple_layers():
|
||||||
"""Test multiple layers case"""
|
"""Test multiple layers case"""
|
||||||
weight = torch.tensor([
|
weight = torch.tensor(
|
||||||
[10, 20, 30, 40, 50, 60], # First layer
|
[
|
||||||
[60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern)
|
[10, 20, 30, 40, 50, 60], # First layer
|
||||||
[25, 25, 25, 25, 25, 25], # Third layer (equal weights)
|
[60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern)
|
||||||
])
|
[25, 25, 25, 25, 25, 25], # Third layer (equal weights)
|
||||||
|
]
|
||||||
|
)
|
||||||
num_replicas = 8
|
num_replicas = 8
|
||||||
num_groups = 2
|
num_groups = 2
|
||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 4
|
num_gpus = 4
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Verify shapes
|
# Verify shapes
|
||||||
assert phy2log.shape == (3, 8)
|
assert phy2log.shape == (3, 8)
|
||||||
@@ -152,12 +160,12 @@ def test_multiple_layers():
|
|||||||
|
|
||||||
# Verify expert allocation is reasonable for each layer
|
# Verify expert allocation is reasonable for each layer
|
||||||
for layer in range(3):
|
for layer in range(3):
|
||||||
assert torch.all(phy2log[layer] >= 0) and torch.all(
|
assert torch.all(phy2log[layer] >= 0) and torch.all(phy2log[layer] < 6), (
|
||||||
phy2log[layer] < 6
|
f"Layer {layer} physical to logical mappingshould be in range [0, 6)"
|
||||||
), f"Layer {layer} physical to logical mapping" \
|
)
|
||||||
"should be in range [0, 6)"
|
assert torch.sum(logcnt[layer]) == num_replicas, (
|
||||||
assert (torch.sum(logcnt[layer]) == num_replicas
|
f"Layer {layer} total replicas should be {num_replicas}"
|
||||||
), f"Layer {layer} total replicas should be {num_replicas}"
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_parameter_validation():
|
def test_parameter_validation():
|
||||||
@@ -179,17 +187,19 @@ def test_parameter_validation():
|
|||||||
|
|
||||||
def test_small_scale_hierarchical():
|
def test_small_scale_hierarchical():
|
||||||
"""Test small-scale hierarchical load balancing"""
|
"""Test small-scale hierarchical load balancing"""
|
||||||
weight = torch.tensor([
|
weight = torch.tensor(
|
||||||
[100, 50, 200, 75, 150, 25, 300, 80], # 8 experts
|
[
|
||||||
])
|
[100, 50, 200, 75, 150, 25, 300, 80], # 8 experts
|
||||||
|
]
|
||||||
|
)
|
||||||
num_replicas = 12
|
num_replicas = 12
|
||||||
num_groups = 4 # 4 groups, 2 experts each
|
num_groups = 4 # 4 groups, 2 experts each
|
||||||
num_nodes = 2 # 2 nodes
|
num_nodes = 2 # 2 nodes
|
||||||
num_gpus = 4 # 4 GPUs
|
num_gpus = 4 # 4 GPUs
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Verify basic constraints
|
# Verify basic constraints
|
||||||
assert phy2log.shape == (1, 12)
|
assert phy2log.shape == (1, 12)
|
||||||
@@ -199,8 +209,9 @@ def test_small_scale_hierarchical():
|
|||||||
|
|
||||||
# Expert with highest weight should have more replicas
|
# Expert with highest weight should have more replicas
|
||||||
max_weight_expert = torch.argmax(weight[0])
|
max_weight_expert = torch.argmax(weight[0])
|
||||||
assert (logcnt[0, max_weight_expert]
|
assert logcnt[0, max_weight_expert] >= 2, (
|
||||||
>= 2), "Highest weight expert should have multiple replicas"
|
"Highest weight expert should have multiple replicas"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_global_load_balance_fallback():
|
def test_global_load_balance_fallback():
|
||||||
@@ -213,9 +224,9 @@ def test_global_load_balance_fallback():
|
|||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 4
|
num_gpus = 4
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Should work normally, just using global load balancing strategy
|
# Should work normally, just using global load balancing strategy
|
||||||
assert phy2log.shape == (1, 8)
|
assert phy2log.shape == (1, 8)
|
||||||
@@ -235,9 +246,9 @@ def test_device_compatibility(device):
|
|||||||
num_nodes = 1
|
num_nodes = 1
|
||||||
num_gpus = 2
|
num_gpus = 2
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
|
|
||||||
# Function will convert to CPU internally, but should handle different
|
# Function will convert to CPU internally, but should handle different
|
||||||
# device inputs normally
|
# device inputs normally
|
||||||
@@ -250,7 +261,8 @@ def test_additional_cases():
|
|||||||
|
|
||||||
# Test case 1: Large-scale distributed setup
|
# Test case 1: Large-scale distributed setup
|
||||||
weight1 = torch.tensor(
|
weight1 = torch.tensor(
|
||||||
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]])
|
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
|
||||||
|
)
|
||||||
phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8)
|
phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8)
|
||||||
|
|
||||||
assert phy2log1.shape == (1, 24)
|
assert phy2log1.shape == (1, 24)
|
||||||
@@ -258,10 +270,12 @@ def test_additional_cases():
|
|||||||
assert torch.sum(logcnt1) == 24
|
assert torch.sum(logcnt1) == 24
|
||||||
|
|
||||||
# Test case 2: Different weight distributions
|
# Test case 2: Different weight distributions
|
||||||
weight2 = torch.tensor([
|
weight2 = torch.tensor(
|
||||||
[200, 150, 100, 50, 25, 12], # Decreasing weights
|
[
|
||||||
[12, 25, 50, 100, 150, 200], # Increasing weights
|
[200, 150, 100, 50, 25, 12], # Decreasing weights
|
||||||
])
|
[12, 25, 50, 100, 150, 200], # Increasing weights
|
||||||
|
]
|
||||||
|
)
|
||||||
phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2)
|
phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2)
|
||||||
|
|
||||||
assert phy2log2.shape == (2, 10)
|
assert phy2log2.shape == (2, 10)
|
||||||
@@ -274,19 +288,21 @@ def test_additional_cases():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
weight = torch.tensor([
|
weight = torch.tensor(
|
||||||
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
[
|
||||||
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
||||||
])
|
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
num_replicas = 16
|
num_replicas = 16
|
||||||
num_groups = 4
|
num_groups = 4
|
||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 8
|
num_gpus = 8
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
phy2log, log2phy, logcnt = rebalance_experts(
|
||||||
num_groups, num_nodes,
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
num_gpus)
|
)
|
||||||
print(phy2log)
|
print(phy2log)
|
||||||
|
|
||||||
test_basic_rebalance()
|
test_basic_rebalance()
|
||||||
|
|||||||
@@ -9,11 +9,12 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from vllm.distributed.eplb.rebalance_execute import (
|
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||||
rearrange_expert_weights_inplace)
|
from vllm.distributed.parallel_state import (
|
||||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
ensure_model_parallel_initialized,
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
init_distributed_environment)
|
init_distributed_environment,
|
||||||
|
)
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
|
|
||||||
@@ -22,13 +23,13 @@ def distributed_run(fn, world_size):
|
|||||||
processes: list[multiprocessing.Process] = []
|
processes: list[multiprocessing.Process] = []
|
||||||
for i in range(number_of_processes):
|
for i in range(number_of_processes):
|
||||||
env: dict[str, str] = {}
|
env: dict[str, str] = {}
|
||||||
env['RANK'] = str(i)
|
env["RANK"] = str(i)
|
||||||
env['LOCAL_RANK'] = str(i)
|
env["LOCAL_RANK"] = str(i)
|
||||||
env['WORLD_SIZE'] = str(number_of_processes)
|
env["WORLD_SIZE"] = str(number_of_processes)
|
||||||
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
|
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||||
env['MASTER_ADDR'] = 'localhost'
|
env["MASTER_ADDR"] = "localhost"
|
||||||
env['MASTER_PORT'] = '12345'
|
env["MASTER_PORT"] = "12345"
|
||||||
p = multiprocessing.Process(target=fn, args=(env, ))
|
p = multiprocessing.Process(target=fn, args=(env,))
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
@@ -45,7 +46,7 @@ def worker_fn_wrapper(fn):
|
|||||||
# and update the environment variables in the function
|
# and update the environment variables in the function
|
||||||
def wrapped_fn(env):
|
def wrapped_fn(env):
|
||||||
update_environment_variables(env)
|
update_environment_variables(env)
|
||||||
local_rank = os.environ['LOCAL_RANK']
|
local_rank = os.environ["LOCAL_RANK"]
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
@@ -60,20 +61,20 @@ def worker_fn_wrapper(fn):
|
|||||||
|
|
||||||
|
|
||||||
def create_expert_indices_with_redundancy(
|
def create_expert_indices_with_redundancy(
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
num_logical_experts: int,
|
num_logical_experts: int,
|
||||||
total_physical_experts: int,
|
total_physical_experts: int,
|
||||||
redundancy_config: list[int], # redundancy for each logical expert
|
redundancy_config: list[int], # redundancy for each logical expert
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Create expert indices with redundancy.
|
Create expert indices with redundancy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_layers: number of layers
|
num_layers: number of layers
|
||||||
num_logical_experts: number of logical experts
|
num_logical_experts: number of logical experts
|
||||||
total_physical_experts: total number of physical experts
|
total_physical_experts: total number of physical experts
|
||||||
redundancy_config: redundancy for each logical expert
|
redundancy_config: redundancy for each logical expert
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
indices: Shape (num_layers, total_physical_experts)
|
indices: Shape (num_layers, total_physical_experts)
|
||||||
"""
|
"""
|
||||||
@@ -106,11 +107,11 @@ def create_expert_weights(
|
|||||||
) -> list[list[torch.Tensor]]:
|
) -> list[list[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Create fake expert weights tensor for testing.
|
Create fake expert weights tensor for testing.
|
||||||
|
|
||||||
Use `arange` to generate predictable weights values, based on logical
|
Use `arange` to generate predictable weights values, based on logical
|
||||||
expert ID.
|
expert ID.
|
||||||
All replicas of the same logical expert should have the same weights.
|
All replicas of the same logical expert should have the same weights.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
physical_to_logical_mapping: Shape (num_layers, num_local_experts)
|
physical_to_logical_mapping: Shape (num_layers, num_local_experts)
|
||||||
mapping[layer, physical_pos] = logical_expert_id
|
mapping[layer, physical_pos] = logical_expert_id
|
||||||
@@ -120,27 +121,27 @@ def create_expert_weights(
|
|||||||
for layer in range(num_layers):
|
for layer in range(num_layers):
|
||||||
layer_weights = []
|
layer_weights = []
|
||||||
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
||||||
weight_tensor = torch.zeros(num_local_experts,
|
weight_tensor = torch.zeros(
|
||||||
hidden_size,
|
num_local_experts, hidden_size, device=device, dtype=torch.float32
|
||||||
device=device,
|
)
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
for local_expert in range(num_local_experts):
|
for local_expert in range(num_local_experts):
|
||||||
# Get the logical expert ID for this physical expert
|
# Get the logical expert ID for this physical expert
|
||||||
global_pos = rank * num_local_experts + local_expert
|
global_pos = rank * num_local_experts + local_expert
|
||||||
logical_expert_id = physical_to_logical_mapping[
|
logical_expert_id = physical_to_logical_mapping[
|
||||||
layer, global_pos].item()
|
layer, global_pos
|
||||||
|
].item()
|
||||||
|
|
||||||
# Generate weights based on logical expert ID
|
# Generate weights based on logical expert ID
|
||||||
# (so that all replicas of the same logical expert have the
|
# (so that all replicas of the same logical expert have the
|
||||||
# same weights)
|
# same weights)
|
||||||
base_value = (logical_expert_id * 1000 + layer * 100 +
|
base_value = logical_expert_id * 1000 + layer * 100 + weight_idx * 10
|
||||||
weight_idx * 10)
|
weight_tensor[local_expert] = torch.arange(
|
||||||
weight_tensor[local_expert] = torch.arange(base_value,
|
base_value,
|
||||||
base_value +
|
base_value + hidden_size,
|
||||||
hidden_size,
|
device=device,
|
||||||
device=device,
|
dtype=torch.float32,
|
||||||
dtype=torch.float32)
|
)
|
||||||
|
|
||||||
layer_weights.append(weight_tensor)
|
layer_weights.append(weight_tensor)
|
||||||
expert_weights.append(layer_weights)
|
expert_weights.append(layer_weights)
|
||||||
@@ -182,12 +183,15 @@ def verify_expert_weights_after_shuffle(
|
|||||||
|
|
||||||
# Check if the weights are correct
|
# Check if the weights are correct
|
||||||
actual_weights = weight_tensor[local_expert]
|
actual_weights = weight_tensor[local_expert]
|
||||||
expected_base = (expected_logical_expert * 1000 + layer * 100 +
|
expected_base = (
|
||||||
weight_idx * 10)
|
expected_logical_expert * 1000 + layer * 100 + weight_idx * 10
|
||||||
expected_weights = torch.arange(expected_base,
|
)
|
||||||
expected_base + hidden_size,
|
expected_weights = torch.arange(
|
||||||
device=actual_weights.device,
|
expected_base,
|
||||||
dtype=actual_weights.dtype)
|
expected_base + hidden_size,
|
||||||
|
device=actual_weights.device,
|
||||||
|
dtype=actual_weights.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
actual_weights,
|
actual_weights,
|
||||||
@@ -195,7 +199,8 @@ def verify_expert_weights_after_shuffle(
|
|||||||
msg=f"Layer {layer}, weight {weight_idx},"
|
msg=f"Layer {layer}, weight {weight_idx},"
|
||||||
f"local expert {local_expert}: "
|
f"local expert {local_expert}: "
|
||||||
f"weights do not match. "
|
f"weights do not match. "
|
||||||
f"Expected logical expert {expected_logical_expert}")
|
f"Expected logical expert {expected_logical_expert}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def verify_redundant_experts_have_same_weights(
|
def verify_redundant_experts_have_same_weights(
|
||||||
@@ -222,23 +227,23 @@ def verify_redundant_experts_have_same_weights(
|
|||||||
total_physical_experts,
|
total_physical_experts,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
device=expert_weights[layer][weight_idx].device,
|
device=expert_weights[layer][weight_idx].device,
|
||||||
dtype=expert_weights[layer][weight_idx].dtype)
|
dtype=expert_weights[layer][weight_idx].dtype,
|
||||||
|
)
|
||||||
|
|
||||||
# Use all_gather to collect expert weights from current node
|
# Use all_gather to collect expert weights from current node
|
||||||
# expert_weights[layer][weight_idx] shape:
|
# expert_weights[layer][weight_idx] shape:
|
||||||
# [num_local_experts, hidden_size]
|
# [num_local_experts, hidden_size]
|
||||||
local_weights = expert_weights[layer][
|
local_weights = expert_weights[layer][
|
||||||
weight_idx] # [num_local_experts, hidden_size]
|
weight_idx
|
||||||
|
] # [num_local_experts, hidden_size]
|
||||||
|
|
||||||
# Split tensor along dim 0 into a list for all_gather
|
# Split tensor along dim 0 into a list for all_gather
|
||||||
gathered_weights_list = torch.chunk(gathered_weights,
|
gathered_weights_list = torch.chunk(gathered_weights, world_size, dim=0)
|
||||||
world_size,
|
|
||||||
dim=0)
|
|
||||||
|
|
||||||
torch.distributed.all_gather(
|
torch.distributed.all_gather(
|
||||||
# Output list: each element corresponds to one rank's weights
|
# Output list: each element corresponds to one rank's weights
|
||||||
list(gathered_weights_list),
|
list(gathered_weights_list),
|
||||||
local_weights # Input: current rank's local weights
|
local_weights, # Input: current rank's local weights
|
||||||
)
|
)
|
||||||
|
|
||||||
all_weights.append(gathered_weights)
|
all_weights.append(gathered_weights)
|
||||||
@@ -266,7 +271,8 @@ def verify_redundant_experts_have_same_weights(
|
|||||||
msg=f"Layer {layer}, weight {weight_idx},"
|
msg=f"Layer {layer}, weight {weight_idx},"
|
||||||
f"logical expert {logical_expert_id}: "
|
f"logical expert {logical_expert_id}: "
|
||||||
f"Physical expert {physical_pos} has different weights"
|
f"Physical expert {physical_pos} has different weights"
|
||||||
f"than expected")
|
f"than expected",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -290,10 +296,11 @@ def verify_redundant_experts_have_same_weights(
|
|||||||
# 4 GPU, 8 experts per GPU
|
# 4 GPU, 8 experts per GPU
|
||||||
# 16 logical experts, 32 physical experts, 16 redundant experts
|
# 16 logical experts, 32 physical experts, 16 redundant experts
|
||||||
(4, 8, 8, 16),
|
(4, 8, 8, 16),
|
||||||
])
|
],
|
||||||
def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
)
|
||||||
num_local_experts,
|
def test_rearrange_expert_weights_with_redundancy(
|
||||||
num_logical_experts):
|
world_size, num_layers, num_local_experts, num_logical_experts
|
||||||
|
):
|
||||||
"""Test the functionality of rearranging expert weights with redundancy."""
|
"""Test the functionality of rearranging expert weights with redundancy."""
|
||||||
|
|
||||||
if torch.cuda.device_count() < world_size:
|
if torch.cuda.device_count() < world_size:
|
||||||
@@ -304,8 +311,8 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
|||||||
# Initialize model parallel (using tensor parallel as an entrypoint
|
# Initialize model parallel (using tensor parallel as an entrypoint
|
||||||
# to expert parallel)
|
# to expert parallel)
|
||||||
ensure_model_parallel_initialized(
|
ensure_model_parallel_initialized(
|
||||||
tensor_model_parallel_size=world_size,
|
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||||
pipeline_model_parallel_size=1)
|
)
|
||||||
|
|
||||||
ep_group = get_tp_group().cpu_group
|
ep_group = get_tp_group().cpu_group
|
||||||
ep_rank = torch.distributed.get_rank()
|
ep_rank = torch.distributed.get_rank()
|
||||||
@@ -316,8 +323,9 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
|||||||
hidden_sizes = [32, 64] # Two different weight matrices
|
hidden_sizes = [32, 64] # Two different weight matrices
|
||||||
|
|
||||||
# Create old expert indices (with redundancy)
|
# Create old expert indices (with redundancy)
|
||||||
redundancy_config = create_redundancy_config(num_logical_experts,
|
redundancy_config = create_redundancy_config(
|
||||||
total_physical_experts)
|
num_logical_experts, total_physical_experts
|
||||||
|
)
|
||||||
|
|
||||||
old_indices = create_expert_indices_with_redundancy(
|
old_indices = create_expert_indices_with_redundancy(
|
||||||
num_layers,
|
num_layers,
|
||||||
@@ -328,7 +336,8 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
|||||||
|
|
||||||
# Create new expert indices (with redundancy)
|
# Create new expert indices (with redundancy)
|
||||||
new_redundancy_config = create_redundancy_config(
|
new_redundancy_config = create_redundancy_config(
|
||||||
num_logical_experts, total_physical_experts)
|
num_logical_experts, total_physical_experts
|
||||||
|
)
|
||||||
new_indices = create_expert_indices_with_redundancy(
|
new_indices = create_expert_indices_with_redundancy(
|
||||||
num_layers,
|
num_layers,
|
||||||
num_logical_experts,
|
num_logical_experts,
|
||||||
@@ -337,9 +346,9 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create expert weights
|
# Create expert weights
|
||||||
expert_weights = create_expert_weights(num_layers, num_local_experts,
|
expert_weights = create_expert_weights(
|
||||||
hidden_sizes, ep_rank, device,
|
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||||
old_indices)
|
)
|
||||||
|
|
||||||
# Execute weight rearrangement
|
# Execute weight rearrangement
|
||||||
rearrange_expert_weights_inplace(
|
rearrange_expert_weights_inplace(
|
||||||
@@ -383,8 +392,8 @@ def test_rearrange_expert_weights_no_change(world_size):
|
|||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def worker_fn():
|
def worker_fn():
|
||||||
ensure_model_parallel_initialized(
|
ensure_model_parallel_initialized(
|
||||||
tensor_model_parallel_size=world_size,
|
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||||
pipeline_model_parallel_size=1)
|
)
|
||||||
|
|
||||||
ep_group = get_tp_group().cpu_group
|
ep_group = get_tp_group().cpu_group
|
||||||
ep_rank = torch.distributed.get_rank()
|
ep_rank = torch.distributed.get_rank()
|
||||||
@@ -401,12 +410,12 @@ def test_rearrange_expert_weights_no_change(world_size):
|
|||||||
|
|
||||||
# Same indices - no change
|
# Same indices - no change
|
||||||
indices = create_expert_indices_with_redundancy(
|
indices = create_expert_indices_with_redundancy(
|
||||||
num_layers, num_logical_experts, total_physical_experts,
|
num_layers, num_logical_experts, total_physical_experts, redundancy_config
|
||||||
redundancy_config)
|
)
|
||||||
|
|
||||||
expert_weights = create_expert_weights(num_layers, num_local_experts,
|
expert_weights = create_expert_weights(
|
||||||
hidden_sizes, ep_rank, device,
|
num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
|
||||||
indices)
|
)
|
||||||
|
|
||||||
# Save original weights
|
# Save original weights
|
||||||
original_weights = []
|
original_weights = []
|
||||||
@@ -422,7 +431,8 @@ def test_rearrange_expert_weights_no_change(world_size):
|
|||||||
indices, # Same indices
|
indices, # Same indices
|
||||||
expert_weights,
|
expert_weights,
|
||||||
ep_group,
|
ep_group,
|
||||||
is_profile=False)
|
is_profile=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Verify that the weights have not changed
|
# Verify that the weights have not changed
|
||||||
for layer in range(num_layers):
|
for layer in range(num_layers):
|
||||||
@@ -430,8 +440,8 @@ def test_rearrange_expert_weights_no_change(world_size):
|
|||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
expert_weights[layer][weight_idx],
|
expert_weights[layer][weight_idx],
|
||||||
original_weights[layer][weight_idx],
|
original_weights[layer][weight_idx],
|
||||||
msg=f"Layer {layer}, weight {weight_idx} should remain "
|
msg=f"Layer {layer}, weight {weight_idx} should remain unchanged",
|
||||||
f"unchanged")
|
)
|
||||||
|
|
||||||
distributed_run(worker_fn, world_size)
|
distributed_run(worker_fn, world_size)
|
||||||
|
|
||||||
@@ -446,8 +456,8 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
|||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def worker_fn():
|
def worker_fn():
|
||||||
ensure_model_parallel_initialized(
|
ensure_model_parallel_initialized(
|
||||||
tensor_model_parallel_size=world_size,
|
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||||
pipeline_model_parallel_size=1)
|
)
|
||||||
|
|
||||||
ep_group = get_tp_group().cpu_group
|
ep_group = get_tp_group().cpu_group
|
||||||
ep_rank = torch.distributed.get_rank()
|
ep_rank = torch.distributed.get_rank()
|
||||||
@@ -460,21 +470,23 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
|||||||
hidden_sizes = [32]
|
hidden_sizes = [32]
|
||||||
|
|
||||||
# Create different index distributions
|
# Create different index distributions
|
||||||
old_redundancy = create_redundancy_config(num_logical_experts,
|
old_redundancy = create_redundancy_config(
|
||||||
total_physical_experts)
|
num_logical_experts, total_physical_experts
|
||||||
new_redundancy = create_redundancy_config(num_logical_experts,
|
)
|
||||||
total_physical_experts)
|
new_redundancy = create_redundancy_config(
|
||||||
|
num_logical_experts, total_physical_experts
|
||||||
|
)
|
||||||
|
|
||||||
old_indices = create_expert_indices_with_redundancy(
|
old_indices = create_expert_indices_with_redundancy(
|
||||||
num_layers, num_logical_experts, total_physical_experts,
|
num_layers, num_logical_experts, total_physical_experts, old_redundancy
|
||||||
old_redundancy)
|
)
|
||||||
new_indices = create_expert_indices_with_redundancy(
|
new_indices = create_expert_indices_with_redundancy(
|
||||||
num_layers, num_logical_experts, total_physical_experts,
|
num_layers, num_logical_experts, total_physical_experts, new_redundancy
|
||||||
new_redundancy)
|
)
|
||||||
|
|
||||||
expert_weights = create_expert_weights(num_layers, num_local_experts,
|
expert_weights = create_expert_weights(
|
||||||
hidden_sizes, ep_rank, device,
|
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||||
old_indices)
|
)
|
||||||
|
|
||||||
# Save original weights
|
# Save original weights
|
||||||
original_weights = []
|
original_weights = []
|
||||||
@@ -490,7 +502,7 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
|||||||
new_indices,
|
new_indices,
|
||||||
expert_weights,
|
expert_weights,
|
||||||
ep_group,
|
ep_group,
|
||||||
is_profile=True # Profile mode
|
is_profile=True, # Profile mode
|
||||||
)
|
)
|
||||||
|
|
||||||
# In profile mode, the weights should remain unchanged
|
# In profile mode, the weights should remain unchanged
|
||||||
@@ -499,6 +511,7 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
|||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
expert_weights[layer][weight_idx],
|
expert_weights[layer][weight_idx],
|
||||||
original_weights[layer][weight_idx],
|
original_weights[layer][weight_idx],
|
||||||
msg="In profile mode, the weights should remain unchanged")
|
msg="In profile mode, the weights should remain unchanged",
|
||||||
|
)
|
||||||
|
|
||||||
distributed_run(worker_fn, world_size)
|
distributed_run(worker_fn, world_size)
|
||||||
|
|||||||
@@ -6,24 +6,29 @@ import time
|
|||||||
import msgspec
|
import msgspec
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
|
from vllm.distributed.kv_events import (
|
||||||
NullEventPublisher)
|
EventBatch,
|
||||||
|
EventPublisherFactory,
|
||||||
|
NullEventPublisher,
|
||||||
|
)
|
||||||
|
|
||||||
DP_RANK = 0
|
DP_RANK = 0
|
||||||
|
|
||||||
|
|
||||||
class EventSample(
|
class EventSample(
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
tag=True, # type: ignore
|
tag=True, # type: ignore
|
||||||
array_like=True # type: ignore
|
array_like=True, # type: ignore
|
||||||
):
|
):
|
||||||
"""Test event for publisher testing"""
|
"""Test event for publisher testing"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
value: str
|
value: str
|
||||||
|
|
||||||
|
|
||||||
class SampleBatch(EventBatch):
|
class SampleBatch(EventBatch):
|
||||||
"""Test event batch for publisher testing"""
|
"""Test event batch for publisher testing"""
|
||||||
|
|
||||||
events: list[EventSample]
|
events: list[EventSample]
|
||||||
|
|
||||||
|
|
||||||
@@ -44,10 +49,8 @@ def test_basic_publishing(publisher, subscriber):
|
|||||||
|
|
||||||
seq, received = result
|
seq, received = result
|
||||||
assert seq == 0, "Sequence number mismatch"
|
assert seq == 0, "Sequence number mismatch"
|
||||||
assert received.ts == pytest.approx(test_batch.ts,
|
assert received.ts == pytest.approx(test_batch.ts, abs=0.1), "Timestamp mismatch"
|
||||||
abs=0.1), ("Timestamp mismatch")
|
assert len(received.events) == len(test_batch.events), "Number of events mismatch"
|
||||||
assert len(received.events) == len(
|
|
||||||
test_batch.events), ("Number of events mismatch")
|
|
||||||
|
|
||||||
for i, event in enumerate(received.events):
|
for i, event in enumerate(received.events):
|
||||||
assert event.id == i, "Event id mismatch"
|
assert event.id == i, "Event id mismatch"
|
||||||
@@ -88,9 +91,9 @@ def test_replay_mechanism(publisher, subscriber):
|
|||||||
assert len(replayed) > 0, "No replayed messages received"
|
assert len(replayed) > 0, "No replayed messages received"
|
||||||
seqs = [seq for seq, _ in replayed]
|
seqs = [seq for seq, _ in replayed]
|
||||||
assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
|
assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
|
||||||
assert seqs == list(range(min(seqs),
|
assert seqs == list(range(min(seqs), max(seqs) + 1)), (
|
||||||
max(seqs) +
|
"Replayed messages not consecutive"
|
||||||
1)), ("Replayed messages not consecutive")
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_buffer_limit(publisher, subscriber, publisher_config):
|
def test_buffer_limit(publisher, subscriber, publisher_config):
|
||||||
@@ -126,6 +129,7 @@ def test_topic_filtering(publisher_config):
|
|||||||
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
|
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
|
||||||
|
|
||||||
from .conftest import MockSubscriber
|
from .conftest import MockSubscriber
|
||||||
|
|
||||||
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
|
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
|
||||||
sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar")
|
sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar")
|
||||||
|
|
||||||
@@ -137,11 +141,13 @@ def test_topic_filtering(publisher_config):
|
|||||||
|
|
||||||
foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
|
foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
|
||||||
assert all(msg is not None for msg in foo_received), (
|
assert all(msg is not None for msg in foo_received), (
|
||||||
"Subscriber with matching topic should receive messages")
|
"Subscriber with matching topic should receive messages"
|
||||||
|
)
|
||||||
|
|
||||||
bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
|
bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
|
||||||
assert all(msg is None for msg in bar_received), (
|
assert all(msg is None for msg in bar_received), (
|
||||||
"Subscriber with non-matching topic should receive no messages")
|
"Subscriber with non-matching topic should receive no messages"
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
pub.shutdown()
|
pub.shutdown()
|
||||||
sub_foo.close()
|
sub_foo.close()
|
||||||
@@ -178,8 +184,7 @@ def test_high_volume(publisher, subscriber):
|
|||||||
|
|
||||||
publisher_thread.join()
|
publisher_thread.join()
|
||||||
|
|
||||||
assert len(received) >= num_batches * 0.9, (
|
assert len(received) >= num_batches * 0.9, "We should have received most messages"
|
||||||
"We should have received most messages")
|
|
||||||
|
|
||||||
seqs = [seq for seq, _ in received]
|
seqs = [seq for seq, _ in received]
|
||||||
assert sorted(seqs) == seqs, "Sequence numbers should be in order"
|
assert sorted(seqs) == seqs, "Sequence numbers should be in order"
|
||||||
@@ -209,13 +214,15 @@ def test_data_parallel_rank_tagging(publisher_config):
|
|||||||
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
|
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
|
||||||
expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port
|
expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port
|
||||||
expected_endpoint_1 = base_endpoint.replace(
|
expected_endpoint_1 = base_endpoint.replace(
|
||||||
":5557", ":5558") # rank 1 gets port + 1
|
":5557", ":5558"
|
||||||
|
) # rank 1 gets port + 1
|
||||||
else:
|
else:
|
||||||
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
|
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
|
||||||
expected_endpoint_0 = base_endpoint # rank 0 gets base
|
expected_endpoint_0 = base_endpoint # rank 0 gets base
|
||||||
expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1
|
expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1
|
||||||
|
|
||||||
from .conftest import MockSubscriber
|
from .conftest import MockSubscriber
|
||||||
|
|
||||||
sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic)
|
sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic)
|
||||||
sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic)
|
sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic)
|
||||||
|
|
||||||
@@ -241,15 +248,15 @@ def test_data_parallel_rank_tagging(publisher_config):
|
|||||||
|
|
||||||
# Verify DP rank tagging
|
# Verify DP rank tagging
|
||||||
assert received_0.data_parallel_rank == 0, (
|
assert received_0.data_parallel_rank == 0, (
|
||||||
f"Expected DP rank 0, got {received_0.data_parallel_rank}")
|
f"Expected DP rank 0, got {received_0.data_parallel_rank}"
|
||||||
|
)
|
||||||
assert received_1.data_parallel_rank == 1, (
|
assert received_1.data_parallel_rank == 1, (
|
||||||
f"Expected DP rank 1, got {received_1.data_parallel_rank}")
|
f"Expected DP rank 1, got {received_1.data_parallel_rank}"
|
||||||
|
)
|
||||||
|
|
||||||
# Verify event content is correct
|
# Verify event content is correct
|
||||||
assert len(
|
assert len(received_0.events) == 2, "Wrong number of events from rank 0"
|
||||||
received_0.events) == 2, "Wrong number of events from rank 0"
|
assert len(received_1.events) == 3, "Wrong number of events from rank 1"
|
||||||
assert len(
|
|
||||||
received_1.events) == 3, "Wrong number of events from rank 1"
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
pub_0.shutdown()
|
pub_0.shutdown()
|
||||||
|
|||||||
@@ -46,28 +46,24 @@ class EPTestSettings:
|
|||||||
):
|
):
|
||||||
return EPTestSettings(
|
return EPTestSettings(
|
||||||
parallel_setups=[
|
parallel_setups=[
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=False),
|
||||||
eager_mode=False,
|
ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=True),
|
||||||
chunked_prefill=False),
|
ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False),
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(
|
||||||
eager_mode=False,
|
tp_size=2 * tp_base, eager_mode=False, chunked_prefill=True
|
||||||
chunked_prefill=True),
|
),
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(
|
||||||
eager_mode=True,
|
tp_size=2 * tp_base, eager_mode=True, chunked_prefill=False
|
||||||
chunked_prefill=False),
|
),
|
||||||
ParallelSetup(tp_size=2 * tp_base,
|
|
||||||
eager_mode=False,
|
|
||||||
chunked_prefill=True),
|
|
||||||
ParallelSetup(tp_size=2 * tp_base,
|
|
||||||
eager_mode=True,
|
|
||||||
chunked_prefill=False),
|
|
||||||
],
|
],
|
||||||
distributed_backends=["mp", "ray"],
|
distributed_backends=["mp", "ray"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=EPTestOptions(trust_remote_code=trust_remote_code,
|
test_options=EPTestOptions(
|
||||||
tokenizer_mode=tokenizer_mode,
|
trust_remote_code=trust_remote_code,
|
||||||
load_format=load_format,
|
tokenizer_mode=tokenizer_mode,
|
||||||
hf_overrides=hf_overrides),
|
load_format=load_format,
|
||||||
|
hf_overrides=hf_overrides,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -82,16 +78,16 @@ class EPTestSettings:
|
|||||||
):
|
):
|
||||||
return EPTestSettings(
|
return EPTestSettings(
|
||||||
parallel_setups=[
|
parallel_setups=[
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False),
|
||||||
eager_mode=True,
|
|
||||||
chunked_prefill=False),
|
|
||||||
],
|
],
|
||||||
distributed_backends=["mp"],
|
distributed_backends=["mp"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=EPTestOptions(trust_remote_code=trust_remote_code,
|
test_options=EPTestOptions(
|
||||||
tokenizer_mode=tokenizer_mode,
|
trust_remote_code=trust_remote_code,
|
||||||
load_format=load_format,
|
tokenizer_mode=tokenizer_mode,
|
||||||
hf_overrides=hf_overrides),
|
load_format=load_format,
|
||||||
|
hf_overrides=hf_overrides,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_params(self, model_name: str):
|
def iter_params(self, model_name: str):
|
||||||
@@ -99,8 +95,13 @@ class EPTestSettings:
|
|||||||
|
|
||||||
for parallel_setup in self.parallel_setups:
|
for parallel_setup in self.parallel_setups:
|
||||||
for distributed_backend in self.distributed_backends:
|
for distributed_backend in self.distributed_backends:
|
||||||
yield (model_name, parallel_setup, distributed_backend,
|
yield (
|
||||||
self.runner, opts)
|
model_name,
|
||||||
|
parallel_setup,
|
||||||
|
distributed_backend,
|
||||||
|
self.runner,
|
||||||
|
opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# NOTE: You can adjust tp_base locally to fit the model in GPU
|
# NOTE: You can adjust tp_base locally to fit the model in GPU
|
||||||
|
|||||||
@@ -6,8 +6,7 @@ import pytest
|
|||||||
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
|
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
|
||||||
|
|
||||||
|
|
||||||
def verify_round_robin_pattern(expert_map, ep_rank, ep_size,
|
def verify_round_robin_pattern(expert_map, ep_rank, ep_size, global_num_experts):
|
||||||
global_num_experts):
|
|
||||||
"""Verify that the expert map follows the round_robin pattern."""
|
"""Verify that the expert map follows the round_robin pattern."""
|
||||||
# Calculate expected local experts (supporting non-divisible cases)
|
# Calculate expected local experts (supporting non-divisible cases)
|
||||||
base_experts = global_num_experts // ep_size
|
base_experts = global_num_experts // ep_size
|
||||||
@@ -30,24 +29,21 @@ def verify_round_robin_pattern(expert_map, ep_rank, ep_size,
|
|||||||
if global_expert_id in expected_expert_ids:
|
if global_expert_id in expected_expert_ids:
|
||||||
local_expert_id = expert_map[global_expert_id]
|
local_expert_id = expert_map[global_expert_id]
|
||||||
expected_local_id = expected_expert_ids.index(global_expert_id)
|
expected_local_id = expected_expert_ids.index(global_expert_id)
|
||||||
assert (
|
assert local_expert_id == expected_local_id, (
|
||||||
local_expert_id == expected_local_id
|
f"Global expert {global_expert_id} should map to local expert "
|
||||||
), f"Global expert {global_expert_id} should map to local expert " \
|
|
||||||
f"{expected_local_id}, got {local_expert_id}"
|
f"{expected_local_id}, got {local_expert_id}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert expert_map[global_expert_id] == -1, (
|
||||||
expert_map[global_expert_id] == -1
|
f"Global expert {global_expert_id} should not be mapped to this rank"
|
||||||
), f"Global expert {global_expert_id} should not be mapped to " \
|
)
|
||||||
f"this rank"
|
|
||||||
|
|
||||||
# Verify that all local expert IDs are consecutive starting from 0
|
# Verify that all local expert IDs are consecutive starting from 0
|
||||||
local_expert_ids = [
|
local_expert_ids = [expert_map[global_id] for global_id in expected_expert_ids]
|
||||||
expert_map[global_id] for global_id in expected_expert_ids
|
|
||||||
]
|
|
||||||
expected_local_ids = list(range(local_num_experts))
|
expected_local_ids = list(range(local_num_experts))
|
||||||
assert (
|
assert local_expert_ids == expected_local_ids, (
|
||||||
local_expert_ids == expected_local_ids
|
f"Expected local expert IDs {expected_local_ids}, got {local_expert_ids}"
|
||||||
), f"Expected local expert IDs {expected_local_ids}, got {local_expert_ids}"
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
|
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
|
||||||
@@ -78,8 +74,9 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size):
|
|||||||
|
|
||||||
for test_global_experts, test_ep_size in test_cases:
|
for test_global_experts, test_ep_size in test_cases:
|
||||||
# Ensure ep_size matches world_size
|
# Ensure ep_size matches world_size
|
||||||
assert (test_ep_size == world_size
|
assert test_ep_size == world_size, (
|
||||||
), f"ep_size {test_ep_size} must equal world_size {world_size}"
|
f"ep_size {test_ep_size} must equal world_size {world_size}"
|
||||||
|
)
|
||||||
|
|
||||||
# Test each rank
|
# Test each rank
|
||||||
for ep_rank in range(world_size):
|
for ep_rank in range(world_size):
|
||||||
@@ -98,21 +95,22 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size):
|
|||||||
expert_placement_strategy=expert_placement_strategy,
|
expert_placement_strategy=expert_placement_strategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert test_local_experts == expected_test_local, (
|
||||||
test_local_experts == expected_test_local
|
f"For {test_global_experts} experts on {test_ep_size} ranks, "
|
||||||
), f"For {test_global_experts} experts on {test_ep_size} ranks, " \
|
f"rank {ep_rank}: expected {expected_test_local} local"
|
||||||
f"rank {ep_rank}: expected {expected_test_local} local" \
|
|
||||||
f"experts, got {test_local_experts}"
|
f"experts, got {test_local_experts}"
|
||||||
|
)
|
||||||
|
|
||||||
if test_expert_map is not None:
|
if test_expert_map is not None:
|
||||||
assert test_expert_map.shape == (
|
assert test_expert_map.shape == (test_global_experts,), (
|
||||||
test_global_experts,
|
f"Expected expert map shape ({test_global_experts},), "
|
||||||
), f"Expected expert map shape ({test_global_experts},), " \
|
|
||||||
f"got {test_expert_map.shape}"
|
f"got {test_expert_map.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
# Verify round_robin pattern for this test case
|
# Verify round_robin pattern for this test case
|
||||||
verify_round_robin_pattern(test_expert_map, ep_rank,
|
verify_round_robin_pattern(
|
||||||
test_ep_size, test_global_experts)
|
test_expert_map, ep_rank, test_ep_size, test_global_experts
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
|
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
|
||||||
@@ -147,28 +145,81 @@ def test_determine_expert_map_comprehensive():
|
|||||||
# expert_placement_strategy, expected_local, expected_map_pattern)
|
# expert_placement_strategy, expected_local, expected_map_pattern)
|
||||||
test_cases = [
|
test_cases = [
|
||||||
# Round robin placement tests
|
# Round robin placement tests
|
||||||
(2, 0, 8, "round_robin", 4, [0, -1, 1, -1, 2, -1, 3,
|
(
|
||||||
-1]), # rank 0 gets even experts
|
2,
|
||||||
(2, 1, 8, "round_robin", 4, [-1, 0, -1, 1, -1, 2, -1,
|
0,
|
||||||
3]), # rank 1 gets odd experts
|
8,
|
||||||
(2, 0, 9, "round_robin", 5, [0, -1, 1, -1, 2, -1, 3, -1, 4
|
"round_robin",
|
||||||
]), # rank 0 gets 5 experts (even + last)
|
4,
|
||||||
(2, 1, 9, "round_robin", 4, [-1, 0, -1, 1, -1, 2, -1, 3,
|
[0, -1, 1, -1, 2, -1, 3, -1],
|
||||||
-1]), # rank 1 gets 4 experts (odd)
|
), # rank 0 gets even experts
|
||||||
|
(
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
"round_robin",
|
||||||
|
4,
|
||||||
|
[-1, 0, -1, 1, -1, 2, -1, 3],
|
||||||
|
), # rank 1 gets odd experts
|
||||||
|
(
|
||||||
|
2,
|
||||||
|
0,
|
||||||
|
9,
|
||||||
|
"round_robin",
|
||||||
|
5,
|
||||||
|
[0, -1, 1, -1, 2, -1, 3, -1, 4],
|
||||||
|
), # rank 0 gets 5 experts (even + last)
|
||||||
|
(
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
9,
|
||||||
|
"round_robin",
|
||||||
|
4,
|
||||||
|
[-1, 0, -1, 1, -1, 2, -1, 3, -1],
|
||||||
|
), # rank 1 gets 4 experts (odd)
|
||||||
# 4-rank tests
|
# 4-rank tests
|
||||||
(4, 0, 8, "round_robin", 2, [0, -1, -1, -1, 1, -1, -1,
|
(
|
||||||
-1]), # rank 0 gets experts 0, 4
|
4,
|
||||||
(4, 1, 8, "round_robin", 2, [-1, 0, -1, -1, -1, 1, -1,
|
0,
|
||||||
-1]), # rank 1 gets experts 1, 5
|
8,
|
||||||
(4, 2, 8, "round_robin", 2, [-1, -1, 0, -1, -1, -1, 1,
|
"round_robin",
|
||||||
-1]), # rank 2 gets experts 2, 6
|
2,
|
||||||
(4, 3, 8, "round_robin", 2, [-1, -1, -1, 0, -1, -1, -1,
|
[0, -1, -1, -1, 1, -1, -1, -1],
|
||||||
1]), # rank 3 gets experts 3, 7
|
), # rank 0 gets experts 0, 4
|
||||||
|
(
|
||||||
|
4,
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
"round_robin",
|
||||||
|
2,
|
||||||
|
[-1, 0, -1, -1, -1, 1, -1, -1],
|
||||||
|
), # rank 1 gets experts 1, 5
|
||||||
|
(
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
8,
|
||||||
|
"round_robin",
|
||||||
|
2,
|
||||||
|
[-1, -1, 0, -1, -1, -1, 1, -1],
|
||||||
|
), # rank 2 gets experts 2, 6
|
||||||
|
(
|
||||||
|
4,
|
||||||
|
3,
|
||||||
|
8,
|
||||||
|
"round_robin",
|
||||||
|
2,
|
||||||
|
[-1, -1, -1, 0, -1, -1, -1, 1],
|
||||||
|
), # rank 3 gets experts 3, 7
|
||||||
]
|
]
|
||||||
|
|
||||||
for ep_size, ep_rank, global_num_experts, expert_placement_strategy, \
|
for (
|
||||||
expected_local, expected_map_pattern in test_cases:
|
ep_size,
|
||||||
|
ep_rank,
|
||||||
|
global_num_experts,
|
||||||
|
expert_placement_strategy,
|
||||||
|
expected_local,
|
||||||
|
expected_map_pattern,
|
||||||
|
) in test_cases:
|
||||||
local_num_experts, expert_map = determine_expert_map(
|
local_num_experts, expert_map = determine_expert_map(
|
||||||
ep_size=ep_size,
|
ep_size=ep_size,
|
||||||
ep_rank=ep_rank,
|
ep_rank=ep_rank,
|
||||||
@@ -176,19 +227,21 @@ def test_determine_expert_map_comprehensive():
|
|||||||
expert_placement_strategy=expert_placement_strategy,
|
expert_placement_strategy=expert_placement_strategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert local_num_experts == expected_local, \
|
assert local_num_experts == expected_local, (
|
||||||
f"ep_size={ep_size}, ep_rank={ep_rank}, " \
|
f"ep_size={ep_size}, ep_rank={ep_rank}, "
|
||||||
f"global_num_experts={global_num_experts}, " \
|
f"global_num_experts={global_num_experts}, "
|
||||||
f"expert_placement_strategy={expert_placement_strategy}: " \
|
f"expert_placement_strategy={expert_placement_strategy}: "
|
||||||
f"expected {expected_local} local experts, got {local_num_experts}"
|
f"expected {expected_local} local experts, got {local_num_experts}"
|
||||||
|
)
|
||||||
|
|
||||||
if expected_map_pattern is None:
|
if expected_map_pattern is None:
|
||||||
assert expert_map is None, "Expected expert_map to be None"
|
assert expert_map is None, "Expected expert_map to be None"
|
||||||
else:
|
else:
|
||||||
assert expert_map is not None, "Expected expert_map to not be None"
|
assert expert_map is not None, "Expected expert_map to not be None"
|
||||||
actual_map = expert_map.tolist()
|
actual_map = expert_map.tolist()
|
||||||
assert actual_map == expected_map_pattern, \
|
assert actual_map == expected_map_pattern, (
|
||||||
f"ep_size={ep_size}, ep_rank={ep_rank}, " \
|
f"ep_size={ep_size}, ep_rank={ep_rank}, "
|
||||||
f"global_num_experts={global_num_experts}, " \
|
f"global_num_experts={global_num_experts}, "
|
||||||
f"expert_placement_strategy={expert_placement_strategy}: " \
|
f"expert_placement_strategy={expert_placement_strategy}: "
|
||||||
f"expected map {expected_map_pattern}, got {actual_map}"
|
f"expected map {expected_map_pattern}, got {actual_map}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from vllm.config import (DeviceConfig, KVTransferConfig, ModelConfig,
|
from vllm.config import (
|
||||||
VllmConfig, set_current_vllm_config)
|
DeviceConfig,
|
||||||
|
KVTransferConfig,
|
||||||
|
ModelConfig,
|
||||||
|
VllmConfig,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||||
get_kv_connector_cache_layout)
|
get_kv_connector_cache_layout,
|
||||||
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger("test_expert_parallel")
|
logger = init_logger("test_expert_parallel")
|
||||||
@@ -23,8 +29,9 @@ def test_get_kv_connector_cache_layout_with_lmcache_connector():
|
|||||||
kv_connector="LMCacheConnectorV1",
|
kv_connector="LMCacheConnectorV1",
|
||||||
kv_role="kv_both",
|
kv_role="kv_both",
|
||||||
)
|
)
|
||||||
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
|
vllm_config = VllmConfig(
|
||||||
kv_transfer_config=kv_transfer_config)
|
device_config=DeviceConfig("cpu"), kv_transfer_config=kv_transfer_config
|
||||||
|
)
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
# Test with default settings
|
# Test with default settings
|
||||||
layout = get_kv_connector_cache_layout()
|
layout = get_kv_connector_cache_layout()
|
||||||
@@ -37,9 +44,11 @@ def test_get_kv_connector_cache_layout_with_nixl_connector():
|
|||||||
kv_role="kv_both",
|
kv_role="kv_both",
|
||||||
)
|
)
|
||||||
model_config = ModelConfig()
|
model_config = ModelConfig()
|
||||||
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
|
vllm_config = VllmConfig(
|
||||||
model_config=model_config,
|
device_config=DeviceConfig("cpu"),
|
||||||
kv_transfer_config=kv_transfer_config)
|
model_config=model_config,
|
||||||
|
kv_transfer_config=kv_transfer_config,
|
||||||
|
)
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
# Test with default settings
|
# Test with default settings
|
||||||
layout = get_kv_connector_cache_layout()
|
layout = get_kv_connector_cache_layout()
|
||||||
@@ -47,25 +56,22 @@ def test_get_kv_connector_cache_layout_with_nixl_connector():
|
|||||||
|
|
||||||
|
|
||||||
def test_get_kv_connector_cache_layout_with_multi_connector():
|
def test_get_kv_connector_cache_layout_with_multi_connector():
|
||||||
kv_transfer_config = KVTransferConfig(kv_connector="MultiConnector",
|
kv_transfer_config = KVTransferConfig(
|
||||||
kv_role="kv_both",
|
kv_connector="MultiConnector",
|
||||||
kv_connector_extra_config={
|
kv_role="kv_both",
|
||||||
"connectors": [{
|
kv_connector_extra_config={
|
||||||
"kv_connector":
|
"connectors": [
|
||||||
"SharedStorageConnector",
|
{"kv_connector": "SharedStorageConnector", "kv_role": "kv_both"},
|
||||||
"kv_role":
|
{"kv_connector": "NixlConnector", "kv_role": "kv_both"},
|
||||||
"kv_both"
|
]
|
||||||
}, {
|
},
|
||||||
"kv_connector":
|
)
|
||||||
"NixlConnector",
|
|
||||||
"kv_role":
|
|
||||||
"kv_both"
|
|
||||||
}]
|
|
||||||
})
|
|
||||||
model_config = ModelConfig()
|
model_config = ModelConfig()
|
||||||
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
|
vllm_config = VllmConfig(
|
||||||
model_config=model_config,
|
device_config=DeviceConfig("cpu"),
|
||||||
kv_transfer_config=kv_transfer_config)
|
model_config=model_config,
|
||||||
|
kv_transfer_config=kv_transfer_config,
|
||||||
|
)
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
# Test with default settings
|
# Test with default settings
|
||||||
layout = get_kv_connector_cache_layout()
|
layout = get_kv_connector_cache_layout()
|
||||||
|
|||||||
@@ -24,14 +24,13 @@ from vllm.utils import get_ip
|
|||||||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not VLLM_MULTI_NODE,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 nodes to run the test.")
|
not VLLM_MULTI_NODE, reason="Need at least 2 nodes to run the test."
|
||||||
|
)
|
||||||
def test_multi_node_assignment() -> None:
|
def test_multi_node_assignment() -> None:
|
||||||
|
|
||||||
# NOTE: important to keep this class definition here
|
# NOTE: important to keep this class definition here
|
||||||
# to let ray use cloudpickle to serialize it.
|
# to let ray use cloudpickle to serialize it.
|
||||||
class Actor:
|
class Actor:
|
||||||
|
|
||||||
def get_ip(self):
|
def get_ip(self):
|
||||||
return get_ip()
|
return get_ip()
|
||||||
|
|
||||||
@@ -41,8 +40,7 @@ def test_multi_node_assignment() -> None:
|
|||||||
|
|
||||||
current_ip = get_ip()
|
current_ip = get_ip()
|
||||||
workers = []
|
workers = []
|
||||||
for bundle_id, bundle in enumerate(
|
for bundle_id, bundle in enumerate(config.placement_group.bundle_specs):
|
||||||
config.placement_group.bundle_specs):
|
|
||||||
if not bundle.get("GPU", 0):
|
if not bundle.get("GPU", 0):
|
||||||
continue
|
continue
|
||||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||||
|
|||||||
@@ -11,15 +11,17 @@ import torch.multiprocessing as mp
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
|
||||||
CudaCommunicator)
|
from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops
|
||||||
from vllm.distributed.device_communicators.pynccl import (
|
|
||||||
register_nccl_symmetric_ops)
|
|
||||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||||
get_nccl_mem_pool, is_symmetric_memory_enabled)
|
get_nccl_mem_pool,
|
||||||
from vllm.distributed.parallel_state import (get_tp_group,
|
is_symmetric_memory_enabled,
|
||||||
init_distributed_environment,
|
)
|
||||||
initialize_model_parallel)
|
from vllm.distributed.parallel_state import (
|
||||||
|
get_tp_group,
|
||||||
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
@@ -38,31 +40,32 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
|||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
update_environment_variables({
|
update_environment_variables(
|
||||||
"RANK": str(local_rank),
|
{
|
||||||
"LOCAL_RANK": str(local_rank),
|
"RANK": str(local_rank),
|
||||||
"WORLD_SIZE": str(world_size),
|
"LOCAL_RANK": str(local_rank),
|
||||||
"MASTER_ADDR": "localhost",
|
"WORLD_SIZE": str(world_size),
|
||||||
"MASTER_PORT": "12345",
|
"MASTER_ADDR": "localhost",
|
||||||
})
|
"MASTER_PORT": "12345",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
|
||||||
cuda_communicator = typing.cast(CudaCommunicator,
|
cuda_communicator = typing.cast(
|
||||||
get_tp_group().device_communicator)
|
CudaCommunicator, get_tp_group().device_communicator
|
||||||
|
)
|
||||||
pynccl_comm = cuda_communicator.pynccl_comm
|
pynccl_comm = cuda_communicator.pynccl_comm
|
||||||
if get_nccl_mem_pool() is None:
|
if get_nccl_mem_pool() is None:
|
||||||
pytest.skip("NCCL allocator compilation failed "
|
pytest.skip(
|
||||||
"(probably missing NCCL headers).")
|
"NCCL allocator compilation failed (probably missing NCCL headers)."
|
||||||
|
)
|
||||||
if not is_symmetric_memory_enabled():
|
if not is_symmetric_memory_enabled():
|
||||||
pytest.skip("NCCL symmetric memory allreduce is disabled.")
|
pytest.skip("NCCL symmetric memory allreduce is disabled.")
|
||||||
|
|
||||||
register_nccl_symmetric_ops(pynccl_comm)
|
register_nccl_symmetric_ops(pynccl_comm)
|
||||||
input = torch.randint(1,
|
input = torch.randint(1, 23, (test_size_elements,), dtype=dtype, device=device)
|
||||||
23, (test_size_elements, ),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
input_clone = input.clone()
|
input_clone = input.clone()
|
||||||
output = torch.ops.vllm.all_reduce_symmetric_with_copy(input)
|
output = torch.ops.vllm.all_reduce_symmetric_with_copy(input)
|
||||||
assert output is not None
|
assert output is not None
|
||||||
@@ -77,8 +80,7 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
|||||||
reason="NCCLSymmMemAllreduce is only available for CUDA platforms.",
|
reason="NCCLSymmMemAllreduce is only available for CUDA platforms.",
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("world_size", [2])
|
@pytest.mark.parametrize("world_size", [2])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||||
reason="Only test on CUDA")
|
|
||||||
def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
|
def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
|
||||||
if world_size > torch.cuda.device_count():
|
if world_size > torch.cuda.device_count():
|
||||||
pytest.skip("Not enough GPUs to run the test.")
|
pytest.skip("Not enough GPUs to run the test.")
|
||||||
@@ -88,7 +90,5 @@ def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
|
|||||||
monkeypatch.setenv("NCCL_NVLS_ENABLE", "1")
|
monkeypatch.setenv("NCCL_NVLS_ENABLE", "1")
|
||||||
monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1")
|
monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1")
|
||||||
|
|
||||||
mp.spawn(nccl_symm_mem_allreduce_worker,
|
mp.spawn(nccl_symm_mem_allreduce_worker, args=(world_size,), nprocs=world_size)
|
||||||
args=(world_size, ),
|
|
||||||
nprocs=world_size)
|
|
||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|||||||
@@ -32,12 +32,15 @@ if __name__ == "__main__":
|
|||||||
# Expected node count based on environment variable)
|
# Expected node count based on environment variable)
|
||||||
expected = int(os.environ.get("NUM_NODES", "1"))
|
expected = int(os.environ.get("NUM_NODES", "1"))
|
||||||
|
|
||||||
assert test_result == expected, \
|
assert test_result == expected, f"Expected {expected} nodes, got {test_result}"
|
||||||
f"Expected {expected} nodes, got {test_result}"
|
|
||||||
|
|
||||||
if pg == dist.group.WORLD:
|
if pg == dist.group.WORLD:
|
||||||
print(f"Node count test passed! Got {test_result} nodes "
|
print(
|
||||||
f"when using torch distributed!")
|
f"Node count test passed! Got {test_result} nodes "
|
||||||
|
f"when using torch distributed!"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f"Node count test passed! Got {test_result} nodes "
|
print(
|
||||||
f"when using StatelessProcessGroup!")
|
f"Node count test passed! Got {test_result} nodes "
|
||||||
|
f"when using StatelessProcessGroup!"
|
||||||
|
)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
|||||||
all workers in a node other than the head node, which can cause the test
|
all workers in a node other than the head node, which can cause the test
|
||||||
to fail.
|
to fail.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -55,26 +56,17 @@ class PPTestSettings:
|
|||||||
):
|
):
|
||||||
return PPTestSettings(
|
return PPTestSettings(
|
||||||
parallel_setups=[
|
parallel_setups=[
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(tp_size=tp_base, pp_size=pp_base, eager_mode=False),
|
||||||
pp_size=pp_base,
|
ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, eager_mode=False),
|
||||||
eager_mode=False),
|
ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, eager_mode=True),
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, eager_mode=False),
|
||||||
pp_size=2 * pp_base,
|
ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, eager_mode=True),
|
||||||
eager_mode=False),
|
|
||||||
ParallelSetup(tp_size=tp_base,
|
|
||||||
pp_size=2 * pp_base,
|
|
||||||
eager_mode=True),
|
|
||||||
ParallelSetup(tp_size=2 * tp_base,
|
|
||||||
pp_size=pp_base,
|
|
||||||
eager_mode=False),
|
|
||||||
ParallelSetup(tp_size=2 * tp_base,
|
|
||||||
pp_size=pp_base,
|
|
||||||
eager_mode=True),
|
|
||||||
],
|
],
|
||||||
distributed_backends=["mp", "ray"],
|
distributed_backends=["mp", "ray"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
test_options=PPTestOptions(
|
||||||
load_format=load_format),
|
multi_node_only=multi_node_only, load_format=load_format
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -86,17 +78,15 @@ class PPTestSettings:
|
|||||||
multi_node_only: bool = False,
|
multi_node_only: bool = False,
|
||||||
load_format: Optional[str] = None,
|
load_format: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
return PPTestSettings(
|
return PPTestSettings(
|
||||||
parallel_setups=[
|
parallel_setups=[
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(tp_size=tp_base, pp_size=pp_base, eager_mode=True),
|
||||||
pp_size=pp_base,
|
|
||||||
eager_mode=True),
|
|
||||||
],
|
],
|
||||||
distributed_backends=["mp"],
|
distributed_backends=["mp"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
test_options=PPTestOptions(
|
||||||
load_format=load_format),
|
multi_node_only=multi_node_only, load_format=load_format
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_params(self, model_id: str):
|
def iter_params(self, model_id: str):
|
||||||
@@ -281,8 +271,10 @@ def _compare_tp(
|
|||||||
if num_gpus_available < tp_size * pp_size:
|
if num_gpus_available < tp_size * pp_size:
|
||||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
pytest.skip(
|
||||||
"multiprocessing distributed backend")
|
"Skipping multi-node pipeline parallel test for "
|
||||||
|
"multiprocessing distributed backend"
|
||||||
|
)
|
||||||
if multi_node_only and not VLLM_MULTI_NODE:
|
if multi_node_only and not VLLM_MULTI_NODE:
|
||||||
pytest.skip("Not in multi-node setting")
|
pytest.skip("Not in multi-node setting")
|
||||||
|
|
||||||
@@ -357,20 +349,16 @@ def _compare_tp(
|
|||||||
"mp",
|
"mp",
|
||||||
]
|
]
|
||||||
|
|
||||||
compare_two_settings(model_id,
|
compare_two_settings(model_id, pp_args, tp_args, pp_env, tp_env, method=method)
|
||||||
pp_args,
|
|
||||||
tp_args,
|
|
||||||
pp_env,
|
|
||||||
tp_env,
|
|
||||||
method=method)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_id", "parallel_setup", "distributed_backend", "runner",
|
("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"),
|
||||||
"test_options"),
|
|
||||||
[
|
[
|
||||||
params for model_id, settings in TEXT_GENERATION_MODELS.items()
|
params
|
||||||
for params in settings.iter_params(model_id) if model_id in TEST_MODELS
|
for model_id, settings in TEXT_GENERATION_MODELS.items()
|
||||||
|
for params in settings.iter_params(model_id)
|
||||||
|
if model_id in TEST_MODELS
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
@@ -382,22 +370,25 @@ def test_tp_language_generation(
|
|||||||
test_options: PPTestOptions,
|
test_options: PPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
_compare_tp(model_id,
|
_compare_tp(
|
||||||
parallel_setup,
|
model_id,
|
||||||
distributed_backend,
|
parallel_setup,
|
||||||
runner,
|
distributed_backend,
|
||||||
test_options,
|
runner,
|
||||||
num_gpus_available,
|
test_options,
|
||||||
method="generate",
|
num_gpus_available,
|
||||||
is_multimodal=False)
|
method="generate",
|
||||||
|
is_multimodal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_id", "parallel_setup", "distributed_backend", "runner",
|
("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"),
|
||||||
"test_options"),
|
|
||||||
[
|
[
|
||||||
params for model_id, settings in EMBEDDING_MODELS.items()
|
params
|
||||||
for params in settings.iter_params(model_id) if model_id in TEST_MODELS
|
for model_id, settings in EMBEDDING_MODELS.items()
|
||||||
|
for params in settings.iter_params(model_id)
|
||||||
|
if model_id in TEST_MODELS
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
@@ -409,22 +400,25 @@ def test_tp_language_embedding(
|
|||||||
test_options: PPTestOptions,
|
test_options: PPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
_compare_tp(model_id,
|
_compare_tp(
|
||||||
parallel_setup,
|
model_id,
|
||||||
distributed_backend,
|
parallel_setup,
|
||||||
runner,
|
distributed_backend,
|
||||||
test_options,
|
runner,
|
||||||
num_gpus_available,
|
test_options,
|
||||||
method="encode",
|
num_gpus_available,
|
||||||
is_multimodal=False)
|
method="encode",
|
||||||
|
is_multimodal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_id", "parallel_setup", "distributed_backend", "runner",
|
("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"),
|
||||||
"test_options"),
|
|
||||||
[
|
[
|
||||||
params for model_id, settings in MULTIMODAL_MODELS.items()
|
params
|
||||||
for params in settings.iter_params(model_id) if model_id in TEST_MODELS
|
for model_id, settings in MULTIMODAL_MODELS.items()
|
||||||
|
for params in settings.iter_params(model_id)
|
||||||
|
if model_id in TEST_MODELS
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
@@ -436,11 +430,13 @@ def test_tp_multimodal_generation(
|
|||||||
test_options: PPTestOptions,
|
test_options: PPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
_compare_tp(model_id,
|
_compare_tp(
|
||||||
parallel_setup,
|
model_id,
|
||||||
distributed_backend,
|
parallel_setup,
|
||||||
runner,
|
distributed_backend,
|
||||||
test_options,
|
runner,
|
||||||
num_gpus_available,
|
test_options,
|
||||||
method="generate",
|
num_gpus_available,
|
||||||
is_multimodal=True)
|
method="generate",
|
||||||
|
is_multimodal=True,
|
||||||
|
)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from vllm.distributed.utils import get_pp_indices
|
|||||||
|
|
||||||
|
|
||||||
def test_custom_layer_partition(monkeypatch: pytest.MonkeyPatch):
|
def test_custom_layer_partition(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
|
|
||||||
def _verify(partition_str, num_layers, pp_size, goldens):
|
def _verify(partition_str, num_layers, pp_size, goldens):
|
||||||
@@ -57,7 +56,8 @@ def test_custom_layer_partition(monkeypatch: pytest.MonkeyPatch):
|
|||||||
(5, 3, 0, (0, 2)),
|
(5, 3, 0, (0, 2)),
|
||||||
(5, 3, 1, (2, 4)),
|
(5, 3, 1, (2, 4)),
|
||||||
(5, 3, 2, (4, 5)),
|
(5, 3, 2, (4, 5)),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_uneven_auto_partition(
|
def test_uneven_auto_partition(
|
||||||
num_hidden_layers: int,
|
num_hidden_layers: int,
|
||||||
pp_size: int,
|
pp_size: int,
|
||||||
|
|||||||
@@ -12,12 +12,18 @@ if TYPE_CHECKING:
|
|||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [
|
@pytest.mark.parametrize(
|
||||||
(2, "JackFram/llama-160m"),
|
"PP_SIZE, MODEL_NAME",
|
||||||
])
|
[
|
||||||
@pytest.mark.parametrize("ATTN_BACKEND", [
|
(2, "JackFram/llama-160m"),
|
||||||
"FLASH_ATTN",
|
],
|
||||||
])
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"ATTN_BACKEND",
|
||||||
|
[
|
||||||
|
"FLASH_ATTN",
|
||||||
|
],
|
||||||
|
)
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_pp_cudagraph(
|
def test_pp_cudagraph(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
|||||||
@@ -9,13 +9,15 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from vllm.distributed.communication_op import ( # noqa
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||||
tensor_model_parallel_all_reduce)
|
|
||||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
||||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
from vllm.distributed.parallel_state import (
|
||||||
get_world_group, graph_capture,
|
ensure_model_parallel_initialized,
|
||||||
init_distributed_environment)
|
get_world_group,
|
||||||
|
graph_capture,
|
||||||
|
init_distributed_environment,
|
||||||
|
)
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
|
|
||||||
@@ -24,13 +26,13 @@ def distributed_run(fn, world_size):
|
|||||||
processes: list[multiprocessing.Process] = []
|
processes: list[multiprocessing.Process] = []
|
||||||
for i in range(number_of_processes):
|
for i in range(number_of_processes):
|
||||||
env: dict[str, str] = {}
|
env: dict[str, str] = {}
|
||||||
env['RANK'] = str(i)
|
env["RANK"] = str(i)
|
||||||
env['LOCAL_RANK'] = str(i)
|
env["LOCAL_RANK"] = str(i)
|
||||||
env['WORLD_SIZE'] = str(number_of_processes)
|
env["WORLD_SIZE"] = str(number_of_processes)
|
||||||
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
|
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||||
env['MASTER_ADDR'] = 'localhost'
|
env["MASTER_ADDR"] = "localhost"
|
||||||
env['MASTER_PORT'] = '12345'
|
env["MASTER_PORT"] = "12345"
|
||||||
p = multiprocessing.Process(target=fn, args=(env, ))
|
p = multiprocessing.Process(target=fn, args=(env,))
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
@@ -47,7 +49,7 @@ def worker_fn_wrapper(fn):
|
|||||||
# and update the environment variables in the function
|
# and update the environment variables in the function
|
||||||
def wrapped_fn(env):
|
def wrapped_fn(env):
|
||||||
update_environment_variables(env)
|
update_environment_variables(env)
|
||||||
local_rank = os.environ['LOCAL_RANK']
|
local_rank = os.environ["LOCAL_RANK"]
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
@@ -58,17 +60,18 @@ def worker_fn_wrapper(fn):
|
|||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def worker_fn():
|
def worker_fn():
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
tensor = torch.ones(16, 1024, 1024,
|
)
|
||||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
tensor = pynccl_comm.all_reduce(tensor)
|
tensor = pynccl_comm.all_reduce(tensor)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
|
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl():
|
def test_pynccl():
|
||||||
distributed_run(worker_fn, 2)
|
distributed_run(worker_fn, 2)
|
||||||
|
|
||||||
@@ -78,7 +81,7 @@ def multiple_allreduce_worker_fn():
|
|||||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||||
groups = [
|
groups = [
|
||||||
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
|
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
|
||||||
torch.distributed.new_group(ranks=[2, 3], backend="gloo")
|
torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
|
||||||
]
|
]
|
||||||
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
|
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
|
||||||
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
||||||
@@ -95,8 +98,9 @@ def multiple_allreduce_worker_fn():
|
|||||||
assert torch.all(tensor == 2).cpu().item()
|
assert torch.all(tensor == 2).cpu().item()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 4 GPUs to run the test.")
|
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_multiple_allreduce():
|
def test_pynccl_multiple_allreduce():
|
||||||
# this tests pynccl for multiple tp groups, in a standalone way
|
# this tests pynccl for multiple tp groups, in a standalone way
|
||||||
# i.e. call `pynccl_comm.all_reduce` directly
|
# i.e. call `pynccl_comm.all_reduce` directly
|
||||||
@@ -121,8 +125,9 @@ def multiple_allreduce_with_vllm_worker_fn():
|
|||||||
assert torch.all(tensor == 2).cpu().item()
|
assert torch.all(tensor == 2).cpu().item()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 4 GPUs to run the test.")
|
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_multiple_allreduce_with_vllm():
|
def test_pynccl_multiple_allreduce_with_vllm():
|
||||||
# this tests pynccl for multiple tp groups, together with vllm
|
# this tests pynccl for multiple tp groups, together with vllm
|
||||||
# i.e. call `tensor_model_parallel_all_reduce`
|
# i.e. call `tensor_model_parallel_all_reduce`
|
||||||
@@ -133,10 +138,11 @@ def test_pynccl_multiple_allreduce_with_vllm():
|
|||||||
def worker_fn_with_cudagraph():
|
def worker_fn_with_cudagraph():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
|
)
|
||||||
# run something in the default stream to initialize torch engine
|
# run something in the default stream to initialize torch engine
|
||||||
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
|
a = torch.ones((4, 4), device=f"cuda:{pynccl_comm.rank}")
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
with torch.cuda.graph(graph):
|
with torch.cuda.graph(graph):
|
||||||
a_out = pynccl_comm.all_reduce(a)
|
a_out = pynccl_comm.all_reduce(a)
|
||||||
@@ -148,84 +154,90 @@ def worker_fn_with_cudagraph():
|
|||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def all_gather_worker_fn():
|
def all_gather_worker_fn():
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
|
)
|
||||||
|
|
||||||
rank = pynccl_comm.rank
|
rank = pynccl_comm.rank
|
||||||
world_size = pynccl_comm.world_size
|
world_size = pynccl_comm.world_size
|
||||||
device = f'cuda:{pynccl_comm.rank}'
|
device = f"cuda:{pynccl_comm.rank}"
|
||||||
|
|
||||||
num_elems = 1000
|
num_elems = 1000
|
||||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
tensor = (
|
||||||
device=device) + rank * num_elems
|
torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems
|
||||||
result = torch.zeros(num_elems * world_size,
|
)
|
||||||
dtype=torch.float32,
|
result = torch.zeros(num_elems * world_size, dtype=torch.float32, device=device)
|
||||||
device=device)
|
|
||||||
|
|
||||||
expected = torch.cat([
|
expected = torch.cat(
|
||||||
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
[
|
||||||
for r in range(world_size)
|
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
||||||
]).to(device)
|
for r in range(world_size)
|
||||||
|
]
|
||||||
|
).to(device)
|
||||||
|
|
||||||
pynccl_comm.all_gather(result, tensor)
|
pynccl_comm.all_gather(result, tensor)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_all_gather():
|
def test_pynccl_all_gather():
|
||||||
distributed_run(all_gather_worker_fn, 2)
|
distributed_run(all_gather_worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def all_gatherv_worker_fn():
|
def all_gatherv_worker_fn():
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
|
)
|
||||||
|
|
||||||
rank = pynccl_comm.rank
|
rank = pynccl_comm.rank
|
||||||
world_size = pynccl_comm.world_size
|
world_size = pynccl_comm.world_size
|
||||||
device = f'cuda:{pynccl_comm.rank}'
|
device = f"cuda:{pynccl_comm.rank}"
|
||||||
|
|
||||||
assert world_size <= 8
|
assert world_size <= 8
|
||||||
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
||||||
num_elems = sizes[rank]
|
num_elems = sizes[rank]
|
||||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100
|
||||||
device=device) + rank * 100
|
|
||||||
result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)
|
result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)
|
||||||
|
|
||||||
expected = torch.cat([
|
expected = torch.cat(
|
||||||
torch.arange(sizes[r], dtype=torch.float32) + r * 100
|
[
|
||||||
for r in range(world_size)
|
torch.arange(sizes[r], dtype=torch.float32) + r * 100
|
||||||
]).to(device)
|
for r in range(world_size)
|
||||||
|
]
|
||||||
|
).to(device)
|
||||||
|
|
||||||
pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
|
pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_all_gatherv():
|
def test_pynccl_all_gatherv():
|
||||||
distributed_run(all_gatherv_worker_fn, 2)
|
distributed_run(all_gatherv_worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def reduce_scatter_worker_fn():
|
def reduce_scatter_worker_fn():
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
|
)
|
||||||
|
|
||||||
rank = pynccl_comm.rank
|
rank = pynccl_comm.rank
|
||||||
world_size = pynccl_comm.world_size
|
world_size = pynccl_comm.world_size
|
||||||
device = f'cuda:{pynccl_comm.rank}'
|
device = f"cuda:{pynccl_comm.rank}"
|
||||||
|
|
||||||
num_elems = 1000
|
num_elems = 1000
|
||||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
tensor = (
|
||||||
device=device) + rank * num_elems
|
torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems
|
||||||
assert (num_elems % world_size == 0)
|
)
|
||||||
result = torch.zeros(num_elems // world_size,
|
assert num_elems % world_size == 0
|
||||||
dtype=torch.float32,
|
result = torch.zeros(num_elems // world_size, dtype=torch.float32, device=device)
|
||||||
device=device)
|
|
||||||
|
|
||||||
# Calculate expected result for this rank's chunk
|
# Calculate expected result for this rank's chunk
|
||||||
scattered_size = num_elems // world_size
|
scattered_size = num_elems // world_size
|
||||||
@@ -233,34 +245,37 @@ def reduce_scatter_worker_fn():
|
|||||||
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
||||||
for r in range(world_size)
|
for r in range(world_size)
|
||||||
]
|
]
|
||||||
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
|
expected = sum(
|
||||||
for tensor in all_tensors).to(device)
|
tensor[rank * scattered_size : (rank + 1) * scattered_size]
|
||||||
|
for tensor in all_tensors
|
||||||
|
).to(device)
|
||||||
|
|
||||||
pynccl_comm.reduce_scatter(result, tensor)
|
pynccl_comm.reduce_scatter(result, tensor)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_reduce_scatter():
|
def test_pynccl_reduce_scatter():
|
||||||
distributed_run(reduce_scatter_worker_fn, 2)
|
distributed_run(reduce_scatter_worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def reduce_scatterv_worker_fn():
|
def reduce_scatterv_worker_fn():
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
|
)
|
||||||
|
|
||||||
rank = pynccl_comm.rank
|
rank = pynccl_comm.rank
|
||||||
world_size = pynccl_comm.world_size
|
world_size = pynccl_comm.world_size
|
||||||
device = f'cuda:{pynccl_comm.rank}'
|
device = f"cuda:{pynccl_comm.rank}"
|
||||||
|
|
||||||
assert world_size <= 8
|
assert world_size <= 8
|
||||||
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
||||||
num_elems = sum(sizes)
|
num_elems = sum(sizes)
|
||||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100
|
||||||
device=device) + rank * 100
|
|
||||||
result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)
|
result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)
|
||||||
|
|
||||||
# Calculate expected result for this rank's chunk
|
# Calculate expected result for this rank's chunk
|
||||||
@@ -278,41 +293,41 @@ def reduce_scatterv_worker_fn():
|
|||||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_reduce_scatterv():
|
def test_pynccl_reduce_scatterv():
|
||||||
distributed_run(reduce_scatterv_worker_fn, 2)
|
distributed_run(reduce_scatterv_worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_with_cudagraph():
|
def test_pynccl_with_cudagraph():
|
||||||
distributed_run(worker_fn_with_cudagraph, 2)
|
distributed_run(worker_fn_with_cudagraph, 2)
|
||||||
|
|
||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def send_recv_worker_fn():
|
def send_recv_worker_fn():
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
|
)
|
||||||
if pynccl_comm.rank == 0:
|
if pynccl_comm.rank == 0:
|
||||||
tensor = torch.ones(16, 1024, 1024,
|
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
|
||||||
else:
|
else:
|
||||||
tensor = torch.empty(16, 1024, 1024,
|
tensor = torch.empty(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
|
||||||
|
|
||||||
if pynccl_comm.rank == 0:
|
if pynccl_comm.rank == 0:
|
||||||
pynccl_comm.send(tensor,
|
pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
||||||
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
|
||||||
else:
|
else:
|
||||||
pynccl_comm.recv(tensor,
|
pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
||||||
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
assert torch.all(tensor == 1).cpu().item()
|
assert torch.all(tensor == 1).cpu().item()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_send_recv():
|
def test_pynccl_send_recv():
|
||||||
distributed_run(send_recv_worker_fn, 2)
|
distributed_run(send_recv_worker_fn, 2)
|
||||||
|
|
||||||
@@ -322,27 +337,20 @@ def multiple_send_recv_worker_fn():
|
|||||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||||
groups = [
|
groups = [
|
||||||
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
|
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
|
||||||
torch.distributed.new_group(ranks=[1, 3], backend="gloo")
|
torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
|
||||||
]
|
]
|
||||||
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
|
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
|
||||||
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
||||||
if torch.distributed.get_rank() == 0:
|
if torch.distributed.get_rank() == 0:
|
||||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||||
elif torch.distributed.get_rank() == 1:
|
elif torch.distributed.get_rank() == 1:
|
||||||
tensor = 2 * torch.ones(
|
tensor = 2 * torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||||
16, 1024, 1024, dtype=torch.float32, device=device)
|
|
||||||
else:
|
else:
|
||||||
tensor = torch.empty(16,
|
tensor = torch.empty(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||||
1024,
|
|
||||||
1024,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=device)
|
|
||||||
if torch.distributed.get_rank() in [0, 1]:
|
if torch.distributed.get_rank() in [0, 1]:
|
||||||
pynccl_comm.send(tensor,
|
pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
||||||
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
|
||||||
else:
|
else:
|
||||||
pynccl_comm.recv(tensor,
|
pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
||||||
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
if torch.distributed.get_rank() in [0, 2]:
|
if torch.distributed.get_rank() in [0, 2]:
|
||||||
assert torch.all(tensor == 1).cpu().item()
|
assert torch.all(tensor == 1).cpu().item()
|
||||||
@@ -350,14 +358,16 @@ def multiple_send_recv_worker_fn():
|
|||||||
assert torch.all(tensor == 2).cpu().item()
|
assert torch.all(tensor == 2).cpu().item()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 4 GPUs to run the test.")
|
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_multiple_send_recv():
|
def test_pynccl_multiple_send_recv():
|
||||||
distributed_run(multiple_send_recv_worker_fn, 4)
|
distributed_run(multiple_send_recv_worker_fn, 4)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
@pytest.mark.skipif(
|
||||||
reason="Need at least 4 GPUs to run the test.")
|
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
|
||||||
|
)
|
||||||
def test_pynccl_broadcast():
|
def test_pynccl_broadcast():
|
||||||
distributed_run(broadcast_worker_fn, 4)
|
distributed_run(broadcast_worker_fn, 4)
|
||||||
|
|
||||||
@@ -366,19 +376,17 @@ def test_pynccl_broadcast():
|
|||||||
def broadcast_worker_fn():
|
def broadcast_worker_fn():
|
||||||
# Test broadcast for every root rank.
|
# Test broadcast for every root rank.
|
||||||
# Essentially this is an all-gather operation.
|
# Essentially this is an all-gather operation.
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(
|
||||||
device=get_world_group().device)
|
get_world_group().cpu_group, device=get_world_group().device
|
||||||
|
)
|
||||||
recv_tensors = [
|
recv_tensors = [
|
||||||
torch.empty(16,
|
torch.empty(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device)
|
||||||
1024,
|
|
||||||
1024,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=pynccl_comm.device)
|
|
||||||
for i in range(pynccl_comm.world_size)
|
for i in range(pynccl_comm.world_size)
|
||||||
]
|
]
|
||||||
recv_tensors[pynccl_comm.rank] = torch.ones(
|
recv_tensors[pynccl_comm.rank] = (
|
||||||
16, 1024, 1024, dtype=torch.float32,
|
torch.ones(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device)
|
||||||
device=pynccl_comm.device) * pynccl_comm.rank
|
* pynccl_comm.rank
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(pynccl_comm.world_size):
|
for i in range(pynccl_comm.world_size):
|
||||||
pynccl_comm.broadcast(recv_tensors[i], src=i)
|
pynccl_comm.broadcast(recv_tensors[i], src=i)
|
||||||
|
|||||||
@@ -8,20 +8,20 @@ import ray
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from vllm.distributed.communication_op import ( # noqa
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||||
tensor_model_parallel_all_reduce)
|
|
||||||
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import (ensure_model_parallel_initialized,
|
from ..utils import (
|
||||||
init_test_distributed_environment, multi_process_parallel)
|
ensure_model_parallel_initialized,
|
||||||
|
init_test_distributed_environment,
|
||||||
|
multi_process_parallel,
|
||||||
|
)
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
random.seed(44)
|
random.seed(44)
|
||||||
# Size over 8MB is sufficient for custom quick allreduce.
|
# Size over 8MB is sufficient for custom quick allreduce.
|
||||||
test_sizes = [
|
test_sizes = [random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)]
|
||||||
random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)
|
|
||||||
]
|
|
||||||
for i, v in enumerate(test_sizes):
|
for i, v in enumerate(test_sizes):
|
||||||
test_sizes[i] -= v % 8
|
test_sizes[i] -= v % 8
|
||||||
|
|
||||||
@@ -38,8 +38,7 @@ def graph_quickreduce(
|
|||||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||||
group = get_tp_group().device_group
|
group = get_tp_group().device_group
|
||||||
|
|
||||||
@@ -64,18 +63,15 @@ def graph_quickreduce(
|
|||||||
for sz in test_sizes:
|
for sz in test_sizes:
|
||||||
for dtype in [torch.float16, torch.bfloat16]:
|
for dtype in [torch.float16, torch.bfloat16]:
|
||||||
with graph_capture(device=device) as graph_capture_context:
|
with graph_capture(device=device) as graph_capture_context:
|
||||||
inp1 = torch.randint(1,
|
inp1 = torch.randint(
|
||||||
23, (sz, ),
|
1, 23, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||||
dtype=dtype,
|
)
|
||||||
device=torch.cuda.current_device())
|
inp2 = torch.randint(
|
||||||
inp2 = torch.randint(-23,
|
-23, 1, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||||
1, (sz, ),
|
)
|
||||||
dtype=dtype,
|
|
||||||
device=torch.cuda.current_device())
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(graph,
|
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
|
||||||
stream=graph_capture_context.stream):
|
|
||||||
for _ in range(num_communication):
|
for _ in range(num_communication):
|
||||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||||
dist.all_reduce(inp1, group=group)
|
dist.all_reduce(inp1, group=group)
|
||||||
@@ -99,39 +95,42 @@ def eager_quickreduce(
|
|||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||||
distributed_init_port)
|
|
||||||
|
|
||||||
# Size over 8MB is sufficient for custom quick allreduce.
|
# Size over 8MB is sufficient for custom quick allreduce.
|
||||||
sz = 16 * 1024 * 1024
|
sz = 16 * 1024 * 1024
|
||||||
fa = get_tp_group().device_communicator.qr_comm
|
fa = get_tp_group().device_communicator.qr_comm
|
||||||
inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)],
|
inp = torch.tensor(
|
||||||
dtype=torch.float16,
|
[1.0 * ((i) % 23) for i in range(sz)], dtype=torch.float16, device=device
|
||||||
device=device)
|
)
|
||||||
out = fa.quick_all_reduce(inp)
|
out = fa.quick_all_reduce(inp)
|
||||||
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
|
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
|
||||||
|
|
||||||
inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)],
|
inp = torch.tensor(
|
||||||
dtype=torch.bfloat16,
|
[1.0 * ((i) % 23) for i in range(sz)], dtype=torch.bfloat16, device=device
|
||||||
device=device)
|
)
|
||||||
out = fa.quick_all_reduce(inp)
|
out = fa.quick_all_reduce(inp)
|
||||||
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
|
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
@pytest.mark.skipif(
|
||||||
reason="only test quick allreduce for rocm")
|
not current_platform.is_rocm(), reason="only test quick allreduce for rocm"
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"])
|
@pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"])
|
||||||
@pytest.mark.parametrize("tp_size", [2])
|
@pytest.mark.parametrize("tp_size", [2])
|
||||||
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
|
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
|
||||||
@pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce])
|
@pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce])
|
||||||
def test_custom_quick_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
|
def test_custom_quick_allreduce(
|
||||||
pipeline_parallel_size, test_target,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
quant_mode):
|
tp_size,
|
||||||
|
pipeline_parallel_size,
|
||||||
|
test_target,
|
||||||
|
quant_mode,
|
||||||
|
):
|
||||||
world_size = tp_size * pipeline_parallel_size
|
world_size = tp_size * pipeline_parallel_size
|
||||||
if world_size > torch.cuda.device_count():
|
if world_size > torch.cuda.device_count():
|
||||||
pytest.skip("Not enough GPUs to run the test.")
|
pytest.skip("Not enough GPUs to run the test.")
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
|
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
|
||||||
|
|
||||||
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size,
|
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
|
||||||
test_target)
|
|
||||||
|
|||||||
@@ -22,15 +22,13 @@ if __name__ == "__main__":
|
|||||||
dist.broadcast_object_list(recv, src=0)
|
dist.broadcast_object_list(recv, src=0)
|
||||||
ip, port = recv
|
ip, port = recv
|
||||||
|
|
||||||
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
|
stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
|
||||||
dist.get_world_size())
|
|
||||||
|
|
||||||
for pg in [dist.group.WORLD, stateless_pg]:
|
for pg in [dist.group.WORLD, stateless_pg]:
|
||||||
test_result = all(in_the_same_node_as(pg, source_rank=0))
|
test_result = all(in_the_same_node_as(pg, source_rank=0))
|
||||||
|
|
||||||
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
|
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
|
||||||
assert test_result == expected, \
|
assert test_result == expected, f"Expected {expected}, got {test_result}"
|
||||||
f"Expected {expected}, got {test_result}"
|
|
||||||
if pg == dist.group.WORLD:
|
if pg == dist.group.WORLD:
|
||||||
print("Same node test passed! when using torch distributed!")
|
print("Same node test passed! when using torch distributed!")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
|||||||
all workers in a node other than the head node, which can cause the test
|
all workers in a node other than the head node, which can cause the test
|
||||||
to fail.
|
to fail.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -56,7 +57,8 @@ class SPTestSettings:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Length mismatch: distributed_backends "
|
f"Length mismatch: distributed_backends "
|
||||||
f"({len(self.distributed_backends)}) != "
|
f"({len(self.distributed_backends)}) != "
|
||||||
f"vllm_major_versions ({len(self.vllm_major_versions)})")
|
f"vllm_major_versions ({len(self.vllm_major_versions)})"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def detailed(
|
def detailed(
|
||||||
@@ -72,18 +74,22 @@ class SPTestSettings:
|
|||||||
for pp_multiplier in [1, 2]:
|
for pp_multiplier in [1, 2]:
|
||||||
for chunked_prefill_val in [False, True]:
|
for chunked_prefill_val in [False, True]:
|
||||||
parallel_setups.append(
|
parallel_setups.append(
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(
|
||||||
pp_size=pp_multiplier * pp_base,
|
tp_size=tp_base,
|
||||||
enable_fusion=False,
|
pp_size=pp_multiplier * pp_base,
|
||||||
eager_mode=eager_mode_val,
|
enable_fusion=False,
|
||||||
chunked_prefill=chunked_prefill_val))
|
eager_mode=eager_mode_val,
|
||||||
|
chunked_prefill=chunked_prefill_val,
|
||||||
|
)
|
||||||
|
)
|
||||||
return SPTestSettings(
|
return SPTestSettings(
|
||||||
parallel_setups=parallel_setups,
|
parallel_setups=parallel_setups,
|
||||||
distributed_backends=["mp", "ray"],
|
distributed_backends=["mp", "ray"],
|
||||||
vllm_major_versions=["1", "1"],
|
vllm_major_versions=["1", "1"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=SPTestOptions(multi_node_only=multi_node_only,
|
test_options=SPTestOptions(
|
||||||
load_format=load_format),
|
multi_node_only=multi_node_only, load_format=load_format
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -100,18 +106,22 @@ class SPTestSettings:
|
|||||||
for pp_multiplier in [1, 2]:
|
for pp_multiplier in [1, 2]:
|
||||||
for chunked_prefill_val in [False, True]:
|
for chunked_prefill_val in [False, True]:
|
||||||
parallel_setups.append(
|
parallel_setups.append(
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(
|
||||||
pp_size=pp_multiplier * pp_base,
|
tp_size=tp_base,
|
||||||
enable_fusion=False,
|
pp_size=pp_multiplier * pp_base,
|
||||||
eager_mode=eager_mode_val,
|
enable_fusion=False,
|
||||||
chunked_prefill=chunked_prefill_val))
|
eager_mode=eager_mode_val,
|
||||||
|
chunked_prefill=chunked_prefill_val,
|
||||||
|
)
|
||||||
|
)
|
||||||
return SPTestSettings(
|
return SPTestSettings(
|
||||||
parallel_setups=parallel_setups,
|
parallel_setups=parallel_setups,
|
||||||
distributed_backends=["mp", "ray"],
|
distributed_backends=["mp", "ray"],
|
||||||
vllm_major_versions=["1", "1"],
|
vllm_major_versions=["1", "1"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=SPTestOptions(multi_node_only=multi_node_only,
|
test_options=SPTestOptions(
|
||||||
load_format=load_format),
|
multi_node_only=multi_node_only, load_format=load_format
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -126,28 +136,39 @@ class SPTestSettings:
|
|||||||
parallel_setups = []
|
parallel_setups = []
|
||||||
for fusion_val in [False, True]:
|
for fusion_val in [False, True]:
|
||||||
parallel_setups.append(
|
parallel_setups.append(
|
||||||
ParallelSetup(tp_size=tp_base,
|
ParallelSetup(
|
||||||
pp_size=pp_base,
|
tp_size=tp_base,
|
||||||
enable_fusion=fusion_val,
|
pp_size=pp_base,
|
||||||
eager_mode=True,
|
enable_fusion=fusion_val,
|
||||||
chunked_prefill=False))
|
eager_mode=True,
|
||||||
|
chunked_prefill=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
return SPTestSettings(
|
return SPTestSettings(
|
||||||
parallel_setups=parallel_setups,
|
parallel_setups=parallel_setups,
|
||||||
distributed_backends=["mp", "ray"],
|
distributed_backends=["mp", "ray"],
|
||||||
vllm_major_versions=["1", "1"],
|
vllm_major_versions=["1", "1"],
|
||||||
runner=runner,
|
runner=runner,
|
||||||
test_options=SPTestOptions(multi_node_only=multi_node_only,
|
test_options=SPTestOptions(
|
||||||
load_format=load_format),
|
multi_node_only=multi_node_only, load_format=load_format
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_params(self, model_id: str):
|
def iter_params(self, model_id: str):
|
||||||
opts = self.test_options
|
opts = self.test_options
|
||||||
|
|
||||||
for parallel_setup in self.parallel_setups:
|
for parallel_setup in self.parallel_setups:
|
||||||
for backend, vllm_major_version in zip(self.distributed_backends,
|
for backend, vllm_major_version in zip(
|
||||||
self.vllm_major_versions):
|
self.distributed_backends, self.vllm_major_versions
|
||||||
yield (model_id, parallel_setup, backend, vllm_major_version,
|
):
|
||||||
self.runner, opts)
|
yield (
|
||||||
|
model_id,
|
||||||
|
parallel_setup,
|
||||||
|
backend,
|
||||||
|
vllm_major_version,
|
||||||
|
self.runner,
|
||||||
|
opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _compare_sp(
|
def _compare_sp(
|
||||||
@@ -200,8 +221,10 @@ def _compare_sp(
|
|||||||
if num_gpus_available < tp_size * pp_size:
|
if num_gpus_available < tp_size * pp_size:
|
||||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
pytest.skip(
|
||||||
"multiprocessing distributed backend")
|
"Skipping multi-node pipeline parallel test for "
|
||||||
|
"multiprocessing distributed backend"
|
||||||
|
)
|
||||||
if multi_node_only and not VLLM_MULTI_NODE:
|
if multi_node_only and not VLLM_MULTI_NODE:
|
||||||
pytest.skip("Not in multi-node setting")
|
pytest.skip("Not in multi-node setting")
|
||||||
|
|
||||||
@@ -232,13 +255,13 @@ def _compare_sp(
|
|||||||
common_args.append("--skip-tokenizer-init")
|
common_args.append("--skip-tokenizer-init")
|
||||||
|
|
||||||
compilation_config = {
|
compilation_config = {
|
||||||
'level': 3,
|
"level": 3,
|
||||||
'custom_ops': ["+rms_norm"],
|
"custom_ops": ["+rms_norm"],
|
||||||
'compile_sizes': [4, 8],
|
"compile_sizes": [4, 8],
|
||||||
'pass_config': {
|
"pass_config": {
|
||||||
'enable_sequence_parallelism': True,
|
"enable_sequence_parallelism": True,
|
||||||
'enable_fusion': enable_fusion,
|
"enable_fusion": enable_fusion,
|
||||||
'enable_noop': True,
|
"enable_noop": True,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -270,12 +293,9 @@ def _compare_sp(
|
|||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
compare_two_settings(model_id,
|
compare_two_settings(
|
||||||
tp_sp_args,
|
model_id, tp_sp_args, tp_args, tp_sp_env, tp_env, method=method
|
||||||
tp_args,
|
)
|
||||||
tp_sp_env,
|
|
||||||
tp_env,
|
|
||||||
method=method)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
testing_ray_compiled_graph = tp_sp_env is not None
|
testing_ray_compiled_graph = tp_sp_env is not None
|
||||||
if testing_ray_compiled_graph and vllm_major_version == "0":
|
if testing_ray_compiled_graph and vllm_major_version == "0":
|
||||||
@@ -301,10 +321,17 @@ SP_TEST_MODELS = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
(
|
||||||
"runner", "test_options"),
|
"model_id",
|
||||||
|
"parallel_setup",
|
||||||
|
"distributed_backend",
|
||||||
|
"vllm_major_version",
|
||||||
|
"runner",
|
||||||
|
"test_options",
|
||||||
|
),
|
||||||
[
|
[
|
||||||
params for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
|
params
|
||||||
|
for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
|
||||||
for params in settings.iter_params(model_id)
|
for params in settings.iter_params(model_id)
|
||||||
if model_id in SP_TEST_MODELS
|
if model_id in SP_TEST_MODELS
|
||||||
],
|
],
|
||||||
@@ -319,12 +346,14 @@ def test_tp_sp_generation(
|
|||||||
test_options: SPTestOptions,
|
test_options: SPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
_compare_sp(model_id,
|
_compare_sp(
|
||||||
parallel_setup,
|
model_id,
|
||||||
distributed_backend,
|
parallel_setup,
|
||||||
vllm_major_version,
|
distributed_backend,
|
||||||
runner,
|
vllm_major_version,
|
||||||
test_options,
|
runner,
|
||||||
num_gpus_available,
|
test_options,
|
||||||
method="generate",
|
num_gpus_available,
|
||||||
is_multimodal=False)
|
method="generate",
|
||||||
|
is_multimodal=False,
|
||||||
|
)
|
||||||
|
|||||||
@@ -26,13 +26,13 @@ def distributed_run(fn, world_size):
|
|||||||
processes = []
|
processes = []
|
||||||
for i in range(number_of_processes):
|
for i in range(number_of_processes):
|
||||||
env = {}
|
env = {}
|
||||||
env['RANK'] = str(i)
|
env["RANK"] = str(i)
|
||||||
env['LOCAL_RANK'] = str(i)
|
env["LOCAL_RANK"] = str(i)
|
||||||
env['WORLD_SIZE'] = str(number_of_processes)
|
env["WORLD_SIZE"] = str(number_of_processes)
|
||||||
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
|
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||||
env['MASTER_ADDR'] = 'localhost'
|
env["MASTER_ADDR"] = "localhost"
|
||||||
env['MASTER_PORT'] = '12345'
|
env["MASTER_PORT"] = "12345"
|
||||||
p = multiprocessing.Process(target=fn, args=(env, ))
|
p = multiprocessing.Process(target=fn, args=(env,))
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
@@ -57,25 +57,23 @@ def worker_fn_wrapper(fn):
|
|||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def worker_fn():
|
def worker_fn():
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
port = get_open_port()
|
port = get_open_port()
|
||||||
ip = '127.0.0.1'
|
ip = "127.0.0.1"
|
||||||
dist.broadcast_object_list([ip, port], src=0)
|
dist.broadcast_object_list([ip, port], src=0)
|
||||||
else:
|
else:
|
||||||
recv = [None, None]
|
recv = [None, None]
|
||||||
dist.broadcast_object_list(recv, src=0)
|
dist.broadcast_object_list(recv, src=0)
|
||||||
ip, port = recv # type: ignore
|
ip, port = recv # type: ignore
|
||||||
|
|
||||||
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
|
stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
|
||||||
dist.get_world_size())
|
|
||||||
|
|
||||||
for pg in [dist.group.WORLD, stateless_pg]:
|
for pg in [dist.group.WORLD, stateless_pg]:
|
||||||
|
|
||||||
writer_rank = 2
|
writer_rank = 2
|
||||||
broadcaster = MessageQueue.create_from_process_group(
|
broadcaster = MessageQueue.create_from_process_group(
|
||||||
pg, 40 * 1024, 2, writer_rank)
|
pg, 40 * 1024, 2, writer_rank
|
||||||
|
)
|
||||||
if rank == writer_rank:
|
if rank == writer_rank:
|
||||||
seed = random.randint(0, 1000)
|
seed = random.randint(0, 1000)
|
||||||
dist.broadcast_object_list([seed], writer_rank)
|
dist.broadcast_object_list([seed], writer_rank)
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ import traceback
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.shm_object_storage import (
|
from vllm.distributed.device_communicators.shm_object_storage import (
|
||||||
SingleWriterShmRingBuffer)
|
SingleWriterShmRingBuffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
||||||
@@ -25,18 +26,21 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
|||||||
"""Test opening an existing buffer"""
|
"""Test opening an existing buffer"""
|
||||||
# First create a buffer
|
# First create a buffer
|
||||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||||
data_buffer_size=self.buffer_size, create=True)
|
data_buffer_size=self.buffer_size, create=True
|
||||||
|
)
|
||||||
|
|
||||||
# Then open it with another instance
|
# Then open it with another instance
|
||||||
reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle())
|
reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle())
|
||||||
self.assertFalse(reader_buffer.is_writer)
|
self.assertFalse(reader_buffer.is_writer)
|
||||||
self.assertEqual(reader_buffer.shared_memory.name,
|
self.assertEqual(
|
||||||
self.ring_buffer.shared_memory.name)
|
reader_buffer.shared_memory.name, self.ring_buffer.shared_memory.name
|
||||||
|
)
|
||||||
|
|
||||||
def test_buffer_access(self):
|
def test_buffer_access(self):
|
||||||
"""Test accessing allocated buffers"""
|
"""Test accessing allocated buffers"""
|
||||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||||
data_buffer_size=self.buffer_size, create=True)
|
data_buffer_size=self.buffer_size, create=True
|
||||||
|
)
|
||||||
|
|
||||||
size = 100
|
size = 100
|
||||||
address, monotonic_id = self.ring_buffer.allocate_buf(size)
|
address, monotonic_id = self.ring_buffer.allocate_buf(size)
|
||||||
@@ -44,11 +48,11 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
|||||||
# Write some test data
|
# Write some test data
|
||||||
test_data = b"Hello, World!" * 7 # 91 bytes
|
test_data = b"Hello, World!" * 7 # 91 bytes
|
||||||
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
|
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
|
||||||
data_buf[0:len(test_data)] = test_data
|
data_buf[0 : len(test_data)] = test_data
|
||||||
|
|
||||||
# Read it back
|
# Read it back
|
||||||
with self.ring_buffer.access_buf(address) as (data_buf2, metadata2):
|
with self.ring_buffer.access_buf(address) as (data_buf2, metadata2):
|
||||||
read_data = bytes(data_buf2[0:len(test_data)])
|
read_data = bytes(data_buf2[0 : len(test_data)])
|
||||||
read_id = metadata2[0]
|
read_id = metadata2[0]
|
||||||
|
|
||||||
self.assertEqual(read_data, test_data)
|
self.assertEqual(read_data, test_data)
|
||||||
@@ -58,7 +62,8 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
|||||||
"""Test that MemoryError is raised when buffer is full"""
|
"""Test that MemoryError is raised when buffer is full"""
|
||||||
small_buffer_size = 200
|
small_buffer_size = 200
|
||||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||||
data_buffer_size=small_buffer_size, create=True)
|
data_buffer_size=small_buffer_size, create=True
|
||||||
|
)
|
||||||
|
|
||||||
# Fill up the buffer
|
# Fill up the buffer
|
||||||
self.ring_buffer.allocate_buf(100)
|
self.ring_buffer.allocate_buf(100)
|
||||||
@@ -72,7 +77,8 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
|||||||
"""Test allocation and freeing of buffers"""
|
"""Test allocation and freeing of buffers"""
|
||||||
small_buffer_size = 200
|
small_buffer_size = 200
|
||||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||||
data_buffer_size=small_buffer_size, create=True)
|
data_buffer_size=small_buffer_size, create=True
|
||||||
|
)
|
||||||
|
|
||||||
size = 80
|
size = 80
|
||||||
# Write some data
|
# Write some data
|
||||||
@@ -81,7 +87,7 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
|||||||
address, monotonic_id = self.ring_buffer.allocate_buf(size)
|
address, monotonic_id = self.ring_buffer.allocate_buf(size)
|
||||||
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
|
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
|
||||||
data_buf[0:4] = (0).to_bytes(4, "little") # 0 for not in-use
|
data_buf[0:4] = (0).to_bytes(4, "little") # 0 for not in-use
|
||||||
data_buf[4:len(test_data) + 4] = test_data
|
data_buf[4 : len(test_data) + 4] = test_data
|
||||||
print(self.ring_buffer.metadata)
|
print(self.ring_buffer.metadata)
|
||||||
freed_ids = self.ring_buffer.free_buf(lambda *args: True)
|
freed_ids = self.ring_buffer.free_buf(lambda *args: True)
|
||||||
print(f" Freed IDs: {freed_ids}")
|
print(f" Freed IDs: {freed_ids}")
|
||||||
@@ -90,7 +96,8 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
|||||||
def test_clear_buffer(self):
|
def test_clear_buffer(self):
|
||||||
"""Test clearing the buffer"""
|
"""Test clearing the buffer"""
|
||||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||||
data_buffer_size=self.buffer_size, create=True)
|
data_buffer_size=self.buffer_size, create=True
|
||||||
|
)
|
||||||
|
|
||||||
# Allocate some buffers
|
# Allocate some buffers
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
@@ -121,8 +128,7 @@ def main():
|
|||||||
# Manual demonstration
|
# Manual demonstration
|
||||||
try:
|
try:
|
||||||
print("Creating ring buffer...")
|
print("Creating ring buffer...")
|
||||||
writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048,
|
writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048, create=True)
|
||||||
create=True)
|
|
||||||
reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle())
|
reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle())
|
||||||
|
|
||||||
print(f"Buffer created with name: {writer_buffer.shared_memory.name}")
|
print(f"Buffer created with name: {writer_buffer.shared_memory.name}")
|
||||||
@@ -140,7 +146,7 @@ def main():
|
|||||||
# Write some test data
|
# Write some test data
|
||||||
with writer_buffer.access_buf(address) as (data_buf, metadata):
|
with writer_buffer.access_buf(address) as (data_buf, metadata):
|
||||||
test_message = f"Test message {i}".encode()
|
test_message = f"Test message {i}".encode()
|
||||||
data_buf[0:len(test_message)] = test_message
|
data_buf[0 : len(test_message)] = test_message
|
||||||
|
|
||||||
except MemoryError as e:
|
except MemoryError as e:
|
||||||
print(f" Failed to allocate {size} bytes: {e}")
|
print(f" Failed to allocate {size} bytes: {e}")
|
||||||
|
|||||||
@@ -12,28 +12,33 @@ import torch
|
|||||||
|
|
||||||
# Assuming these are imported from your module
|
# Assuming these are imported from your module
|
||||||
from vllm.distributed.device_communicators.shm_object_storage import (
|
from vllm.distributed.device_communicators.shm_object_storage import (
|
||||||
MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer)
|
MsgpackSerde,
|
||||||
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
|
SingleWriterShmObjectStorage,
|
||||||
MultiModalSharedField)
|
SingleWriterShmRingBuffer,
|
||||||
|
)
|
||||||
|
from vllm.multimodal.inputs import (
|
||||||
|
MultiModalFieldElem,
|
||||||
|
MultiModalKwargsItem,
|
||||||
|
MultiModalSharedField,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _dummy_elem(modality: str, key: str, size: int):
|
def _dummy_elem(modality: str, key: str, size: int):
|
||||||
return MultiModalFieldElem(
|
return MultiModalFieldElem(
|
||||||
modality=modality,
|
modality=modality,
|
||||||
key=key,
|
key=key,
|
||||||
data=torch.empty((size, ), dtype=torch.int8),
|
data=torch.empty((size,), dtype=torch.int8),
|
||||||
field=MultiModalSharedField(1),
|
field=MultiModalSharedField(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _dummy_item(modality: str, size_by_key: dict[str, int]):
|
def _dummy_item(modality: str, size_by_key: dict[str, int]):
|
||||||
return MultiModalKwargsItem.from_elems([
|
return MultiModalKwargsItem.from_elems(
|
||||||
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
|
[_dummy_elem(modality, key, size) for key, size in size_by_key.items()]
|
||||||
])
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSingleWriterShmObjectStorage(unittest.TestCase):
|
class TestSingleWriterShmObjectStorage(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Set up test fixtures before each test method."""
|
"""Set up test fixtures before each test method."""
|
||||||
ring_buffer = SingleWriterShmRingBuffer(
|
ring_buffer = SingleWriterShmRingBuffer(
|
||||||
@@ -208,8 +213,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError) as context:
|
with self.assertRaises(ValueError) as context:
|
||||||
self.storage.get(address, monotonic_id + 100)
|
self.storage.get(address, monotonic_id + 100)
|
||||||
|
|
||||||
self.assertIn("has been modified or is invalid", \
|
self.assertIn("has been modified or is invalid", str(context.exception))
|
||||||
str(context.exception))
|
|
||||||
|
|
||||||
def test_clear_storage(self):
|
def test_clear_storage(self):
|
||||||
"""Test clearing the storage."""
|
"""Test clearing the storage."""
|
||||||
@@ -234,8 +238,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase):
|
|||||||
# Reader process function
|
# Reader process function
|
||||||
def reader_process(process_id, storage_handle, items_to_read):
|
def reader_process(process_id, storage_handle, items_to_read):
|
||||||
"""Reader process that connects to existing shared memory and reads data."""
|
"""Reader process that connects to existing shared memory and reads data."""
|
||||||
reader_storage = SingleWriterShmObjectStorage.create_from_handle(
|
reader_storage = SingleWriterShmObjectStorage.create_from_handle(storage_handle)
|
||||||
storage_handle)
|
|
||||||
|
|
||||||
print(f"Reader {process_id} started")
|
print(f"Reader {process_id} started")
|
||||||
|
|
||||||
@@ -276,11 +279,7 @@ def run_multiprocess_example():
|
|||||||
|
|
||||||
# Test basic data types
|
# Test basic data types
|
||||||
test_data = [
|
test_data = [
|
||||||
("user_data", {
|
("user_data", {"name": "Alice", "age": 30, "scores": [95, 87, 92]}),
|
||||||
"name": "Alice",
|
|
||||||
"age": 30,
|
|
||||||
"scores": [95, 87, 92]
|
|
||||||
}),
|
|
||||||
("simple_string", "Hello, World!"),
|
("simple_string", "Hello, World!"),
|
||||||
("number", 42),
|
("number", 42),
|
||||||
("list_data", [1, 2, 3, "four", 5.0]),
|
("list_data", [1, 2, 3, "four", 5.0]),
|
||||||
@@ -301,8 +300,9 @@ def run_multiprocess_example():
|
|||||||
# initialize lock for reader processes
|
# initialize lock for reader processes
|
||||||
handle.reader_lock = Lock()
|
handle.reader_lock = Lock()
|
||||||
for i in range(storage.n_readers):
|
for i in range(storage.n_readers):
|
||||||
p = multiprocessing.Process(target=reader_process,
|
p = multiprocessing.Process(
|
||||||
args=(i, handle, stored_items))
|
target=reader_process, args=(i, handle, stored_items)
|
||||||
|
)
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
|
|||||||
@@ -14,11 +14,12 @@ import vllm.envs as envs
|
|||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
|
||||||
CudaCommunicator)
|
from vllm.distributed.parallel_state import (
|
||||||
from vllm.distributed.parallel_state import (get_tp_group,
|
get_tp_group,
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
initialize_model_parallel)
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@@ -32,8 +33,7 @@ test_size_elements = 1024 * 1024
|
|||||||
|
|
||||||
def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
|
def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
|
||||||
monkeypatch = pytest.MonkeyPatch()
|
monkeypatch = pytest.MonkeyPatch()
|
||||||
config = VllmConfig(parallel_config=ParallelConfig(
|
config = VllmConfig(parallel_config=ParallelConfig(tensor_parallel_size=world_size))
|
||||||
tensor_parallel_size=world_size))
|
|
||||||
|
|
||||||
with monkeypatch.context() as m, set_current_vllm_config(config):
|
with monkeypatch.context() as m, set_current_vllm_config(config):
|
||||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
@@ -42,34 +42,34 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
|
|||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
update_environment_variables({
|
update_environment_variables(
|
||||||
'RANK': str(local_rank),
|
{
|
||||||
'LOCAL_RANK': str(local_rank),
|
"RANK": str(local_rank),
|
||||||
'WORLD_SIZE': str(world_size),
|
"LOCAL_RANK": str(local_rank),
|
||||||
'MASTER_ADDR': 'localhost',
|
"WORLD_SIZE": str(world_size),
|
||||||
'MASTER_PORT': '12345',
|
"MASTER_ADDR": "localhost",
|
||||||
})
|
"MASTER_PORT": "12345",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
|
||||||
cuda_communicator = typing.cast(CudaCommunicator,
|
cuda_communicator = typing.cast(
|
||||||
get_tp_group().device_communicator)
|
CudaCommunicator, get_tp_group().device_communicator
|
||||||
|
)
|
||||||
symm_mem_comm = cuda_communicator.symm_mem_comm
|
symm_mem_comm = cuda_communicator.symm_mem_comm
|
||||||
if symm_mem_comm is None or symm_mem_comm.disabled:
|
if symm_mem_comm is None or symm_mem_comm.disabled:
|
||||||
# can't use skip under multiprocessing
|
# can't use skip under multiprocessing
|
||||||
q.put("SymmMemCommunicator is not available or disabled.")
|
q.put("SymmMemCommunicator is not available or disabled.")
|
||||||
return
|
return
|
||||||
|
|
||||||
inp_direct_symm_mem = torch.randint(1,
|
inp_direct_symm_mem = torch.randint(
|
||||||
23, (test_size_elements, ),
|
1, 23, (test_size_elements,), dtype=dtype, device=device
|
||||||
dtype=dtype,
|
)
|
||||||
device=device)
|
|
||||||
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
|
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
|
||||||
# can't use skip under multiprocessing
|
# can't use skip under multiprocessing
|
||||||
q.put(
|
q.put("SymmMemCommunicator isn't used for this world and input size.")
|
||||||
"SymmMemCommunicator isn't used for this world and input size."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
|
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
|
||||||
@@ -78,42 +78,37 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
|
|||||||
|
|
||||||
group = get_tp_group().device_group
|
group = get_tp_group().device_group
|
||||||
dist.all_reduce(original_inp_direct_symm_mem, group=group)
|
dist.all_reduce(original_inp_direct_symm_mem, group=group)
|
||||||
torch.testing.assert_close(out_direct_symm_mem,
|
torch.testing.assert_close(
|
||||||
original_inp_direct_symm_mem,
|
out_direct_symm_mem, original_inp_direct_symm_mem, atol=2.5, rtol=0.1
|
||||||
atol=2.5,
|
)
|
||||||
rtol=0.1)
|
|
||||||
|
|
||||||
# Test tensor_model_parallel_all_reduce which should use symm_mem
|
# Test tensor_model_parallel_all_reduce which should use symm_mem
|
||||||
inp_tensor_parallel = torch.randint(-23,
|
inp_tensor_parallel = torch.randint(
|
||||||
1, (test_size_elements, ),
|
-23, 1, (test_size_elements,), dtype=dtype, device=device
|
||||||
dtype=dtype,
|
)
|
||||||
device=device)
|
|
||||||
original_inp_tensor_parallel = inp_tensor_parallel.clone()
|
original_inp_tensor_parallel = inp_tensor_parallel.clone()
|
||||||
out_tensor_parallel = tensor_model_parallel_all_reduce(
|
out_tensor_parallel = tensor_model_parallel_all_reduce(inp_tensor_parallel)
|
||||||
inp_tensor_parallel)
|
|
||||||
dist.all_reduce(original_inp_tensor_parallel, group=group)
|
dist.all_reduce(original_inp_tensor_parallel, group=group)
|
||||||
torch.testing.assert_close(out_tensor_parallel,
|
torch.testing.assert_close(
|
||||||
original_inp_tensor_parallel,
|
out_tensor_parallel, original_inp_tensor_parallel, atol=2.5, rtol=0.1
|
||||||
atol=2.5,
|
)
|
||||||
rtol=0.1)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not current_platform.is_cuda(),
|
not current_platform.is_cuda(),
|
||||||
reason="SymmMemAllreduce is only available for CUDA platforms.")
|
reason="SymmMemAllreduce is only available for CUDA platforms.",
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("tp_size", [2])
|
@pytest.mark.parametrize("tp_size", [2])
|
||||||
@pytest.mark.parametrize("pipeline_parallel_size", [1])
|
@pytest.mark.parametrize("pipeline_parallel_size", [1])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||||
reason="Only test on CUDA")
|
def test_symm_mem_allreduce(
|
||||||
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
|
monkeypatch: pytest.MonkeyPatch, tp_size, pipeline_parallel_size
|
||||||
pipeline_parallel_size):
|
):
|
||||||
world_size = tp_size * pipeline_parallel_size
|
world_size = tp_size * pipeline_parallel_size
|
||||||
if world_size > torch.cuda.device_count():
|
if world_size > torch.cuda.device_count():
|
||||||
pytest.skip("Not enough GPUs to run the test.")
|
pytest.skip("Not enough GPUs to run the test.")
|
||||||
q = mp.get_context('spawn').Queue()
|
q = mp.get_context("spawn").Queue()
|
||||||
mp.spawn(symm_mem_allreduce_worker,
|
mp.spawn(symm_mem_allreduce_worker, args=(world_size, q), nprocs=world_size)
|
||||||
args=(world_size, q),
|
|
||||||
nprocs=world_size)
|
|
||||||
try:
|
try:
|
||||||
val = q.get(timeout=1)
|
val = q.get(timeout=1)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
@@ -126,18 +121,20 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
|
|||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not current_platform.is_cuda(),
|
not current_platform.is_cuda(),
|
||||||
reason="SymmMemAllreduce is only available for CUDA platforms.")
|
reason="SymmMemAllreduce is only available for CUDA platforms.",
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
)
|
||||||
reason="Only test on CUDA")
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||||
def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch):
|
def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch):
|
||||||
world_size = 4
|
world_size = 4
|
||||||
if world_size > torch.cuda.device_count():
|
if world_size > torch.cuda.device_count():
|
||||||
pytest.skip("Not enough GPUs to run the test.")
|
pytest.skip("Not enough GPUs to run the test.")
|
||||||
# Verify that the DataParallel runs without error
|
# Verify that the DataParallel runs without error
|
||||||
engine_args = EngineArgs(model="distilbert/distilgpt2",
|
engine_args = EngineArgs(
|
||||||
enforce_eager=True,
|
model="distilbert/distilgpt2",
|
||||||
enable_prefix_caching=True,
|
enforce_eager=True,
|
||||||
data_parallel_size=2,
|
enable_prefix_caching=True,
|
||||||
tensor_parallel_size=2,
|
data_parallel_size=2,
|
||||||
data_parallel_backend="mp")
|
tensor_parallel_size=2,
|
||||||
|
data_parallel_backend="mp",
|
||||||
|
)
|
||||||
LLMEngine.from_engine_args(engine_args)
|
LLMEngine.from_engine_args(engine_args)
|
||||||
|
|||||||
@@ -24,13 +24,15 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|||||||
|
|
||||||
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
|
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
|
||||||
# to test if all ranks agree on the same kv cache configuration.
|
# to test if all ranks agree on the same kv cache configuration.
|
||||||
llm = LLM(model="facebook/opt-125m",
|
llm = LLM(
|
||||||
tensor_parallel_size=2,
|
model="facebook/opt-125m",
|
||||||
pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)),
|
tensor_parallel_size=2,
|
||||||
distributed_executor_backend="external_launcher",
|
pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)),
|
||||||
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
distributed_executor_backend="external_launcher",
|
||||||
swap_space=random.randint(1, 4),
|
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
||||||
seed=0)
|
swap_space=random.randint(1, 4),
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
@@ -48,15 +50,14 @@ def test_consistent_across_ranks(obj):
|
|||||||
assert container[0] == obj
|
assert container[0] == obj
|
||||||
|
|
||||||
|
|
||||||
test_consistent_across_ranks(
|
test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
||||||
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
||||||
test_consistent_across_ranks(
|
|
||||||
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
|
||||||
|
|
||||||
# make sure we can access the model parameters from the calling process
|
# make sure we can access the model parameters from the calling process
|
||||||
# of the `LLM` instance.
|
# of the `LLM` instance.
|
||||||
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
|
params = list(
|
||||||
model.parameters())
|
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model.parameters()
|
||||||
|
)
|
||||||
test_consistent_across_ranks(len(params))
|
test_consistent_across_ranks(len(params))
|
||||||
|
|
||||||
# all ranks should have the same outputs
|
# all ranks should have the same outputs
|
||||||
@@ -65,5 +66,4 @@ for output in outputs:
|
|||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
test_consistent_across_ranks(prompt)
|
test_consistent_across_ranks(prompt)
|
||||||
test_consistent_across_ranks(generated_text)
|
test_consistent_across_ranks(generated_text)
|
||||||
print(f"Rank {torch_rank}, Prompt: {prompt!r}, "
|
print(f"Rank {torch_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
f"Generated text: {generated_text!r}")
|
|
||||||
|
|||||||
@@ -24,23 +24,22 @@ dp_rank = int(os.getenv("DP_RANK", "0"))
|
|||||||
|
|
||||||
if dp_size > 1:
|
if dp_size > 1:
|
||||||
# distribute the prompts across the data parallel ranks
|
# distribute the prompts across the data parallel ranks
|
||||||
prompts = [
|
prompts = [prompt for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank]
|
||||||
prompt for idx, prompt in enumerate(prompts)
|
|
||||||
if idx % dp_size == dp_rank
|
|
||||||
]
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
|
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
|
||||||
# to test if all ranks agree on the same kv cache configuration.
|
# to test if all ranks agree on the same kv cache configuration.
|
||||||
llm = LLM(model="microsoft/Phi-mini-MoE-instruct",
|
llm = LLM(
|
||||||
tensor_parallel_size=int(os.getenv("TP_SIZE", "1")),
|
model="microsoft/Phi-mini-MoE-instruct",
|
||||||
pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")),
|
tensor_parallel_size=int(os.getenv("TP_SIZE", "1")),
|
||||||
enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1,
|
pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")),
|
||||||
distributed_executor_backend="external_launcher",
|
enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1,
|
||||||
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
distributed_executor_backend="external_launcher",
|
||||||
swap_space=random.randint(1, 4),
|
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
||||||
seed=0)
|
swap_space=random.randint(1, 4),
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
@@ -54,21 +53,18 @@ def test_consistent_across_ranks(obj):
|
|||||||
dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group)
|
dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group)
|
||||||
else:
|
else:
|
||||||
container = [None]
|
container = [None]
|
||||||
dist.broadcast_object_list(container,
|
dist.broadcast_object_list(container, src=group.ranks[0], group=cpu_group)
|
||||||
src=group.ranks[0],
|
|
||||||
group=cpu_group)
|
|
||||||
assert container[0] == obj
|
assert container[0] == obj
|
||||||
|
|
||||||
|
|
||||||
test_consistent_across_ranks(
|
test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
||||||
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
||||||
test_consistent_across_ranks(
|
|
||||||
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
|
||||||
|
|
||||||
# make sure we can access the model parameters from the calling process
|
# make sure we can access the model parameters from the calling process
|
||||||
# of the `LLM` instance.
|
# of the `LLM` instance.
|
||||||
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
|
params = list(
|
||||||
model.parameters())
|
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model.parameters()
|
||||||
|
)
|
||||||
test_consistent_across_ranks(len(params))
|
test_consistent_across_ranks(len(params))
|
||||||
|
|
||||||
# all ranks should have the same outputs
|
# all ranks should have the same outputs
|
||||||
@@ -77,5 +73,4 @@ for output in outputs:
|
|||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
test_consistent_across_ranks(prompt)
|
test_consistent_across_ranks(prompt)
|
||||||
test_consistent_across_ranks(generated_text)
|
test_consistent_across_ranks(generated_text)
|
||||||
print(f"Rank {group_rank}, Prompt: {prompt!r}, "
|
print(f"Rank {group_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
f"Generated text: {generated_text!r}")
|
|
||||||
|
|||||||
@@ -10,21 +10,22 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
from vllm.distributed.utils import StatelessProcessGroup
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
from vllm.utils import (cuda_device_count_stateless, get_open_port,
|
from vllm.utils import (
|
||||||
update_environment_variables)
|
cuda_device_count_stateless,
|
||||||
|
get_open_port,
|
||||||
|
update_environment_variables,
|
||||||
|
)
|
||||||
|
|
||||||
from ..utils import multi_gpu_test
|
from ..utils import multi_gpu_test
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
class _CUDADeviceCountStatelessTestActor:
|
class _CUDADeviceCountStatelessTestActor:
|
||||||
|
|
||||||
def get_count(self):
|
def get_count(self):
|
||||||
return cuda_device_count_stateless()
|
return cuda_device_count_stateless()
|
||||||
|
|
||||||
def set_cuda_visible_devices(self, cuda_visible_devices: str):
|
def set_cuda_visible_devices(self, cuda_visible_devices: str):
|
||||||
update_environment_variables(
|
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||||
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
|
||||||
|
|
||||||
def get_cuda_visible_devices(self):
|
def get_cuda_visible_devices(self):
|
||||||
return envs.CUDA_VISIBLE_DEVICES
|
return envs.CUDA_VISIBLE_DEVICES
|
||||||
@@ -34,10 +35,9 @@ def test_cuda_device_count_stateless():
|
|||||||
"""Test that cuda_device_count_stateless changes return value if
|
"""Test that cuda_device_count_stateless changes return value if
|
||||||
CUDA_VISIBLE_DEVICES is changed."""
|
CUDA_VISIBLE_DEVICES is changed."""
|
||||||
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
|
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
|
||||||
num_gpus=2).remote()
|
num_gpus=2
|
||||||
assert len(
|
).remote()
|
||||||
sorted(ray.get(
|
assert len(sorted(ray.get(actor.get_cuda_visible_devices.remote()).split(","))) == 2
|
||||||
actor.get_cuda_visible_devices.remote()).split(","))) == 2
|
|
||||||
assert ray.get(actor.get_count.remote()) == 2
|
assert ray.get(actor.get_count.remote()) == 2
|
||||||
ray.get(actor.set_cuda_visible_devices.remote("0"))
|
ray.get(actor.set_cuda_visible_devices.remote("0"))
|
||||||
assert ray.get(actor.get_count.remote()) == 1
|
assert ray.get(actor.get_count.remote()) == 1
|
||||||
@@ -46,15 +46,13 @@ def test_cuda_device_count_stateless():
|
|||||||
|
|
||||||
|
|
||||||
def cpu_worker(rank, WORLD_SIZE, port1, port2):
|
def cpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
pg1 = StatelessProcessGroup.create(
|
||||||
port=port1,
|
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||||
rank=rank,
|
)
|
||||||
world_size=WORLD_SIZE)
|
|
||||||
if rank <= 2:
|
if rank <= 2:
|
||||||
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
pg2 = StatelessProcessGroup.create(
|
||||||
port=port2,
|
host="127.0.0.1", port=port2, rank=rank, world_size=3
|
||||||
rank=rank,
|
)
|
||||||
world_size=3)
|
|
||||||
data = torch.tensor([rank])
|
data = torch.tensor([rank])
|
||||||
data = pg1.broadcast_obj(data, src=2)
|
data = pg1.broadcast_obj(data, src=2)
|
||||||
assert data.item() == 2
|
assert data.item() == 2
|
||||||
@@ -68,16 +66,14 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2):
|
|||||||
|
|
||||||
def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
pg1 = StatelessProcessGroup.create(
|
||||||
port=port1,
|
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||||
rank=rank,
|
)
|
||||||
world_size=WORLD_SIZE)
|
|
||||||
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
||||||
if rank <= 2:
|
if rank <= 2:
|
||||||
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
pg2 = StatelessProcessGroup.create(
|
||||||
port=port2,
|
host="127.0.0.1", port=port2, rank=rank, world_size=3
|
||||||
rank=rank,
|
)
|
||||||
world_size=3)
|
|
||||||
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
||||||
data = torch.tensor([rank]).cuda()
|
data = torch.tensor([rank]).cuda()
|
||||||
pynccl1.all_reduce(data)
|
pynccl1.all_reduce(data)
|
||||||
@@ -96,10 +92,9 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
|||||||
|
|
||||||
|
|
||||||
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
|
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
pg1 = StatelessProcessGroup.create(
|
||||||
port=port1,
|
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||||
rank=rank,
|
)
|
||||||
world_size=WORLD_SIZE)
|
|
||||||
if rank == 2:
|
if rank == 2:
|
||||||
pg1.broadcast_obj("secret", src=2)
|
pg1.broadcast_obj("secret", src=2)
|
||||||
else:
|
else:
|
||||||
@@ -109,10 +104,9 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2):
|
|||||||
|
|
||||||
|
|
||||||
def allgather_worker(rank, WORLD_SIZE, port1, port2):
|
def allgather_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
pg1 = StatelessProcessGroup.create(
|
||||||
port=port1,
|
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||||
rank=rank,
|
)
|
||||||
world_size=WORLD_SIZE)
|
|
||||||
data = pg1.all_gather_obj(rank)
|
data = pg1.all_gather_obj(rank)
|
||||||
assert data == list(range(WORLD_SIZE))
|
assert data == list(range(WORLD_SIZE))
|
||||||
pg1.barrier()
|
pg1.barrier()
|
||||||
@@ -121,7 +115,8 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2):
|
|||||||
@pytest.mark.skip(reason="This test is flaky and prone to hang.")
|
@pytest.mark.skip(reason="This test is flaky and prone to hang.")
|
||||||
@multi_gpu_test(num_gpus=4)
|
@multi_gpu_test(num_gpus=4)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
|
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker]
|
||||||
|
)
|
||||||
def test_stateless_process_group(worker):
|
def test_stateless_process_group(worker):
|
||||||
port1 = get_open_port()
|
port1 = get_open_port()
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
@@ -129,12 +124,14 @@ def test_stateless_process_group(worker):
|
|||||||
port2 = get_open_port()
|
port2 = get_open_port()
|
||||||
WORLD_SIZE = 4
|
WORLD_SIZE = 4
|
||||||
from multiprocessing import get_context
|
from multiprocessing import get_context
|
||||||
|
|
||||||
ctx = get_context("fork")
|
ctx = get_context("fork")
|
||||||
processes = []
|
processes = []
|
||||||
for i in range(WORLD_SIZE):
|
for i in range(WORLD_SIZE):
|
||||||
rank = i
|
rank = i
|
||||||
processes.append(
|
processes.append(
|
||||||
ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2)))
|
ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2))
|
||||||
|
)
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.start()
|
p.start()
|
||||||
for p in processes:
|
for p in processes:
|
||||||
|
|||||||
@@ -10,22 +10,30 @@ from typing import Annotated, Literal, Optional, Union
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import CompilationConfig, config
|
from vllm.config import CompilationConfig, config
|
||||||
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
from vllm.engine.arg_utils import (
|
||||||
get_type, get_type_hints, is_not_builtin,
|
EngineArgs,
|
||||||
is_type, literal_to_kwargs, optional_type,
|
contains_type,
|
||||||
parse_type)
|
get_kwargs,
|
||||||
|
get_type,
|
||||||
|
get_type_hints,
|
||||||
|
is_not_builtin,
|
||||||
|
is_type,
|
||||||
|
literal_to_kwargs,
|
||||||
|
optional_type,
|
||||||
|
parse_type,
|
||||||
|
)
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type", "value", "expected"), [
|
@pytest.mark.parametrize(
|
||||||
(int, "42", 42),
|
("type", "value", "expected"),
|
||||||
(float, "3.14", 3.14),
|
[
|
||||||
(str, "Hello World!", "Hello World!"),
|
(int, "42", 42),
|
||||||
(json.loads, '{"foo":1,"bar":2}', {
|
(float, "3.14", 3.14),
|
||||||
"foo": 1,
|
(str, "Hello World!", "Hello World!"),
|
||||||
"bar": 2
|
(json.loads, '{"foo":1,"bar":2}', {"foo": 1, "bar": 2}),
|
||||||
}),
|
],
|
||||||
])
|
)
|
||||||
def test_parse_type(type, value, expected):
|
def test_parse_type(type, value, expected):
|
||||||
parse_type_func = parse_type(type)
|
parse_type_func = parse_type(type)
|
||||||
assert parse_type_func(value) == expected
|
assert parse_type_func(value) == expected
|
||||||
@@ -37,50 +45,56 @@ def test_optional_type():
|
|||||||
assert optional_type_func("42") == 42
|
assert optional_type_func("42") == 42
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
|
@pytest.mark.parametrize(
|
||||||
(int, int, True),
|
("type_hint", "type", "expected"),
|
||||||
(int, float, False),
|
[
|
||||||
(list[int], list, True),
|
(int, int, True),
|
||||||
(list[int], tuple, False),
|
(int, float, False),
|
||||||
(Literal[0, 1], Literal, True),
|
(list[int], list, True),
|
||||||
])
|
(list[int], tuple, False),
|
||||||
|
(Literal[0, 1], Literal, True),
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_is_type(type_hint, type, expected):
|
def test_is_type(type_hint, type, expected):
|
||||||
assert is_type(type_hint, type) == expected
|
assert is_type(type_hint, type) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
|
@pytest.mark.parametrize(
|
||||||
({float, int}, int, True),
|
("type_hints", "type", "expected"),
|
||||||
({int, tuple}, int, True),
|
[
|
||||||
({int, tuple[int]}, int, True),
|
({float, int}, int, True),
|
||||||
({int, tuple[int, ...]}, int, True),
|
({int, tuple}, int, True),
|
||||||
({int, tuple[int]}, float, False),
|
({int, tuple[int]}, int, True),
|
||||||
({int, tuple[int, ...]}, float, False),
|
({int, tuple[int, ...]}, int, True),
|
||||||
({str, Literal["x", "y"]}, Literal, True),
|
({int, tuple[int]}, float, False),
|
||||||
])
|
({int, tuple[int, ...]}, float, False),
|
||||||
|
({str, Literal["x", "y"]}, Literal, True),
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_contains_type(type_hints, type, expected):
|
def test_contains_type(type_hints, type, expected):
|
||||||
assert contains_type(type_hints, type) == expected
|
assert contains_type(type_hints, type) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
|
@pytest.mark.parametrize(
|
||||||
({int, float}, int, int),
|
("type_hints", "type", "expected"),
|
||||||
({int, float}, str, None),
|
[
|
||||||
({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
|
({int, float}, int, int),
|
||||||
])
|
({int, float}, str, None),
|
||||||
|
({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_get_type(type_hints, type, expected):
|
def test_get_type(type_hints, type, expected):
|
||||||
assert get_type(type_hints, type) == expected
|
assert get_type(type_hints, type) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type_hints", "expected"), [
|
@pytest.mark.parametrize(
|
||||||
({Literal[1, 2]}, {
|
("type_hints", "expected"),
|
||||||
"type": int,
|
[
|
||||||
"choices": [1, 2]
|
({Literal[1, 2]}, {"type": int, "choices": [1, 2]}),
|
||||||
}),
|
({str, Literal["x", "y"]}, {"type": str, "metavar": ["x", "y"]}),
|
||||||
({str, Literal["x", "y"]}, {
|
({Literal[1, "a"]}, Exception),
|
||||||
"type": str,
|
],
|
||||||
"metavar": ["x", "y"]
|
)
|
||||||
}),
|
|
||||||
({Literal[1, "a"]}, Exception),
|
|
||||||
])
|
|
||||||
def test_literal_to_kwargs(type_hints, expected):
|
def test_literal_to_kwargs(type_hints, expected):
|
||||||
context = nullcontext()
|
context = nullcontext()
|
||||||
if expected is Exception:
|
if expected is Exception:
|
||||||
@@ -123,22 +137,27 @@ class DummyConfig:
|
|||||||
"""Nested config"""
|
"""Nested config"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type_hint", "expected"), [
|
@pytest.mark.parametrize(
|
||||||
(int, False),
|
("type_hint", "expected"),
|
||||||
(DummyConfig, True),
|
[
|
||||||
])
|
(int, False),
|
||||||
|
(DummyConfig, True),
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_is_not_builtin(type_hint, expected):
|
def test_is_not_builtin(type_hint, expected):
|
||||||
assert is_not_builtin(type_hint) == expected
|
assert is_not_builtin(type_hint) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("type_hint", "expected"), [
|
("type_hint", "expected"),
|
||||||
|
[
|
||||||
(Annotated[int, "annotation"], {int}),
|
(Annotated[int, "annotation"], {int}),
|
||||||
(Optional[int], {int, type(None)}),
|
(Optional[int], {int, type(None)}),
|
||||||
(Annotated[Optional[int], "annotation"], {int, type(None)}),
|
(Annotated[Optional[int], "annotation"], {int, type(None)}),
|
||||||
(Optional[Annotated[int, "annotation"]], {int, type(None)}),
|
(Optional[Annotated[int, "annotation"]], {int, type(None)}),
|
||||||
],
|
],
|
||||||
ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"])
|
ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"],
|
||||||
|
)
|
||||||
def test_get_type_hints(type_hint, expected):
|
def test_get_type_hints(type_hint, expected):
|
||||||
assert get_type_hints(type_hint) == expected
|
assert get_type_hints(type_hint) == expected
|
||||||
|
|
||||||
@@ -178,24 +197,16 @@ def test_get_kwargs():
|
|||||||
("arg", "expected"),
|
("arg", "expected"),
|
||||||
[
|
[
|
||||||
(None, dict()),
|
(None, dict()),
|
||||||
('{"video": {"num_frames": 123} }', {
|
('{"video": {"num_frames": 123} }', {"video": {"num_frames": 123}}),
|
||||||
"video": {
|
|
||||||
"num_frames": 123
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
(
|
(
|
||||||
'{"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, "image": {"foo": "bar"} }', # noqa
|
'{"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, "image": {"foo": "bar"} }', # noqa
|
||||||
{
|
{
|
||||||
"video": {
|
"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"},
|
||||||
"num_frames": 123,
|
"image": {"foo": "bar"},
|
||||||
"fps": 1.0,
|
},
|
||||||
"foo": "bar"
|
),
|
||||||
},
|
],
|
||||||
"image": {
|
)
|
||||||
"foo": "bar"
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
])
|
|
||||||
def test_media_io_kwargs_parser(arg, expected):
|
def test_media_io_kwargs_parser(arg, expected):
|
||||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||||
if arg is None:
|
if arg is None:
|
||||||
@@ -230,24 +241,32 @@ def test_compilation_config():
|
|||||||
assert args.compilation_config.level == 3
|
assert args.compilation_config.level == 3
|
||||||
|
|
||||||
# set to string form of a dict
|
# set to string form of a dict
|
||||||
args = parser.parse_args([
|
args = parser.parse_args(
|
||||||
"-O",
|
[
|
||||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
"-O",
|
||||||
'"use_inductor": false}',
|
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||||
])
|
'"use_inductor": false}',
|
||||||
assert (args.compilation_config.level == 3 and
|
]
|
||||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
)
|
||||||
and not args.compilation_config.use_inductor)
|
assert (
|
||||||
|
args.compilation_config.level == 3
|
||||||
|
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||||
|
and not args.compilation_config.use_inductor
|
||||||
|
)
|
||||||
|
|
||||||
# set to string form of a dict
|
# set to string form of a dict
|
||||||
args = parser.parse_args([
|
args = parser.parse_args(
|
||||||
"--compilation-config="
|
[
|
||||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
"--compilation-config="
|
||||||
'"use_inductor": true}',
|
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||||
])
|
'"use_inductor": true}',
|
||||||
assert (args.compilation_config.level == 3 and
|
]
|
||||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
)
|
||||||
and args.compilation_config.use_inductor)
|
assert (
|
||||||
|
args.compilation_config.level == 3
|
||||||
|
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||||
|
and args.compilation_config.use_inductor
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_prefix_cache_default():
|
def test_prefix_cache_default():
|
||||||
@@ -255,8 +274,7 @@ def test_prefix_cache_default():
|
|||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
|
|
||||||
engine_args = EngineArgs.from_cli_args(args=args)
|
engine_args = EngineArgs.from_cli_args(args=args)
|
||||||
assert (not engine_args.enable_prefix_caching
|
assert not engine_args.enable_prefix_caching, "prefix caching defaults to off."
|
||||||
), "prefix caching defaults to off."
|
|
||||||
|
|
||||||
# with flag to turn it on.
|
# with flag to turn it on.
|
||||||
args = parser.parse_args(["--enable-prefix-caching"])
|
args = parser.parse_args(["--enable-prefix-caching"])
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ import pytest
|
|||||||
|
|
||||||
from ..conftest import IMAGE_ASSETS
|
from ..conftest import IMAGE_ASSETS
|
||||||
|
|
||||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts(
|
||||||
"stop_sign":
|
{
|
||||||
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
|
"stop_sign": "USER: <image>\nWhat's the content of the image?\nASSISTANT:",
|
||||||
"cherry_blossom":
|
"cherry_blossom": "USER: <image>\nWhat is the season?\nASSISTANT:",
|
||||||
"USER: <image>\nWhat is the season?\nASSISTANT:",
|
}
|
||||||
})
|
)
|
||||||
|
|
||||||
models = ["llava-hf/llava-1.5-7b-hf"]
|
models = ["llava-hf/llava-1.5-7b-hf"]
|
||||||
|
|
||||||
@@ -19,8 +19,7 @@ models = ["llava-hf/llava-1.5-7b-hf"]
|
|||||||
def test_context_length_too_short(vllm_runner, image_assets, model):
|
def test_context_length_too_short(vllm_runner, image_assets, model):
|
||||||
images = [asset.pil_image for asset in image_assets]
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
with pytest.raises(ValueError,
|
with pytest.raises(ValueError, match="longer than the maximum model length"):
|
||||||
match="longer than the maximum model length"):
|
|
||||||
vllm_model = vllm_runner(
|
vllm_model = vllm_runner(
|
||||||
model,
|
model,
|
||||||
max_model_len=128, # LLaVA has a feature size of 576
|
max_model_len=128, # LLaVA has a feature size of 576
|
||||||
@@ -29,6 +28,6 @@ def test_context_length_too_short(vllm_runner, image_assets, model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with vllm_model:
|
with vllm_model:
|
||||||
vllm_model.generate_greedy([HF_IMAGE_PROMPTS[0]],
|
vllm_model.generate_greedy(
|
||||||
max_tokens=1,
|
[HF_IMAGE_PROMPTS[0]], max_tokens=1, images=[images[0]]
|
||||||
images=[images[0]])
|
)
|
||||||
|
|||||||
@@ -26,8 +26,10 @@ def sample_token_ids():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_regex():
|
def sample_regex():
|
||||||
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
return (
|
||||||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
||||||
|
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -35,40 +37,27 @@ def sample_json_schema():
|
|||||||
return {
|
return {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"name": {
|
"name": {"type": "string"},
|
||||||
"type": "string"
|
"age": {"type": "integer"},
|
||||||
},
|
|
||||||
"age": {
|
|
||||||
"type": "integer"
|
|
||||||
},
|
|
||||||
"skills": {
|
"skills": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {"type": "string", "maxLength": 10},
|
||||||
"type": "string",
|
"minItems": 3,
|
||||||
"maxLength": 10
|
|
||||||
},
|
|
||||||
"minItems": 3
|
|
||||||
},
|
},
|
||||||
"work_history": {
|
"work_history": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"company": {
|
"company": {"type": "string"},
|
||||||
"type": "string"
|
"duration": {"type": "number"},
|
||||||
},
|
"position": {"type": "string"},
|
||||||
"duration": {
|
|
||||||
"type": "number"
|
|
||||||
},
|
|
||||||
"position": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["company", "position"]
|
"required": ["company", "position"],
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"required": ["name", "age", "skills", "work_history"]
|
"required": ["name", "age", "skills", "work_history"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -80,65 +69,53 @@ def sample_complex_json_schema():
|
|||||||
"score": {
|
"score": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"minimum": 0,
|
"minimum": 0,
|
||||||
"maximum": 100 # Numeric range
|
"maximum": 100, # Numeric range
|
||||||
},
|
},
|
||||||
"grade": {
|
"grade": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"pattern": "^[A-D]$" # Regex pattern
|
"pattern": "^[A-D]$", # Regex pattern
|
||||||
},
|
},
|
||||||
"email": {
|
"email": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
|
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$",
|
||||||
},
|
},
|
||||||
"tags": {
|
"tags": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"pattern":
|
"pattern": "^[a-z]{1,10}$", # Combining length and pattern restrictions
|
||||||
"^[a-z]{1,10}$" # Combining length and pattern restrictions
|
},
|
||||||
}
|
},
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["score", "grade", "email", "tags"]
|
"required": ["score", "grade", "email", "tags"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_definition_json_schema():
|
def sample_definition_json_schema():
|
||||||
return {
|
return {
|
||||||
'$defs': {
|
"$defs": {
|
||||||
'Step': {
|
"Step": {
|
||||||
'properties': {
|
"properties": {
|
||||||
'explanation': {
|
"explanation": {"title": "Explanation", "type": "string"},
|
||||||
'title': 'Explanation',
|
"output": {"title": "Output", "type": "string"},
|
||||||
'type': 'string'
|
|
||||||
},
|
|
||||||
'output': {
|
|
||||||
'title': 'Output',
|
|
||||||
'type': 'string'
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
'required': ['explanation', 'output'],
|
"required": ["explanation", "output"],
|
||||||
'title': 'Step',
|
"title": "Step",
|
||||||
'type': 'object'
|
"type": "object",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'properties': {
|
"properties": {
|
||||||
'steps': {
|
"steps": {
|
||||||
'items': {
|
"items": {"$ref": "#/$defs/Step"},
|
||||||
'$ref': '#/$defs/Step'
|
"title": "Steps",
|
||||||
},
|
"type": "array",
|
||||||
'title': 'Steps',
|
|
||||||
'type': 'array'
|
|
||||||
},
|
},
|
||||||
'final_answer': {
|
"final_answer": {"title": "Final Answer", "type": "string"},
|
||||||
'title': 'Final Answer',
|
|
||||||
'type': 'string'
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
'required': ['steps', 'final_answer'],
|
"required": ["steps", "final_answer"],
|
||||||
'title': 'MathReasoning',
|
"title": "MathReasoning",
|
||||||
'type': 'object'
|
"type": "object",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -149,64 +126,71 @@ def sample_enum_json_schema():
|
|||||||
"properties": {
|
"properties": {
|
||||||
"status": {
|
"status": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["active", "inactive",
|
"enum": ["active", "inactive", "pending"], # Literal values using enum
|
||||||
"pending"] # Literal values using enum
|
|
||||||
},
|
},
|
||||||
"priority": {
|
"priority": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["low", "medium", "high", "critical"]
|
"enum": ["low", "medium", "high", "critical"],
|
||||||
},
|
},
|
||||||
"category": {
|
"category": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"type": {
|
"type": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["bug", "feature", "improvement"]
|
"enum": ["bug", "feature", "improvement"],
|
||||||
},
|
},
|
||||||
"severity": {
|
"severity": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"enum": [1, 2, 3, 4,
|
"enum": [1, 2, 3, 4, 5], # Enum can also contain numbers
|
||||||
5] # Enum can also contain numbers
|
},
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["type", "severity"]
|
"required": ["type", "severity"],
|
||||||
},
|
},
|
||||||
"flags": {
|
"flags": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["urgent", "blocked", "needs_review", "approved"]
|
"enum": ["urgent", "blocked", "needs_review", "approved"],
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"required": ["status", "priority", "category", "flags"]
|
"required": ["status", "priority", "category", "flags"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_structured_outputs_choices():
|
def sample_structured_outputs_choices():
|
||||||
return [
|
return [
|
||||||
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
|
"Python",
|
||||||
"Ruby", "Swift", "Kotlin"
|
"Java",
|
||||||
|
"JavaScript",
|
||||||
|
"C++",
|
||||||
|
"C#",
|
||||||
|
"PHP",
|
||||||
|
"TypeScript",
|
||||||
|
"Ruby",
|
||||||
|
"Swift",
|
||||||
|
"Kotlin",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_sql_statements():
|
def sample_sql_statements():
|
||||||
return ("""
|
return """
|
||||||
start: select_statement
|
start: select_statement
|
||||||
select_statement: "SELECT" column "from" table "where" condition
|
select_statement: "SELECT" column "from" table "where" condition
|
||||||
column: "col_1" | "col_2"
|
column: "col_1" | "col_2"
|
||||||
table: "table_1" | "table_2"
|
table: "table_1" | "table_2"
|
||||||
condition: column "=" number
|
condition: column "=" number
|
||||||
number: "1" | "2"
|
number: "1" | "2"
|
||||||
""")
|
"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def zephyr_lora_files():
|
def zephyr_lora_files():
|
||||||
"""Download zephyr LoRA files once per test session."""
|
"""Download zephyr LoRA files once per test session."""
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora")
|
return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora")
|
||||||
|
|
||||||
|
|
||||||
@@ -214,5 +198,5 @@ def zephyr_lora_files():
|
|||||||
def opt125_lora_files() -> str:
|
def opt125_lora_files() -> str:
|
||||||
"""Download opt-125m LoRA files once per test session."""
|
"""Download opt-125m LoRA files once per test session."""
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
return snapshot_download(
|
|
||||||
repo_id="peft-internal-testing/opt-125m-dummy-lora")
|
return snapshot_download(repo_id="peft-internal-testing/opt-125m-dummy-lora")
|
||||||
|
|||||||
@@ -48,20 +48,23 @@ def run_test(model_name, more_args=None):
|
|||||||
|
|
||||||
measured_value = results["results"][TASK][FILTER]
|
measured_value = results["results"][TASK][FILTER]
|
||||||
assert model_name in EXPECTED_VALUES, (
|
assert model_name in EXPECTED_VALUES, (
|
||||||
f"Cannot find the expected value for the model {model_name=}")
|
f"Cannot find the expected value for the model {model_name=}"
|
||||||
|
)
|
||||||
expected_value = EXPECTED_VALUES[model_name]
|
expected_value = EXPECTED_VALUES[model_name]
|
||||||
assert (measured_value - RTOL < expected_value
|
assert (
|
||||||
and measured_value + RTOL > expected_value
|
measured_value - RTOL < expected_value
|
||||||
), f"Expected: {expected_value} | Measured: {measured_value}"
|
and measured_value + RTOL > expected_value
|
||||||
|
), f"Expected: {expected_value} | Measured: {measured_value}"
|
||||||
|
|
||||||
|
|
||||||
# TODO: [AlexM] Fix it with new CI/CD tests
|
# TODO: [AlexM] Fix it with new CI/CD tests
|
||||||
TPU_TP_TEST_STR = "" #"tensor_parallel_size=4"
|
TPU_TP_TEST_STR = "" # "tensor_parallel_size=4"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
@pytest.mark.skipif(
|
||||||
and not current_platform.is_tpu(),
|
not current_platform.is_cuda() and not current_platform.is_tpu(),
|
||||||
reason="V1 is currently only supported on CUDA and TPU")
|
reason="V1 is currently only supported on CUDA and TPU",
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("model", MODEL_NAMES)
|
@pytest.mark.parametrize("model", MODEL_NAMES)
|
||||||
def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
|
def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Run with the V1 Engine."""
|
"""Run with the V1 Engine."""
|
||||||
@@ -82,12 +85,14 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
|
|||||||
run_test(model, more_args)
|
run_test(model, more_args)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
@pytest.mark.skipif(
|
||||||
and not current_platform.is_tpu(),
|
not current_platform.is_cuda() and not current_platform.is_tpu(),
|
||||||
reason="V1 is currently only supported on CUDA and TPU")
|
reason="V1 is currently only supported on CUDA and TPU",
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES)
|
@pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES)
|
||||||
def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(
|
def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(
|
||||||
model, monkeypatch: pytest.MonkeyPatch):
|
model, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
"""Run with the V1 Engine."""
|
"""Run with the V1 Engine."""
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
|
|||||||
@@ -14,9 +14,7 @@ from ..openai.test_vision import TEST_IMAGE_ASSETS
|
|||||||
def text_llm():
|
def text_llm():
|
||||||
# pytest caches the fixture so we use weakref.proxy to
|
# pytest caches the fixture so we use weakref.proxy to
|
||||||
# enable garbage collection
|
# enable garbage collection
|
||||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, seed=0)
|
||||||
enforce_eager=True,
|
|
||||||
seed=0)
|
|
||||||
|
|
||||||
yield weakref.proxy(llm)
|
yield weakref.proxy(llm)
|
||||||
|
|
||||||
@@ -28,14 +26,8 @@ def text_llm():
|
|||||||
def test_chat(text_llm):
|
def test_chat(text_llm):
|
||||||
prompt1 = "Explain the concept of entropy."
|
prompt1 = "Explain the concept of entropy."
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
"role": "system",
|
{"role": "user", "content": prompt1},
|
||||||
"content": "You are a helpful assistant"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt1
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
outputs = text_llm.chat(messages)
|
outputs = text_llm.chat(messages)
|
||||||
assert len(outputs) == 1
|
assert len(outputs) == 1
|
||||||
@@ -46,25 +38,13 @@ def test_multi_chat(text_llm):
|
|||||||
prompt2 = "Explain what among us is."
|
prompt2 = "Explain what among us is."
|
||||||
|
|
||||||
conversation1 = [
|
conversation1 = [
|
||||||
{
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
"role": "system",
|
{"role": "user", "content": prompt1},
|
||||||
"content": "You are a helpful assistant"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt1
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
conversation2 = [
|
conversation2 = [
|
||||||
{
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
"role": "system",
|
{"role": "user", "content": prompt2},
|
||||||
"content": "You are a helpful assistant"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt2
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
messages = [conversation1, conversation2]
|
messages = [conversation1, conversation2]
|
||||||
@@ -94,26 +74,22 @@ def vision_llm():
|
|||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("image_urls",
|
@pytest.mark.parametrize(
|
||||||
[[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]],
|
"image_urls", [[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]], indirect=True
|
||||||
indirect=True)
|
)
|
||||||
def test_chat_multi_image(vision_llm, image_urls: list[str]):
|
def test_chat_multi_image(vision_llm, image_urls: list[str]):
|
||||||
messages = [{
|
messages = [
|
||||||
"role":
|
{
|
||||||
"user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
*({
|
*(
|
||||||
"type": "image_url",
|
{"type": "image_url", "image_url": {"url": image_url}}
|
||||||
"image_url": {
|
for image_url in image_urls
|
||||||
"url": image_url
|
),
|
||||||
}
|
{"type": "text", "text": "What's in this image?"},
|
||||||
} for image_url in image_urls),
|
],
|
||||||
{
|
}
|
||||||
"type": "text",
|
]
|
||||||
"text": "What's in this image?"
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}]
|
|
||||||
outputs = vision_llm.chat(messages)
|
outputs = vision_llm.chat(messages)
|
||||||
assert len(outputs) >= 0
|
assert len(outputs) >= 0
|
||||||
|
|
||||||
@@ -124,14 +100,8 @@ def test_llm_chat_tokenization_no_double_bos(text_llm):
|
|||||||
Check we get a single BOS token for llama chat.
|
Check we get a single BOS token for llama chat.
|
||||||
"""
|
"""
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
"role": "system",
|
{"role": "user", "content": "Hello!"},
|
||||||
"content": "You are a helpful assistant"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Hello!"
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
outputs = text_llm.chat(messages)
|
outputs = text_llm.chat(messages)
|
||||||
assert len(outputs) == 1
|
assert len(outputs) == 1
|
||||||
@@ -167,14 +137,8 @@ def thinking_llm():
|
|||||||
@pytest.mark.parametrize("enable_thinking", [True, False])
|
@pytest.mark.parametrize("enable_thinking", [True, False])
|
||||||
def test_chat_extra_kwargs(thinking_llm, enable_thinking):
|
def test_chat_extra_kwargs(thinking_llm, enable_thinking):
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
"role": "system",
|
{"role": "user", "content": "What is 1+1?"},
|
||||||
"content": "You are a helpful assistant"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "What is 1+1?"
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
outputs = thinking_llm.chat(
|
outputs = thinking_llm.chat(
|
||||||
|
|||||||
@@ -23,9 +23,11 @@ def test_collective_rpc(tp_size, backend, monkeypatch):
|
|||||||
return self.rank
|
return self.rank
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
llm = LLM(
|
||||||
enforce_eager=True,
|
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||||
load_format="dummy",
|
enforce_eager=True,
|
||||||
tensor_parallel_size=tp_size,
|
load_format="dummy",
|
||||||
distributed_executor_backend=backend)
|
tensor_parallel_size=tp_size,
|
||||||
|
distributed_executor_backend=backend,
|
||||||
|
)
|
||||||
assert llm.collective_rpc(echo_rank) == list(range(tp_size))
|
assert llm.collective_rpc(echo_rank) == list(range(tp_size))
|
||||||
|
|||||||
@@ -29,11 +29,13 @@ TOKEN_IDS = [
|
|||||||
def llm():
|
def llm():
|
||||||
# pytest caches the fixture so we use weakref.proxy to
|
# pytest caches the fixture so we use weakref.proxy to
|
||||||
# enable garbage collection
|
# enable garbage collection
|
||||||
llm = LLM(model=MODEL_NAME,
|
llm = LLM(
|
||||||
max_num_batched_tokens=4096,
|
model=MODEL_NAME,
|
||||||
tensor_parallel_size=1,
|
max_num_batched_tokens=4096,
|
||||||
gpu_memory_utilization=0.10,
|
tensor_parallel_size=1,
|
||||||
enforce_eager=True)
|
gpu_memory_utilization=0.10,
|
||||||
|
enforce_eager=True,
|
||||||
|
)
|
||||||
|
|
||||||
yield weakref.proxy(llm)
|
yield weakref.proxy(llm)
|
||||||
|
|
||||||
@@ -81,7 +83,8 @@ def test_max_model_len():
|
|||||||
outputs = llm.generate(PROMPTS, sampling_params)
|
outputs = llm.generate(PROMPTS, sampling_params)
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
num_total_tokens = len(output.prompt_token_ids) + len(
|
num_total_tokens = len(output.prompt_token_ids) + len(
|
||||||
output.outputs[0].token_ids)
|
output.outputs[0].token_ids
|
||||||
|
)
|
||||||
# Total tokens must not exceed max_model_len + 1 (the last token can be
|
# Total tokens must not exceed max_model_len + 1 (the last token can be
|
||||||
# generated with the context length equal to the max model length)
|
# generated with the context length equal to the max model length)
|
||||||
# It can be less if generation finishes due to other reasons (e.g., EOS)
|
# It can be less if generation finishes due to other reasons (e.g., EOS)
|
||||||
|
|||||||
@@ -16,9 +16,8 @@ def test_gpu_memory_utilization():
|
|||||||
# makes sure gpu_memory_utilization is per-instance limit,
|
# makes sure gpu_memory_utilization is per-instance limit,
|
||||||
# not a global limit
|
# not a global limit
|
||||||
llms = [
|
llms = [
|
||||||
LLM(model="facebook/opt-125m",
|
LLM(model="facebook/opt-125m", gpu_memory_utilization=0.3, enforce_eager=True)
|
||||||
gpu_memory_utilization=0.3,
|
for i in range(3)
|
||||||
enforce_eager=True) for i in range(3)
|
|
||||||
]
|
]
|
||||||
for llm in llms:
|
for llm in llms:
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ from vllm import LLM
|
|||||||
|
|
||||||
def test_empty_prompt():
|
def test_empty_prompt():
|
||||||
llm = LLM(model="openai-community/gpt2", enforce_eager=True)
|
llm = LLM(model="openai-community/gpt2", enforce_eager=True)
|
||||||
with pytest.raises(ValueError, match='decoder prompt cannot be empty'):
|
with pytest.raises(ValueError, match="decoder prompt cannot be empty"):
|
||||||
llm.generate([""])
|
llm.generate([""])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_v1
|
@pytest.mark.skip_v1
|
||||||
def test_out_of_vocab_token():
|
def test_out_of_vocab_token():
|
||||||
llm = LLM(model="openai-community/gpt2", enforce_eager=True)
|
llm = LLM(model="openai-community/gpt2", enforce_eager=True)
|
||||||
with pytest.raises(ValueError, match='out of vocabulary'):
|
with pytest.raises(ValueError, match="out of vocabulary"):
|
||||||
llm.generate({"prompt_token_ids": [999999]})
|
llm.generate({"prompt_token_ids": [999999]})
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Tests for HF_HUB_OFFLINE mode"""
|
"""Tests for HF_HUB_OFFLINE mode"""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import importlib
|
import importlib
|
||||||
import sys
|
import sys
|
||||||
@@ -91,12 +92,11 @@ def test_offline_mode(monkeypatch: pytest.MonkeyPatch):
|
|||||||
|
|
||||||
|
|
||||||
def _re_import_modules():
|
def _re_import_modules():
|
||||||
hf_hub_module_names = [
|
hf_hub_module_names = [k for k in sys.modules if k.startswith("huggingface_hub")]
|
||||||
k for k in sys.modules if k.startswith("huggingface_hub")
|
|
||||||
]
|
|
||||||
transformers_module_names = [
|
transformers_module_names = [
|
||||||
k for k in sys.modules if k.startswith("transformers")
|
k
|
||||||
and not k.startswith("transformers_modules")
|
for k in sys.modules
|
||||||
|
if k.startswith("transformers") and not k.startswith("transformers_modules")
|
||||||
]
|
]
|
||||||
|
|
||||||
reload_exception = None
|
reload_exception = None
|
||||||
|
|||||||
@@ -7,14 +7,14 @@ from vllm.assets.audio import AudioAsset
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mary_had_lamb():
|
def mary_had_lamb():
|
||||||
path = AudioAsset('mary_had_lamb').get_local_path()
|
path = AudioAsset("mary_had_lamb").get_local_path()
|
||||||
with open(str(path), "rb") as f:
|
with open(str(path), "rb") as f:
|
||||||
yield f
|
yield f
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def winning_call():
|
def winning_call():
|
||||||
path = AudioAsset('winning_call').get_local_path()
|
path = AudioAsset("winning_call").get_local_path()
|
||||||
with open(str(path), "rb") as f:
|
with open(str(path), "rb") as f:
|
||||||
yield f
|
yield f
|
||||||
|
|
||||||
@@ -22,6 +22,6 @@ def winning_call():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def foscolo():
|
def foscolo():
|
||||||
# Test translation it->en
|
# Test translation it->en
|
||||||
path = AudioAsset('azacinto_foscolo').get_local_path()
|
path = AudioAsset("azacinto_foscolo").get_local_path()
|
||||||
with open(str(path), "rb") as f:
|
with open(str(path), "rb") as f:
|
||||||
yield f
|
yield f
|
||||||
|
|||||||
@@ -44,14 +44,15 @@ def run_test(more_args):
|
|||||||
print(f"Running with: {args}")
|
print(f"Running with: {args}")
|
||||||
|
|
||||||
with RemoteOpenAIServer(
|
with RemoteOpenAIServer(
|
||||||
MODEL_NAME, args,
|
MODEL_NAME, args, max_wait_seconds=MAX_WAIT_SECONDS
|
||||||
max_wait_seconds=MAX_WAIT_SECONDS) as remote_server:
|
) as remote_server:
|
||||||
url = f"{remote_server.url_for('v1')}/completions"
|
url = f"{remote_server.url_for('v1')}/completions"
|
||||||
|
|
||||||
model_args = (
|
model_args = (
|
||||||
f"model={MODEL_NAME},"
|
f"model={MODEL_NAME},"
|
||||||
f"base_url={url},"
|
f"base_url={url},"
|
||||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
|
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False"
|
||||||
|
)
|
||||||
|
|
||||||
results = lm_eval.simple_evaluate(
|
results = lm_eval.simple_evaluate(
|
||||||
model="local-completions",
|
model="local-completions",
|
||||||
@@ -60,15 +61,18 @@ def run_test(more_args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
measured_value = results["results"][TASK][FILTER]
|
measured_value = results["results"][TASK][FILTER]
|
||||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
assert (
|
||||||
and measured_value + RTOL > EXPECTED_VALUE
|
measured_value - RTOL < EXPECTED_VALUE
|
||||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
and measured_value + RTOL > EXPECTED_VALUE
|
||||||
|
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
@pytest.mark.skipif(
|
||||||
and not current_platform.is_tpu()
|
not current_platform.is_cuda()
|
||||||
and not current_platform.is_xpu(),
|
and not current_platform.is_tpu()
|
||||||
reason="V1 currently only supported on CUDA, XPU and TPU")
|
and not current_platform.is_xpu(),
|
||||||
|
reason="V1 currently only supported on CUDA, XPU and TPU",
|
||||||
|
)
|
||||||
def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
|
def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Run with the V1 Engine."""
|
"""Run with the V1 Engine."""
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ a baseline.
|
|||||||
This simulates real work usage of the API and makes sure that the frontend and
|
This simulates real work usage of the API and makes sure that the frontend and
|
||||||
AsyncLLMEngine are working correctly.
|
AsyncLLMEngine are working correctly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import time
|
import time
|
||||||
@@ -45,7 +46,8 @@ async def transcribe_audio(client, tokenizer, y, sr):
|
|||||||
# NOTE there's no streaming in transcriptions, can't measure ttft
|
# NOTE there's no streaming in transcriptions, can't measure ttft
|
||||||
latency = end_time - start_time
|
latency = end_time - start_time
|
||||||
num_output_tokens = len(
|
num_output_tokens = len(
|
||||||
tokenizer(transcription.text, add_special_tokens=False).input_ids)
|
tokenizer(transcription.text, add_special_tokens=False).input_ids
|
||||||
|
)
|
||||||
return latency, num_output_tokens, transcription.text
|
return latency, num_output_tokens, transcription.text
|
||||||
|
|
||||||
|
|
||||||
@@ -73,8 +75,8 @@ async def process_dataset(model, client, data, concurrent_request):
|
|||||||
for sample in data:
|
for sample in data:
|
||||||
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
|
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
bound_transcribe(sem, client, tokenizer, (audio, sr),
|
bound_transcribe(sem, client, tokenizer, (audio, sr), sample["text"])
|
||||||
sample["text"]))
|
)
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
return await asyncio.gather(*tasks)
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
@@ -98,34 +100,35 @@ def print_performance_metrics(results, total_time):
|
|||||||
|
|
||||||
|
|
||||||
def add_duration(sample):
|
def add_duration(sample):
|
||||||
y, sr = sample['audio']["array"], sample['audio']["sampling_rate"]
|
y, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
|
||||||
sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000
|
sample["duration_ms"] = librosa.get_duration(y=y, sr=sr) * 1000
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs):
|
def load_hf_dataset(dataset_repo: str, split="validation", **hf_kwargs):
|
||||||
## Load and filter the dataset
|
## Load and filter the dataset
|
||||||
dataset = load_dataset(dataset_repo, split=split, **hf_kwargs)
|
dataset = load_dataset(dataset_repo, split=split, **hf_kwargs)
|
||||||
if 'duration_ms' not in dataset[0]:
|
if "duration_ms" not in dataset[0]:
|
||||||
# compute duration to filter
|
# compute duration to filter
|
||||||
dataset = dataset.map(add_duration)
|
dataset = dataset.map(add_duration)
|
||||||
|
|
||||||
# Whisper max supported duration
|
# Whisper max supported duration
|
||||||
dataset = dataset.filter(lambda example: example['duration_ms'] < 30000)
|
dataset = dataset.filter(lambda example: example["duration_ms"] < 30000)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def run_evaluation(model: str,
|
def run_evaluation(
|
||||||
client,
|
model: str,
|
||||||
dataset,
|
client,
|
||||||
max_concurrent_reqs: int,
|
dataset,
|
||||||
n_examples: int = -1,
|
max_concurrent_reqs: int,
|
||||||
print_metrics: bool = True):
|
n_examples: int = -1,
|
||||||
|
print_metrics: bool = True,
|
||||||
|
):
|
||||||
if n_examples > 0:
|
if n_examples > 0:
|
||||||
dataset = dataset.select(range(n_examples))
|
dataset = dataset.select(range(n_examples))
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
results = asyncio.run(
|
results = asyncio.run(process_dataset(model, client, dataset, max_concurrent_reqs))
|
||||||
process_dataset(model, client, dataset, max_concurrent_reqs))
|
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
total_time = end - start
|
total_time = end - start
|
||||||
print(f"Total Test Time: {total_time:.4f} seconds")
|
print(f"Total Test Time: {total_time:.4f} seconds")
|
||||||
@@ -135,8 +138,7 @@ def run_evaluation(model: str,
|
|||||||
predictions = [res[2] for res in results]
|
predictions = [res[2] for res in results]
|
||||||
references = [res[3] for res in results]
|
references = [res[3] for res in results]
|
||||||
wer = load("wer")
|
wer = load("wer")
|
||||||
wer_score = 100 * wer.compute(references=references,
|
wer_score = 100 * wer.compute(references=references, predictions=predictions)
|
||||||
predictions=predictions)
|
|
||||||
print("WER:", wer_score)
|
print("WER:", wer_score)
|
||||||
return wer_score
|
return wer_score
|
||||||
|
|
||||||
@@ -145,26 +147,25 @@ def run_evaluation(model: str,
|
|||||||
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"])
|
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"])
|
||||||
# Original dataset is 20GB+ in size, hence we use a pre-filtered slice.
|
# Original dataset is 20GB+ in size, hence we use a pre-filtered slice.
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"])
|
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]
|
||||||
|
)
|
||||||
# NOTE: Expected WER measured with equivalent hf.transformers args:
|
# NOTE: Expected WER measured with equivalent hf.transformers args:
|
||||||
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
|
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
|
||||||
@pytest.mark.parametrize("expected_wer", [12.744980])
|
@pytest.mark.parametrize("expected_wer", [12.744980])
|
||||||
def test_wer_correctness(model_name,
|
def test_wer_correctness(
|
||||||
dataset_repo,
|
model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None
|
||||||
expected_wer,
|
):
|
||||||
n_examples=-1,
|
|
||||||
max_concurrent_request=None):
|
|
||||||
# TODO refactor to use `ASRDataset`
|
# TODO refactor to use `ASRDataset`
|
||||||
with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server:
|
with RemoteOpenAIServer(model_name, ["--enforce-eager"]) as remote_server:
|
||||||
dataset = load_hf_dataset(dataset_repo)
|
dataset = load_hf_dataset(dataset_repo)
|
||||||
|
|
||||||
if not max_concurrent_request:
|
if not max_concurrent_request:
|
||||||
# No max concurrency
|
# No max concurrency
|
||||||
max_concurrent_request = n_examples if n_examples > 0\
|
max_concurrent_request = n_examples if n_examples > 0 else len(dataset)
|
||||||
else len(dataset)
|
|
||||||
|
|
||||||
client = remote_server.get_async_client()
|
client = remote_server.get_async_client()
|
||||||
wer = run_evaluation(model_name, client, dataset,
|
wer = run_evaluation(
|
||||||
max_concurrent_request, n_examples)
|
model_name, client, dataset, max_concurrent_request, n_examples
|
||||||
|
)
|
||||||
if expected_wer:
|
if expected_wer:
|
||||||
torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2)
|
torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2)
|
||||||
|
|||||||
@@ -44,15 +44,11 @@ async def client(server):
|
|||||||
ids=["completion", "chat"],
|
ids=["completion", "chat"],
|
||||||
argnames=["create_func_gen", "content_body"],
|
argnames=["create_func_gen", "content_body"],
|
||||||
argvalues=[
|
argvalues=[
|
||||||
(lambda x: x.completions.create, {
|
(lambda x: x.completions.create, {"prompt": " ".join(["A"] * 10_000)}),
|
||||||
"prompt": " ".join(['A'] * 10_000)
|
(
|
||||||
}),
|
lambda x: x.chat.completions.create,
|
||||||
(lambda x: x.chat.completions.create, {
|
{"messages": [{"role": "user", "content": " ".join(["A"] * 10_000)}]},
|
||||||
"messages": [{
|
),
|
||||||
"role": "user",
|
|
||||||
"content": " ".join(['A'] * 10_000)
|
|
||||||
}]
|
|
||||||
}),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_with_and_without_truncate(
|
async def test_with_and_without_truncate(
|
||||||
@@ -65,15 +61,15 @@ async def test_with_and_without_truncate(
|
|||||||
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
|
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
|
||||||
|
|
||||||
num_requests = 10
|
num_requests = 10
|
||||||
truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] *
|
truncate_prompt_tokens = [1000] * (num_requests // 2) + [None] * (
|
||||||
(num_requests - num_requests // 2))
|
num_requests - num_requests // 2
|
||||||
|
)
|
||||||
random.shuffle(truncate_prompt_tokens)
|
random.shuffle(truncate_prompt_tokens)
|
||||||
|
|
||||||
bodies = [{
|
bodies = [
|
||||||
**body, "extra_body": {
|
{**body, "extra_body": {"truncate_prompt_tokens": t}}
|
||||||
'truncate_prompt_tokens': t
|
for t in truncate_prompt_tokens
|
||||||
}
|
]
|
||||||
} for t in truncate_prompt_tokens]
|
|
||||||
|
|
||||||
async def get_status_code(**kwargs):
|
async def get_status_code(**kwargs):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -56,24 +56,18 @@ def base64_encoded_audio() -> dict[str, str]:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||||
async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
async def test_single_chat_session_audio(
|
||||||
model_name: str, audio_url: str):
|
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||||
messages = [{
|
):
|
||||||
"role":
|
messages = [
|
||||||
"user",
|
{
|
||||||
"content": [
|
"role": "user",
|
||||||
{
|
"content": [
|
||||||
"type": "audio_url",
|
{"type": "audio_url", "audio_url": {"url": audio_url}},
|
||||||
"audio_url": {
|
{"type": "text", "text": "What's happening in this audio?"},
|
||||||
"url": audio_url
|
],
|
||||||
}
|
}
|
||||||
},
|
]
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What's happening in this audio?"
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}]
|
|
||||||
|
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
@@ -82,13 +76,15 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
|||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
top_logprobs=5)
|
top_logprobs=5,
|
||||||
|
)
|
||||||
assert len(chat_completion.choices) == 1
|
assert len(chat_completion.choices) == 1
|
||||||
|
|
||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
completion_tokens=10, prompt_tokens=202, total_tokens=212
|
||||||
|
)
|
||||||
|
|
||||||
message = choice.message
|
message = choice.message
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
@@ -110,56 +106,52 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||||
async def test_error_on_invalid_audio_url_type(client: openai.AsyncOpenAI,
|
async def test_error_on_invalid_audio_url_type(
|
||||||
model_name: str,
|
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||||
audio_url: str):
|
):
|
||||||
messages = [{
|
messages = [
|
||||||
"role":
|
{
|
||||||
"user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{"type": "audio_url", "audio_url": audio_url},
|
||||||
"type": "audio_url",
|
{"type": "text", "text": "What's happening in this audio?"},
|
||||||
"audio_url": audio_url
|
],
|
||||||
},
|
}
|
||||||
{
|
]
|
||||||
"type": "text",
|
|
||||||
"text": "What's happening in this audio?"
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}]
|
|
||||||
|
|
||||||
# audio_url should be a dict {"url": "some url"}, not directly a string
|
# audio_url should be a dict {"url": "some url"}, not directly a string
|
||||||
with pytest.raises(openai.BadRequestError):
|
with pytest.raises(openai.BadRequestError):
|
||||||
_ = await client.chat.completions.create(model=model_name,
|
_ = await client.chat.completions.create(
|
||||||
messages=messages,
|
model=model_name,
|
||||||
max_completion_tokens=10,
|
messages=messages,
|
||||||
temperature=0.0)
|
max_completion_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||||
async def test_single_chat_session_audio_base64encoded(
|
async def test_single_chat_session_audio_base64encoded(
|
||||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
|
client: openai.AsyncOpenAI,
|
||||||
base64_encoded_audio: dict[str, str]):
|
model_name: str,
|
||||||
|
audio_url: str,
|
||||||
messages = [{
|
base64_encoded_audio: dict[str, str],
|
||||||
"role":
|
):
|
||||||
"user",
|
messages = [
|
||||||
"content": [
|
{
|
||||||
{
|
"role": "user",
|
||||||
"type": "audio_url",
|
"content": [
|
||||||
"audio_url": {
|
{
|
||||||
"url":
|
"type": "audio_url",
|
||||||
f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
|
"audio_url": {
|
||||||
}
|
"url": f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
|
||||||
},
|
},
|
||||||
{
|
},
|
||||||
"type": "text",
|
{"type": "text", "text": "What's happening in this audio?"},
|
||||||
"text": "What's happening in this audio?"
|
],
|
||||||
},
|
}
|
||||||
],
|
]
|
||||||
}]
|
|
||||||
|
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
@@ -168,13 +160,15 @@ async def test_single_chat_session_audio_base64encoded(
|
|||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
top_logprobs=5)
|
top_logprobs=5,
|
||||||
|
)
|
||||||
assert len(chat_completion.choices) == 1
|
assert len(chat_completion.choices) == 1
|
||||||
|
|
||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
completion_tokens=10, prompt_tokens=202, total_tokens=212
|
||||||
|
)
|
||||||
|
|
||||||
message = choice.message
|
message = choice.message
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
@@ -198,25 +192,26 @@ async def test_single_chat_session_audio_base64encoded(
|
|||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||||
async def test_single_chat_session_input_audio(
|
async def test_single_chat_session_input_audio(
|
||||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
|
client: openai.AsyncOpenAI,
|
||||||
base64_encoded_audio: dict[str, str]):
|
model_name: str,
|
||||||
messages = [{
|
audio_url: str,
|
||||||
"role":
|
base64_encoded_audio: dict[str, str],
|
||||||
"user",
|
):
|
||||||
"content": [
|
messages = [
|
||||||
{
|
{
|
||||||
"type": "input_audio",
|
"role": "user",
|
||||||
"input_audio": {
|
"content": [
|
||||||
"data": base64_encoded_audio[audio_url],
|
{
|
||||||
"format": "wav"
|
"type": "input_audio",
|
||||||
}
|
"input_audio": {
|
||||||
},
|
"data": base64_encoded_audio[audio_url],
|
||||||
{
|
"format": "wav",
|
||||||
"type": "text",
|
},
|
||||||
"text": "What's happening in this audio?"
|
},
|
||||||
},
|
{"type": "text", "text": "What's happening in this audio?"},
|
||||||
],
|
],
|
||||||
}]
|
}
|
||||||
|
]
|
||||||
|
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
@@ -224,13 +219,15 @@ async def test_single_chat_session_input_audio(
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=10,
|
max_completion_tokens=10,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
top_logprobs=5)
|
top_logprobs=5,
|
||||||
|
)
|
||||||
assert len(chat_completion.choices) == 1
|
assert len(chat_completion.choices) == 1
|
||||||
|
|
||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
completion_tokens=10, prompt_tokens=202, total_tokens=212
|
||||||
|
)
|
||||||
|
|
||||||
message = choice.message
|
message = choice.message
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
@@ -252,24 +249,18 @@ async def test_single_chat_session_input_audio(
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||||
async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
|
async def test_chat_streaming_audio(
|
||||||
model_name: str, audio_url: str):
|
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||||
messages = [{
|
):
|
||||||
"role":
|
messages = [
|
||||||
"user",
|
{
|
||||||
"content": [
|
"role": "user",
|
||||||
{
|
"content": [
|
||||||
"type": "audio_url",
|
{"type": "audio_url", "audio_url": {"url": audio_url}},
|
||||||
"audio_url": {
|
{"type": "text", "text": "What's happening in this audio?"},
|
||||||
"url": audio_url
|
],
|
||||||
}
|
}
|
||||||
},
|
]
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What's happening in this audio?"
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}]
|
|
||||||
|
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
@@ -309,27 +300,27 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||||
async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
|
async def test_chat_streaming_input_audio(
|
||||||
model_name: str, audio_url: str,
|
client: openai.AsyncOpenAI,
|
||||||
base64_encoded_audio: dict[str,
|
model_name: str,
|
||||||
str]):
|
audio_url: str,
|
||||||
messages = [{
|
base64_encoded_audio: dict[str, str],
|
||||||
"role":
|
):
|
||||||
"user",
|
messages = [
|
||||||
"content": [
|
{
|
||||||
{
|
"role": "user",
|
||||||
"type": "input_audio",
|
"content": [
|
||||||
"input_audio": {
|
{
|
||||||
"data": base64_encoded_audio[audio_url],
|
"type": "input_audio",
|
||||||
"format": "wav"
|
"input_audio": {
|
||||||
}
|
"data": base64_encoded_audio[audio_url],
|
||||||
},
|
"format": "wav",
|
||||||
{
|
},
|
||||||
"type": "text",
|
},
|
||||||
"text": "What's happening in this audio?"
|
{"type": "text", "text": "What's happening in this audio?"},
|
||||||
},
|
],
|
||||||
],
|
}
|
||||||
}]
|
]
|
||||||
|
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
@@ -369,26 +360,23 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]])
|
"audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]]
|
||||||
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
|
)
|
||||||
audio_urls: list[str]):
|
async def test_multi_audio_input(
|
||||||
|
client: openai.AsyncOpenAI, model_name: str, audio_urls: list[str]
|
||||||
messages = [{
|
):
|
||||||
"role":
|
messages = [
|
||||||
"user",
|
{
|
||||||
"content": [
|
"role": "user",
|
||||||
*({
|
"content": [
|
||||||
"type": "audio_url",
|
*(
|
||||||
"audio_url": {
|
{"type": "audio_url", "audio_url": {"url": audio_url}}
|
||||||
"url": audio_url
|
for audio_url in audio_urls
|
||||||
}
|
),
|
||||||
} for audio_url in audio_urls),
|
{"type": "text", "text": "What's happening in this audio?"},
|
||||||
{
|
],
|
||||||
"type": "text",
|
}
|
||||||
"text": "What's happening in this audio?"
|
]
|
||||||
},
|
|
||||||
],
|
|
||||||
}]
|
|
||||||
|
|
||||||
if len(audio_urls) > MAXIMUM_AUDIOS:
|
if len(audio_urls) > MAXIMUM_AUDIOS:
|
||||||
with pytest.raises(openai.BadRequestError): # test multi-audio input
|
with pytest.raises(openai.BadRequestError): # test multi-audio input
|
||||||
|
|||||||
@@ -16,9 +16,9 @@ from ...utils import RemoteOpenAIServer
|
|||||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='module')
|
@pytest.fixture(scope="module")
|
||||||
def server_args(request: pytest.FixtureRequest) -> list[str]:
|
def server_args(request: pytest.FixtureRequest) -> list[str]:
|
||||||
""" Provide extra arguments to the server via indirect parametrization
|
"""Provide extra arguments to the server via indirect parametrization
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
@@ -80,8 +80,10 @@ async def client(server):
|
|||||||
"server_args",
|
"server_args",
|
||||||
[
|
[
|
||||||
pytest.param([], id="default-frontend-multiprocessing"),
|
pytest.param([], id="default-frontend-multiprocessing"),
|
||||||
pytest.param(["--disable-frontend-multiprocessing"],
|
pytest.param(
|
||||||
id="disable-frontend-multiprocessing")
|
["--disable-frontend-multiprocessing"],
|
||||||
|
id="disable-frontend-multiprocessing",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@@ -97,8 +99,10 @@ async def test_show_version(server: RemoteOpenAIServer):
|
|||||||
"server_args",
|
"server_args",
|
||||||
[
|
[
|
||||||
pytest.param([], id="default-frontend-multiprocessing"),
|
pytest.param([], id="default-frontend-multiprocessing"),
|
||||||
pytest.param(["--disable-frontend-multiprocessing"],
|
pytest.param(
|
||||||
id="disable-frontend-multiprocessing")
|
["--disable-frontend-multiprocessing"],
|
||||||
|
id="disable-frontend-multiprocessing",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@@ -112,11 +116,13 @@ async def test_check_health(server: RemoteOpenAIServer):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"server_args",
|
"server_args",
|
||||||
[
|
[
|
||||||
pytest.param(["--max-model-len", "10100"],
|
pytest.param(
|
||||||
id="default-frontend-multiprocessing"),
|
["--max-model-len", "10100"], id="default-frontend-multiprocessing"
|
||||||
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
["--disable-frontend-multiprocessing", "--max-model-len", "10100"],
|
["--disable-frontend-multiprocessing", "--max-model-len", "10100"],
|
||||||
id="disable-frontend-multiprocessing")
|
id="disable-frontend-multiprocessing",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@@ -131,14 +137,16 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
|
|||||||
# Request about 2 million tokens
|
# Request about 2 million tokens
|
||||||
for _ in range(200):
|
for _ in range(200):
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
client.chat.completions.create(messages=chat_input,
|
client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
messages=chat_input,
|
||||||
max_tokens=10000,
|
model=MODEL_NAME,
|
||||||
extra_body={"min_tokens": 10000}))
|
max_tokens=10000,
|
||||||
|
extra_body={"min_tokens": 10000},
|
||||||
|
)
|
||||||
|
)
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
|
|
||||||
done, pending = await asyncio.wait(tasks,
|
done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
|
||||||
return_when=asyncio.ALL_COMPLETED)
|
|
||||||
|
|
||||||
# Make sure all requests were sent to the server and timed out
|
# Make sure all requests were sent to the server and timed out
|
||||||
# (We don't want to hide other errors like 400s that would invalidate this
|
# (We don't want to hide other errors like 400s that would invalidate this
|
||||||
@@ -151,16 +159,15 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
|
|||||||
# If the server had not cancelled all the other requests, then it would not
|
# If the server had not cancelled all the other requests, then it would not
|
||||||
# be able to respond to this one within the timeout
|
# be able to respond to this one within the timeout
|
||||||
client = server.get_async_client(timeout=5)
|
client = server.get_async_client(timeout=5)
|
||||||
response = await client.chat.completions.create(messages=chat_input,
|
response = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
messages=chat_input, model=MODEL_NAME, max_tokens=10
|
||||||
max_tokens=10)
|
)
|
||||||
|
|
||||||
assert len(response.choices) == 1
|
assert len(response.choices) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_request_wrong_content_type(server: RemoteOpenAIServer):
|
async def test_request_wrong_content_type(server: RemoteOpenAIServer):
|
||||||
|
|
||||||
chat_input = [{"role": "user", "content": "Write a long story"}]
|
chat_input = [{"role": "user", "content": "Write a long story"}]
|
||||||
client = server.get_async_client()
|
client = server.get_async_client()
|
||||||
|
|
||||||
@@ -169,17 +176,13 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer):
|
|||||||
messages=chat_input,
|
messages=chat_input,
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
max_tokens=10000,
|
max_tokens=10000,
|
||||||
extra_headers={
|
extra_headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
"Content-Type": "application/x-www-form-urlencoded"
|
)
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"server_args",
|
"server_args",
|
||||||
[
|
[pytest.param(["--enable-server-load-tracking"], id="enable-server-load-tracking")],
|
||||||
pytest.param(["--enable-server-load-tracking"],
|
|
||||||
id="enable-server-load-tracking")
|
|
||||||
],
|
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -202,7 +205,8 @@ async def test_server_load(server: RemoteOpenAIServer):
|
|||||||
|
|
||||||
# Start the completion request in a background thread.
|
# Start the completion request in a background thread.
|
||||||
completion_future = asyncio.create_task(
|
completion_future = asyncio.create_task(
|
||||||
asyncio.to_thread(make_long_completion_request))
|
asyncio.to_thread(make_long_completion_request)
|
||||||
|
)
|
||||||
|
|
||||||
# Give a short delay to ensure the request has started.
|
# Give a short delay to ensure the request has started.
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user