Compare commits

..

10 Commits

Author SHA1 Message Date
Woosuk Kwon
90eb3f43ca Bump up the version to v0.1.7 (#1013)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.8) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9) (push) Has been cancelled
2023-09-11 00:54:30 -07:00
Woosuk Kwon
e67b4f2c2a Use FP32 in RoPE initialization (#1004)
Co-authored-by: One <imone@tuta.io>
2023-09-11 00:26:35 -07:00
Woosuk Kwon
d6770d1f23 Update setup.py (#1006) 2023-09-10 23:42:45 -07:00
Woosuk Kwon
b9cecc2635 [Docs] Update installation page (#1005) 2023-09-10 14:23:31 -07:00
Kyujin Cho
898285c9bf fix: CUDA error when inferencing with Falcon-40B base model (#992) 2023-09-10 01:39:02 -07:00
Antoni Baum
a62de9ecfd Fix wrong dtype in PagedAttentionWithALiBi bias (#996)
---------

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
2023-09-09 14:58:35 -07:00
Jingru
4042d192f5 fix "tansformers_module" ModuleNotFoundError when load model with trust_remote_code=True (#871) 2023-09-08 17:21:30 -07:00
Zhuohan Li
1117aa1411 Bump up the version to v0.1.6 (#989)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.8) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9) (push) Has been cancelled
2023-09-08 00:07:46 -07:00
Antoni Baum
080438477f Start background task in AsyncLLMEngine.generate (#988)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-09-08 00:03:39 -07:00
Robert Irvine
4b5bcf8906 faster startup of vLLM (#982)
* update

---------

Co-authored-by: Robert Irvine <robert@seamlessml.com>
2023-09-08 14:48:54 +09:00
12 changed files with 89 additions and 70 deletions

View File

@@ -3,31 +3,15 @@
Installation Installation
============ ============
vLLM is a Python library that also contains some C++ and CUDA code. vLLM is a Python library that also contains pre-compiled C++ and CUDA (11.8) binaries.
This additional code requires compilation on the user's machine.
Requirements Requirements
------------ ------------
* OS: Linux * OS: Linux
* Python: 3.8 or higher * Python: 3.8 -- 3.11
* CUDA: 11.0 -- 11.8
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.) * GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
.. note::
As of now, vLLM does not support CUDA 12.
If you are using Hopper or Lovelace GPUs, please use CUDA 11.8 instead of CUDA 12.
.. tip::
If you have trouble installing vLLM, we recommend using the NVIDIA PyTorch Docker image.
.. code-block:: console
$ # Pull the Docker image with CUDA 11.8.
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
Inside the Docker container, please execute :code:`pip uninstall torch` before installing vLLM.
Install with pip Install with pip
---------------- ----------------
@@ -40,7 +24,7 @@ You can install vLLM using pip:
$ conda activate myenv $ conda activate myenv
$ # Install vLLM. $ # Install vLLM.
$ pip install vllm # This may take 5-10 minutes. $ pip install vllm
.. _build_from_source: .. _build_from_source:
@@ -55,3 +39,11 @@ You can also build and install vLLM from source:
$ git clone https://github.com/vllm-project/vllm.git $ git clone https://github.com/vllm-project/vllm.git
$ cd vllm $ cd vllm
$ pip install -e . # This may take 5-10 minutes. $ pip install -e . # This may take 5-10 minutes.
.. tip::
If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image.
.. code-block:: console
$ # Pull the Docker image with CUDA 11.8.
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3

View File

@@ -22,7 +22,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
if CUDA_HOME is None: if CUDA_HOME is None:
raise RuntimeError( raise RuntimeError(
f"Cannot find CUDA_HOME. CUDA must be available to build the package.") "Cannot find CUDA_HOME. CUDA must be available to build the package.")
def get_nvcc_cuda_version(cuda_dir: str) -> Version: def get_nvcc_cuda_version(cuda_dir: str) -> Version:
@@ -54,7 +54,8 @@ if nvcc_cuda_version < Version("11.0"):
raise RuntimeError("CUDA 11.0 or higher is required to build the package.") raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"): if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
raise RuntimeError( raise RuntimeError(
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.") "CUDA 11.1 or higher is required for GPUs with compute capability 8.6."
)
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"): if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9. # CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by # However, GPUs with compute capability 8.9 can also run the code generated by
@@ -65,7 +66,8 @@ if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
compute_capabilities.add(80) compute_capabilities.add(80)
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
raise RuntimeError( raise RuntimeError(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.") "CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
)
# If no GPU is available, add all supported compute capabilities. # If no GPU is available, add all supported compute capabilities.
if not compute_capabilities: if not compute_capabilities:
@@ -78,7 +80,9 @@ if not compute_capabilities:
# Add target compute capabilities to NVCC flags. # Add target compute capabilities to NVCC flags.
for capability in compute_capabilities: for capability in compute_capabilities:
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"] NVCC_FLAGS += [
"-gencode", f"arch=compute_{capability},code=sm_{capability}"
]
# Use NVCC threads to parallelize the build. # Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"): if nvcc_cuda_version >= Version("11.2"):
@@ -91,7 +95,10 @@ ext_modules = []
cache_extension = CUDAExtension( cache_extension = CUDAExtension(
name="vllm.cache_ops", name="vllm.cache_ops",
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"], sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
) )
ext_modules.append(cache_extension) ext_modules.append(cache_extension)
@@ -99,7 +106,10 @@ ext_modules.append(cache_extension)
attention_extension = CUDAExtension( attention_extension = CUDAExtension(
name="vllm.attention_ops", name="vllm.attention_ops",
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"], sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
) )
ext_modules.append(attention_extension) ext_modules.append(attention_extension)
@@ -107,7 +117,10 @@ ext_modules.append(attention_extension)
positional_encoding_extension = CUDAExtension( positional_encoding_extension = CUDAExtension(
name="vllm.pos_encoding_ops", name="vllm.pos_encoding_ops",
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"], sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
) )
ext_modules.append(positional_encoding_extension) ext_modules.append(positional_encoding_extension)
@@ -115,7 +128,10 @@ ext_modules.append(positional_encoding_extension)
layernorm_extension = CUDAExtension( layernorm_extension = CUDAExtension(
name="vllm.layernorm_ops", name="vllm.layernorm_ops",
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"], sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
) )
ext_modules.append(layernorm_extension) ext_modules.append(layernorm_extension)
@@ -123,7 +139,10 @@ ext_modules.append(layernorm_extension)
activation_extension = CUDAExtension( activation_extension = CUDAExtension(
name="vllm.activation_ops", name="vllm.activation_ops",
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"], sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
) )
ext_modules.append(activation_extension) ext_modules.append(activation_extension)
@@ -138,8 +157,8 @@ def find_version(filepath: str):
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
""" """
with open(filepath) as fp: with open(filepath) as fp:
version_match = re.search( version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M) fp.read(), re.M)
if version_match: if version_match:
return version_match.group(1) return version_match.group(1)
raise RuntimeError("Unable to find version string.") raise RuntimeError("Unable to find version string.")
@@ -162,7 +181,8 @@ setuptools.setup(
version=find_version(get_path("vllm", "__init__.py")), version=find_version(get_path("vllm", "__init__.py")),
author="vLLM Team", author="vLLM Team",
license="Apache 2.0", license="Apache 2.0",
description="A high-throughput and memory-efficient inference and serving engine for LLMs", description=("A high-throughput and memory-efficient inference and "
"serving engine for LLMs"),
long_description=read_readme(), long_description=read_readme(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://github.com/vllm-project/vllm", url="https://github.com/vllm-project/vllm",
@@ -174,11 +194,12 @@ setuptools.setup(
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: Apache Software License", "License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
], ],
packages=setuptools.find_packages( packages=setuptools.find_packages(exclude=("benchmarks", "csrc", "docs",
exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")), "examples", "tests")),
python_requires=">=3.8", python_requires=">=3.8",
install_requires=get_requirements(), install_requires=get_requirements(),
ext_modules=ext_modules, ext_modules=ext_modules,

View File

@@ -40,8 +40,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args, engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
start_engine_loop=False)
vllm.entrypoints.api_server.engine = engine vllm.entrypoints.api_server.engine = engine
uvicorn.run( uvicorn.run(
app, app,

View File

@@ -133,9 +133,10 @@ def test_rotary_embedding(
device="cuda") device="cuda")
# Create the rotary embedding. # Create the rotary embedding.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) inv_freq = 1.0 / (base**(
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
t = torch.arange(max_position).float() t = torch.arange(max_position).float()
freqs = torch.einsum("i,j -> ij", t, inv_freq.float()) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1) cos_sin_cache = torch.cat((cos, sin), dim=-1)

View File

@@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
__version__ = "0.1.5" __version__ = "0.1.7"
__all__ = [ __all__ = [
"LLM", "LLM",

View File

@@ -114,8 +114,9 @@ class ModelConfig:
# Note: for falcon, when new_decoder_architecture is True, the # Note: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of # multi_query flag is ignored and we use n_head_kv for the number of
# KV heads. # KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = ( new_decoder_arch_falcon = (
self.hf_config.model_type == "falcon" self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False)) and getattr(self.hf_config, "new_decoder_architecture", False))
if not new_decoder_arch_falcon and getattr(self.hf_config, if not new_decoder_arch_falcon and getattr(self.hf_config,
"multi_query", False): "multi_query", False):

View File

@@ -230,6 +230,8 @@ class AsyncLLMEngine:
async frontend will be executed in a separate process as the async frontend will be executed in a separate process as the
model workers. model workers.
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args, *kwargs: Arguments for LLMEngine. *args, *kwargs: Arguments for LLMEngine.
""" """
@@ -240,7 +242,7 @@ class AsyncLLMEngine:
engine_use_ray: bool, engine_use_ray: bool,
*args, *args,
log_requests: bool = True, log_requests: bool = True,
start_engine_loop: bool = False, start_engine_loop: bool = True,
**kwargs) -> None: **kwargs) -> None:
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray self.engine_use_ray = engine_use_ray
@@ -249,8 +251,7 @@ class AsyncLLMEngine:
self.request_tracker: RequestTracker = RequestTracker() self.request_tracker: RequestTracker = RequestTracker()
self.background_loop = None self.background_loop = None
if start_engine_loop: self.start_engine_loop = start_engine_loop
self.start_background_loop()
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
@@ -330,11 +331,14 @@ class AsyncLLMEngine:
f"prompt token ids: {prompt_token_ids}.") f"prompt token ids: {prompt_token_ids}.")
if not self.is_running: if not self.is_running:
raise AsyncEngineDeadError( if self.start_engine_loop:
"Background loop is not running. If it was running, " self.start_background_loop()
"inspect the output to find the stacktrace of the " else:
"error that caused the background loop to stop " raise AsyncEngineDeadError(
"(AsyncEngineDeadError).") "Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
stream = self.request_tracker.add_request( stream = self.request_tracker.add_request(
request_id, request_id,
@@ -426,7 +430,7 @@ class AsyncLLMEngine:
@classmethod @classmethod
def from_engine_args(cls, def from_engine_args(cls,
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
start_engine_loop: bool = False) -> "AsyncLLMEngine": start_engine_loop: bool = True) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()

View File

@@ -153,7 +153,7 @@ class LLMEngine:
placement_group=placement_group, placement_group=placement_group,
placement_group_capture_child_tasks=True), placement_group_capture_child_tasks=True),
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorker).remote() )(RayWorker).remote(self.model_config.trust_remote_code)
self.workers.append(worker) self.workers.append(worker)
# Initialize torch distributed process group for the workers. # Initialize torch distributed process group for the workers.

View File

@@ -11,7 +11,11 @@ try:
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be """Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self) -> None: def __init__(self, init_cached_hf_modules=False) -> None:
if init_cached_hf_modules:
# pylint: disable=import-outside-toplevel
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
self.worker = None self.worker = None
def init_worker(self, worker_init_fn): def init_worker(self, worker_init_fn):

View File

@@ -32,9 +32,6 @@ async def generate(request: Request) -> Response:
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
request_id = random_uuid() request_id = random_uuid()
if not engine.is_running:
engine.start_background_loop()
results_generator = engine.generate(prompt, sampling_params, request_id) results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case # Streaming case
@@ -80,8 +77,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args, engine = AsyncLLMEngine.from_engine_args(engine_args)
start_engine_loop=False)
uvicorn.run(app, uvicorn.run(app,
host=args.host, host=args.host,

View File

@@ -192,9 +192,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
""" """
logger.info(f"Received chat completion request: {request}") logger.info(f"Received chat completion request: {request}")
if not engine.is_running:
engine.start_background_loop()
error_check_ret = await check_model(request) error_check_ret = await check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
@@ -367,9 +364,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
""" """
logger.info(f"Received completion request: {request}") logger.info(f"Received completion request: {request}")
if not engine.is_running:
engine.start_background_loop()
error_check_ret = await check_model(request) error_check_ret = await check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
@@ -627,8 +621,7 @@ if __name__ == "__main__":
served_model = args.model served_model = args.model
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args, engine = AsyncLLMEngine.from_engine_args(engine_args)
start_engine_loop=False)
engine_model_config = asyncio.run(engine.get_model_config()) engine_model_config = asyncio.run(engine.get_model_config())
max_model_len = engine_model_config.get_max_model_len() max_model_len = engine_model_config.get_max_model_len()

View File

@@ -73,7 +73,12 @@ class PagedAttention(nn.Module):
raise ValueError(f"head_size ({self.head_size}) is not supported. " raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
def set_attn_bias(self, input_metadata: InputMetadata) -> None: def set_attn_bias(
self,
input_metadata: InputMetadata,
dtype: torch.dtype,
) -> None:
del dtype # Unused.
if input_metadata.attn_bias: if input_metadata.attn_bias:
# Already set by a previous layer. # Already set by a previous layer.
return return
@@ -196,7 +201,7 @@ class PagedAttention(nn.Module):
if num_prompt_tokens > 0: if num_prompt_tokens > 0:
# Prompt run. # Prompt run.
assert input_metadata.num_generation_tokens == 0 assert input_metadata.num_generation_tokens == 0
self.set_attn_bias(input_metadata) self.set_attn_bias(input_metadata, dtype=query.dtype)
self.multi_query_kv_attention( self.multi_query_kv_attention(
output[:num_prompt_tokens], output[:num_prompt_tokens],
query[:num_prompt_tokens], query[:num_prompt_tokens],
@@ -259,9 +264,10 @@ class PagedAttentionWithRoPE(PagedAttention):
self.is_neox_style = is_neox_style self.is_neox_style = is_neox_style
# Create the cos and sin cache. # Create the cos and sin cache.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) inv_freq = 1.0 / (base**(torch.arange(
t = torch.arange(max_position).float() 0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
freqs = torch.einsum("i,j -> ij", t, inv_freq.float()) t = torch.arange(max_position, dtype=torch.float, device="cuda")
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
@@ -339,13 +345,14 @@ class PagedAttentionWithALiBi(PagedAttention):
slopes = torch.tensor(slopes, dtype=torch.float32) slopes = torch.tensor(slopes, dtype=torch.float32)
self.register_buffer("alibi_slopes", slopes, persistent=False) self.register_buffer("alibi_slopes", slopes, persistent=False)
def set_attn_bias(self, input_metadata: InputMetadata) -> None: def set_attn_bias(self, input_metadata: InputMetadata,
dtype: torch.dtype) -> None:
if input_metadata.attn_bias: if input_metadata.attn_bias:
# Already set by a previous layer. # Already set by a previous layer.
return return
# Generates ALiBi mask for each prompt. # Generates ALiBi mask for each prompt.
for prompt_len in input_metadata.prompt_lens: for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len) bias = torch.arange(prompt_len, dtype=dtype)
# Note(zhuohan): HF uses # Note(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)` # `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but # here. We find that both biases give the same results, but
@@ -363,6 +370,7 @@ class PagedAttentionWithALiBi(PagedAttention):
prompt_len, prompt_len,
padded_len, padded_len,
device=self.alibi_slopes.device, device=self.alibi_slopes.device,
dtype=dtype,
)[:, :, :, :prompt_len].copy_(bias) )[:, :, :, :prompt_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None]) bias.mul_(self.alibi_slopes[:, None, None])
attn_bias = LowerTriangularMaskWithTensorBias(bias) attn_bias = LowerTriangularMaskWithTensorBias(bias)