Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
90eb3f43ca | ||
|
|
e67b4f2c2a | ||
|
|
d6770d1f23 | ||
|
|
b9cecc2635 | ||
|
|
898285c9bf | ||
|
|
a62de9ecfd | ||
|
|
4042d192f5 |
@@ -3,31 +3,15 @@
|
|||||||
Installation
|
Installation
|
||||||
============
|
============
|
||||||
|
|
||||||
vLLM is a Python library that also contains some C++ and CUDA code.
|
vLLM is a Python library that also contains pre-compiled C++ and CUDA (11.8) binaries.
|
||||||
This additional code requires compilation on the user's machine.
|
|
||||||
|
|
||||||
Requirements
|
Requirements
|
||||||
------------
|
------------
|
||||||
|
|
||||||
* OS: Linux
|
* OS: Linux
|
||||||
* Python: 3.8 or higher
|
* Python: 3.8 -- 3.11
|
||||||
* CUDA: 11.0 -- 11.8
|
|
||||||
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
|
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
|
||||||
|
|
||||||
.. note::
|
|
||||||
As of now, vLLM does not support CUDA 12.
|
|
||||||
If you are using Hopper or Lovelace GPUs, please use CUDA 11.8 instead of CUDA 12.
|
|
||||||
|
|
||||||
.. tip::
|
|
||||||
If you have trouble installing vLLM, we recommend using the NVIDIA PyTorch Docker image.
|
|
||||||
|
|
||||||
.. code-block:: console
|
|
||||||
|
|
||||||
$ # Pull the Docker image with CUDA 11.8.
|
|
||||||
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
|
|
||||||
|
|
||||||
Inside the Docker container, please execute :code:`pip uninstall torch` before installing vLLM.
|
|
||||||
|
|
||||||
Install with pip
|
Install with pip
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
@@ -40,7 +24,7 @@ You can install vLLM using pip:
|
|||||||
$ conda activate myenv
|
$ conda activate myenv
|
||||||
|
|
||||||
$ # Install vLLM.
|
$ # Install vLLM.
|
||||||
$ pip install vllm # This may take 5-10 minutes.
|
$ pip install vllm
|
||||||
|
|
||||||
|
|
||||||
.. _build_from_source:
|
.. _build_from_source:
|
||||||
@@ -55,3 +39,11 @@ You can also build and install vLLM from source:
|
|||||||
$ git clone https://github.com/vllm-project/vllm.git
|
$ git clone https://github.com/vllm-project/vllm.git
|
||||||
$ cd vllm
|
$ cd vllm
|
||||||
$ pip install -e . # This may take 5-10 minutes.
|
$ pip install -e . # This may take 5-10 minutes.
|
||||||
|
|
||||||
|
.. tip::
|
||||||
|
If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image.
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ # Pull the Docker image with CUDA 11.8.
|
||||||
|
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
|
||||||
|
|||||||
49
setup.py
49
setup.py
@@ -22,7 +22,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
|||||||
|
|
||||||
if CUDA_HOME is None:
|
if CUDA_HOME is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
||||||
|
|
||||||
|
|
||||||
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||||
@@ -54,7 +54,8 @@ if nvcc_cuda_version < Version("11.0"):
|
|||||||
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
|
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
|
||||||
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
|
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
|
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6."
|
||||||
|
)
|
||||||
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
||||||
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||||
# However, GPUs with compute capability 8.9 can also run the code generated by
|
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||||
@@ -65,7 +66,8 @@ if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
|||||||
compute_capabilities.add(80)
|
compute_capabilities.add(80)
|
||||||
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
|
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
|
||||||
|
)
|
||||||
|
|
||||||
# If no GPU is available, add all supported compute capabilities.
|
# If no GPU is available, add all supported compute capabilities.
|
||||||
if not compute_capabilities:
|
if not compute_capabilities:
|
||||||
@@ -78,7 +80,9 @@ if not compute_capabilities:
|
|||||||
|
|
||||||
# Add target compute capabilities to NVCC flags.
|
# Add target compute capabilities to NVCC flags.
|
||||||
for capability in compute_capabilities:
|
for capability in compute_capabilities:
|
||||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
|
NVCC_FLAGS += [
|
||||||
|
"-gencode", f"arch=compute_{capability},code=sm_{capability}"
|
||||||
|
]
|
||||||
|
|
||||||
# Use NVCC threads to parallelize the build.
|
# Use NVCC threads to parallelize the build.
|
||||||
if nvcc_cuda_version >= Version("11.2"):
|
if nvcc_cuda_version >= Version("11.2"):
|
||||||
@@ -91,7 +95,10 @@ ext_modules = []
|
|||||||
cache_extension = CUDAExtension(
|
cache_extension = CUDAExtension(
|
||||||
name="vllm.cache_ops",
|
name="vllm.cache_ops",
|
||||||
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
|
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
|
||||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(cache_extension)
|
ext_modules.append(cache_extension)
|
||||||
|
|
||||||
@@ -99,7 +106,10 @@ ext_modules.append(cache_extension)
|
|||||||
attention_extension = CUDAExtension(
|
attention_extension = CUDAExtension(
|
||||||
name="vllm.attention_ops",
|
name="vllm.attention_ops",
|
||||||
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
|
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
|
||||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(attention_extension)
|
ext_modules.append(attention_extension)
|
||||||
|
|
||||||
@@ -107,7 +117,10 @@ ext_modules.append(attention_extension)
|
|||||||
positional_encoding_extension = CUDAExtension(
|
positional_encoding_extension = CUDAExtension(
|
||||||
name="vllm.pos_encoding_ops",
|
name="vllm.pos_encoding_ops",
|
||||||
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
|
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
|
||||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(positional_encoding_extension)
|
ext_modules.append(positional_encoding_extension)
|
||||||
|
|
||||||
@@ -115,7 +128,10 @@ ext_modules.append(positional_encoding_extension)
|
|||||||
layernorm_extension = CUDAExtension(
|
layernorm_extension = CUDAExtension(
|
||||||
name="vllm.layernorm_ops",
|
name="vllm.layernorm_ops",
|
||||||
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
|
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
|
||||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(layernorm_extension)
|
ext_modules.append(layernorm_extension)
|
||||||
|
|
||||||
@@ -123,7 +139,10 @@ ext_modules.append(layernorm_extension)
|
|||||||
activation_extension = CUDAExtension(
|
activation_extension = CUDAExtension(
|
||||||
name="vllm.activation_ops",
|
name="vllm.activation_ops",
|
||||||
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
|
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
|
||||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(activation_extension)
|
ext_modules.append(activation_extension)
|
||||||
|
|
||||||
@@ -138,8 +157,8 @@ def find_version(filepath: str):
|
|||||||
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
|
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
|
||||||
"""
|
"""
|
||||||
with open(filepath) as fp:
|
with open(filepath) as fp:
|
||||||
version_match = re.search(
|
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
|
||||||
r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M)
|
fp.read(), re.M)
|
||||||
if version_match:
|
if version_match:
|
||||||
return version_match.group(1)
|
return version_match.group(1)
|
||||||
raise RuntimeError("Unable to find version string.")
|
raise RuntimeError("Unable to find version string.")
|
||||||
@@ -162,7 +181,8 @@ setuptools.setup(
|
|||||||
version=find_version(get_path("vllm", "__init__.py")),
|
version=find_version(get_path("vllm", "__init__.py")),
|
||||||
author="vLLM Team",
|
author="vLLM Team",
|
||||||
license="Apache 2.0",
|
license="Apache 2.0",
|
||||||
description="A high-throughput and memory-efficient inference and serving engine for LLMs",
|
description=("A high-throughput and memory-efficient inference and "
|
||||||
|
"serving engine for LLMs"),
|
||||||
long_description=read_readme(),
|
long_description=read_readme(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
url="https://github.com/vllm-project/vllm",
|
url="https://github.com/vllm-project/vllm",
|
||||||
@@ -174,11 +194,12 @@ setuptools.setup(
|
|||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.9",
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
"License :: OSI Approved :: Apache Software License",
|
"License :: OSI Approved :: Apache Software License",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
],
|
],
|
||||||
packages=setuptools.find_packages(
|
packages=setuptools.find_packages(exclude=("benchmarks", "csrc", "docs",
|
||||||
exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")),
|
"examples", "tests")),
|
||||||
python_requires=">=3.8",
|
python_requires=">=3.8",
|
||||||
install_requires=get_requirements(),
|
install_requires=get_requirements(),
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
|
|||||||
@@ -133,9 +133,10 @@ def test_rotary_embedding(
|
|||||||
device="cuda")
|
device="cuda")
|
||||||
|
|
||||||
# Create the rotary embedding.
|
# Create the rotary embedding.
|
||||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
inv_freq = 1.0 / (base**(
|
||||||
|
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||||
t = torch.arange(max_position).float()
|
t = torch.arange(max_position).float()
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
|
|||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
__version__ = "0.1.6"
|
__version__ = "0.1.7"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LLM",
|
"LLM",
|
||||||
|
|||||||
@@ -114,8 +114,9 @@ class ModelConfig:
|
|||||||
# Note: for falcon, when new_decoder_architecture is True, the
|
# Note: for falcon, when new_decoder_architecture is True, the
|
||||||
# multi_query flag is ignored and we use n_head_kv for the number of
|
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||||
# KV heads.
|
# KV heads.
|
||||||
|
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
||||||
new_decoder_arch_falcon = (
|
new_decoder_arch_falcon = (
|
||||||
self.hf_config.model_type == "falcon"
|
self.hf_config.model_type in falcon_model_types
|
||||||
and getattr(self.hf_config, "new_decoder_architecture", False))
|
and getattr(self.hf_config, "new_decoder_architecture", False))
|
||||||
if not new_decoder_arch_falcon and getattr(self.hf_config,
|
if not new_decoder_arch_falcon and getattr(self.hf_config,
|
||||||
"multi_query", False):
|
"multi_query", False):
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ class LLMEngine:
|
|||||||
placement_group=placement_group,
|
placement_group=placement_group,
|
||||||
placement_group_capture_child_tasks=True),
|
placement_group_capture_child_tasks=True),
|
||||||
**ray_remote_kwargs,
|
**ray_remote_kwargs,
|
||||||
)(RayWorker).remote()
|
)(RayWorker).remote(self.model_config.trust_remote_code)
|
||||||
self.workers.append(worker)
|
self.workers.append(worker)
|
||||||
|
|
||||||
# Initialize torch distributed process group for the workers.
|
# Initialize torch distributed process group for the workers.
|
||||||
|
|||||||
@@ -11,7 +11,11 @@ try:
|
|||||||
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
|
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
|
||||||
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, init_cached_hf_modules=False) -> None:
|
||||||
|
if init_cached_hf_modules:
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
from transformers.dynamic_module_utils import init_hf_modules
|
||||||
|
init_hf_modules()
|
||||||
self.worker = None
|
self.worker = None
|
||||||
|
|
||||||
def init_worker(self, worker_init_fn):
|
def init_worker(self, worker_init_fn):
|
||||||
|
|||||||
@@ -73,7 +73,12 @@ class PagedAttention(nn.Module):
|
|||||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||||
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
||||||
|
|
||||||
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
|
def set_attn_bias(
|
||||||
|
self,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
del dtype # Unused.
|
||||||
if input_metadata.attn_bias:
|
if input_metadata.attn_bias:
|
||||||
# Already set by a previous layer.
|
# Already set by a previous layer.
|
||||||
return
|
return
|
||||||
@@ -196,7 +201,7 @@ class PagedAttention(nn.Module):
|
|||||||
if num_prompt_tokens > 0:
|
if num_prompt_tokens > 0:
|
||||||
# Prompt run.
|
# Prompt run.
|
||||||
assert input_metadata.num_generation_tokens == 0
|
assert input_metadata.num_generation_tokens == 0
|
||||||
self.set_attn_bias(input_metadata)
|
self.set_attn_bias(input_metadata, dtype=query.dtype)
|
||||||
self.multi_query_kv_attention(
|
self.multi_query_kv_attention(
|
||||||
output[:num_prompt_tokens],
|
output[:num_prompt_tokens],
|
||||||
query[:num_prompt_tokens],
|
query[:num_prompt_tokens],
|
||||||
@@ -259,10 +264,10 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
self.is_neox_style = is_neox_style
|
self.is_neox_style = is_neox_style
|
||||||
|
|
||||||
# Create the cos and sin cache.
|
# Create the cos and sin cache.
|
||||||
inv_freq = 1.0 / (base**(
|
inv_freq = 1.0 / (base**(torch.arange(
|
||||||
torch.arange(0, rotary_dim, 2, device="cuda") / rotary_dim))
|
0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
|
||||||
t = torch.arange(max_position, device="cuda").float()
|
t = torch.arange(max_position, dtype=torch.float, device="cuda")
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
cache = torch.cat((cos, sin), dim=-1)
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
@@ -340,13 +345,14 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
slopes = torch.tensor(slopes, dtype=torch.float32)
|
slopes = torch.tensor(slopes, dtype=torch.float32)
|
||||||
self.register_buffer("alibi_slopes", slopes, persistent=False)
|
self.register_buffer("alibi_slopes", slopes, persistent=False)
|
||||||
|
|
||||||
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
|
def set_attn_bias(self, input_metadata: InputMetadata,
|
||||||
|
dtype: torch.dtype) -> None:
|
||||||
if input_metadata.attn_bias:
|
if input_metadata.attn_bias:
|
||||||
# Already set by a previous layer.
|
# Already set by a previous layer.
|
||||||
return
|
return
|
||||||
# Generates ALiBi mask for each prompt.
|
# Generates ALiBi mask for each prompt.
|
||||||
for prompt_len in input_metadata.prompt_lens:
|
for prompt_len in input_metadata.prompt_lens:
|
||||||
bias = torch.arange(prompt_len)
|
bias = torch.arange(prompt_len, dtype=dtype)
|
||||||
# Note(zhuohan): HF uses
|
# Note(zhuohan): HF uses
|
||||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||||
# here. We find that both biases give the same results, but
|
# here. We find that both biases give the same results, but
|
||||||
@@ -364,6 +370,7 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
prompt_len,
|
prompt_len,
|
||||||
padded_len,
|
padded_len,
|
||||||
device=self.alibi_slopes.device,
|
device=self.alibi_slopes.device,
|
||||||
|
dtype=dtype,
|
||||||
)[:, :, :, :prompt_len].copy_(bias)
|
)[:, :, :, :prompt_len].copy_(bias)
|
||||||
bias.mul_(self.alibi_slopes[:, None, None])
|
bias.mul_(self.alibi_slopes[:, None, None])
|
||||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||||
|
|||||||
Reference in New Issue
Block a user