Compare commits

..

61 Commits

Author SHA1 Message Date
Woosuk Kwon
e2fb71ec9f Bump up the version to v0.2.0 (#1212)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.8) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9) (push) Has been cancelled
2023-09-28 15:30:38 -07:00
Woosuk Kwon
f936657eb6 Provide default max model length (#1224) 2023-09-28 14:44:02 -07:00
Woosuk Kwon
6f88f762bf Fix OOM in attention kernel test (#1223) 2023-09-28 14:33:24 -07:00
Woosuk Kwon
202351d5bf Add Mistral to supported model list (#1221) 2023-09-28 14:33:04 -07:00
Woosuk Kwon
2e8e49fce3 [Fix] Remove false assertion (#1222) 2023-09-28 10:52:38 -07:00
Woosuk Kwon
a8e98aee0c Fix Mistral model (#1220) 2023-09-28 10:44:05 -07:00
Chris Bamford
bb1ba58f06 [Mistral] Mistral-7B-v0.1 support (#1196)
Co-authored-by: timlacroix <t@mistral.ai>
2023-09-28 10:41:03 -07:00
Qing
7bedab5748 Add rope_scaling to Qwen (#1210) 2023-09-28 00:49:23 -07:00
Dan Lord
20f7cc4cde Add skip_special_tokens sampling params (#1186) 2023-09-27 19:21:42 -07:00
Danilo Peixoto
649aa730c5 Use standard extras for uvicorn (#1166) 2023-09-27 17:41:36 -07:00
Woosuk Kwon
a19bc5c628 Automatically configure max_num_batched_tokens (#1198) 2023-09-27 16:34:00 -07:00
Qing
28e616c4e3 fix qwen-14b model (#1173) 2023-09-27 16:33:16 -07:00
Wang Ran (汪然)
30e775281d fix typo (#1184)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-09-27 16:22:45 -07:00
Lily Liu
21877b0d75 Support Longchat and RoPE scaling (#555)
Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2023-09-27 03:36:02 -07:00
Antoni Baum
cf5cb1e33e Allocate more shared memory to attention kernel (#1154) 2023-09-26 22:27:13 -07:00
Woosuk Kwon
03ffd0a022 Add comments on RoPE initialization (#1176) 2023-09-26 10:48:33 -07:00
Woosuk Kwon
a425bd9a9a [Setup] Enable TORCH_CUDA_ARCH_LIST for selecting target GPUs (#1074) 2023-09-26 10:21:08 -07:00
Wen Sun
bbbf86565f Align max_tokens behavior with openai (#852) 2023-09-23 18:10:13 -07:00
Woosuk Kwon
9f6be8692e Fix config for Falcon (#1164) 2023-09-23 17:38:43 -07:00
Zhuohan Li
f187877945 [FIX] Simplify sampler logic (#1156) 2023-09-23 17:21:56 -07:00
Zhuohan Li
947b794146 [Sampler] Vectorized sampling (simplified) (#1048)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
2023-09-22 17:48:04 -07:00
Woosuk Kwon
8d926e91f1 Announce the First vLLM Meetup (#1148) 2023-09-22 11:37:14 -07:00
Nick Perez
4ee52bb169 Docs: Fix broken link to openai example (#1145)
Link to `openai_client.py` is no longer valid - updated to `openai_completion_client.py`
2023-09-22 11:36:09 -07:00
Woosuk Kwon
7d7e3b78a3 Use --ipc=host in docker run for distributed inference (#1125) 2023-09-21 18:26:47 -07:00
Ricardo Lu
f98b745a81 feat: support stop_token_ids parameter. (#1097) 2023-09-21 15:34:02 -07:00
Roy
2d1e86f1b1 clean api code, remove redundant background task. (#1102) 2023-09-21 13:25:05 -07:00
Woosuk Kwon
1ac4ccf73c Add float16 and float32 (#1115) 2023-09-21 00:52:47 -07:00
Woosuk Kwon
2ac4d5e2bf Replace DtypeTensor (#1123) 2023-09-21 00:51:47 -07:00
Antoni Baum
3302f0aef3 rope_theta and max_position_embeddings from config (#1096)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: wnma3mz <wnma3mz@gmail.com>
2023-09-20 13:35:11 -07:00
Tanmay Verma
6f2dd6c37e Add documentation to Triton server tutorial (#983) 2023-09-20 10:32:40 -07:00
Woosuk Kwon
bc0644574c Add gpu_memory_utilization and swap_space to LLM (#1090) 2023-09-19 22:16:04 -07:00
Woosuk Kwon
400b8289f7 Add pyarrow to dependencies & Print warning on Ray import error (#1094) 2023-09-18 22:36:17 -07:00
Zhuohan Li
c1026311b5 [Community] Add vLLM Discord server (#1086) 2023-09-18 12:23:35 -07:00
Woosuk Kwon
2b1c116b5a Add minimum capability requirement for AWQ (#1064) 2023-09-18 12:02:01 -07:00
Woosuk Kwon
cc796b1358 Convert before transpose (#1073) 2023-09-18 11:51:48 -07:00
Zhuohan Li
f029ef94d7 Fix get_max_num_running_seqs for waiting and swapped seq groups (#1068) 2023-09-18 11:49:40 -07:00
Roy
95592fa00a align llm_engine and async_engine. (#1081) 2023-09-18 11:49:10 -07:00
orellavie1212
fbe66e1d0b added support for quantize on LLM module (#1080) 2023-09-18 11:04:21 -07:00
Zhuohan Li
90979c38f8 [FIX] Don't initialize parameter by default (#1067) 2023-09-17 17:15:38 -07:00
陈序
e21d7687a9 Fix hanging when prompt exceeds limit (#1029) 2023-09-17 01:48:56 -07:00
Antoni Baum
ff36139ffc Remove AsyncLLMEngine busy loop, shield background task (#1059) 2023-09-17 00:29:08 -07:00
Woosuk Kwon
e3e79e9e8a Implement AWQ quantization support for LLaMA (#1032)
Co-authored-by: Robert Irvine <robert@seamlessml.com>
Co-authored-by: root <rirv938@gmail.com>
Co-authored-by: Casper <casperbh.96@gmail.com>
Co-authored-by: julian-q <julianhquevedo@gmail.com>
2023-09-16 00:03:37 -07:00
Jerry Yang
b9fe4616f9 Abort when coroutine is cancelled (#1020) 2023-09-14 17:40:18 -07:00
Woosuk Kwon
64ca424e75 Fix warning message on LLaMA FastTokenizer (#1037) 2023-09-14 17:33:32 -07:00
Lukas Kreussel
b5f93d0631 Only fail if logit_bias has actual values (#1045) 2023-09-14 17:33:01 -07:00
Woosuk Kwon
a58936966f Add pandas to requirements.txt (#1047)
* Add pandas to requirements.txt

* Minor
2023-09-14 17:31:38 -07:00
Antoni Baum
dd54a4b026 Fix detokenization leaving special tokens (#1044)
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
2023-09-14 16:37:03 -07:00
Woosuk Kwon
eda1a7cad3 Announce paper release (#1036) 2023-09-13 17:38:13 -07:00
Zhuohan Li
f04908cae7 [FIX] Minor bug fixes (#1035)
* [FIX] Minor bug fixes

* Address review comments
2023-09-13 16:38:12 -07:00
Jasmond L
ab019eea75 Add Model Revision Support (#1014)
Co-authored-by: Jasmond Loh <Jasmond.Loh@hotmail.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-09-13 15:20:02 -07:00
Antoni Baum
9841d48a10 Use TGI-like incremental detokenization (#984) 2023-09-13 13:38:01 -07:00
Ikko Eltociear Ashimine
3272d7a0b7 Fix typo in README.md (#1033) 2023-09-13 12:55:23 -07:00
Antoni Baum
0bb1e885a0 Make max_model_len configurable (#972) 2023-09-12 16:29:19 -07:00
leiwen83
d6545ad22e add option to shorten prompt print in log (#991)
Signed-off-by: Lei Wen <wenlei03@qiyi.com>
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-09-12 15:10:14 -07:00
Woosuk Kwon
90eb3f43ca Bump up the version to v0.1.7 (#1013)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.8) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9) (push) Has been cancelled
2023-09-11 00:54:30 -07:00
Woosuk Kwon
e67b4f2c2a Use FP32 in RoPE initialization (#1004)
Co-authored-by: One <imone@tuta.io>
2023-09-11 00:26:35 -07:00
Woosuk Kwon
d6770d1f23 Update setup.py (#1006) 2023-09-10 23:42:45 -07:00
Woosuk Kwon
b9cecc2635 [Docs] Update installation page (#1005) 2023-09-10 14:23:31 -07:00
Kyujin Cho
898285c9bf fix: CUDA error when inferencing with Falcon-40B base model (#992) 2023-09-10 01:39:02 -07:00
Antoni Baum
a62de9ecfd Fix wrong dtype in PagedAttentionWithALiBi bias (#996)
---------

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
2023-09-09 14:58:35 -07:00
Jingru
4042d192f5 fix "tansformers_module" ModuleNotFoundError when load model with trust_remote_code=True (#871) 2023-09-08 17:21:30 -07:00
73 changed files with 3523 additions and 832 deletions

4
.gitignore vendored
View File

@@ -173,3 +173,7 @@ cython_debug/
# Sphinx documentation # Sphinx documentation
_build/ _build/
# vim swap files
*.swo
*.swp

View File

@@ -10,13 +10,24 @@ Easy, fast, and cheap LLM serving for everyone
</h3> </h3>
<p align="center"> <p align="center">
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://github.com/vllm-project/vllm/discussions"><b>Discussions</b></a> | | <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
</p> </p>
--- ---
**The First vLLM Bay Area Meetup (Oct 5th 6pm-8pm PT)**
We are excited to invite you to the first vLLM meetup!
The vLLM team will share recent updates and roadmap.
We will also have vLLM users and contributors coming up to the stage to share their experiences.
Please register [here](https://lu.ma/first-vllm-meetup) and join us!
---
*Latest News* 🔥 *Latest News* 🔥
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM. - [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command! - [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds. - [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
@@ -35,13 +46,13 @@ vLLM is fast with:
vLLM is flexible and easy to use with: vLLM is flexible and easy to use with:
- Seamless integration with popular HuggingFace models - Seamless integration with popular Hugging Face models
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more - High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
- Tensor parallelism support for distributed inference - Tensor parallelism support for distributed inference
- Streaming outputs - Streaming outputs
- OpenAI-compatible API server - OpenAI-compatible API server
vLLM seamlessly supports many Huggingface models, including the following architectures: vLLM seamlessly supports many Hugging Face models, including the following architectures:
- Aquila (`BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.) - Aquila (`BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.) - Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
@@ -53,6 +64,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) - GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) - InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
@@ -72,7 +84,7 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started
## Performance ## Performance
vLLM outperforms HuggingFace Transformers (HF) by up to 24x and Text Generation Inference (TGI) by up to 3.5x, in terms of throughput. vLLM outperforms Hugging Face Transformers (HF) by up to 24x and Text Generation Inference (TGI) by up to 3.5x, in terms of throughput.
For details, check out our [blog post](https://vllm.ai). For details, check out our [blog post](https://vllm.ai).
<p align="center"> <p align="center">
@@ -104,3 +116,15 @@ For details, check out our [blog post](https://vllm.ai).
We welcome and value any contributions and collaborations. We welcome and value any contributions and collaborations.
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved. Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
## Citation
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
```bibtex
@inproceedings{kwon2023efficient,
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
year={2023}
}
```

View File

@@ -18,6 +18,7 @@ def main(args: argparse.Namespace):
llm = LLM( llm = LLM(
model=args.model, model=args.model,
tokenizer=args.tokenizer, tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size, tensor_parallel_size=args.tensor_parallel_size,
max_num_seqs=args.batch_size, max_num_seqs=args.batch_size,
max_num_batched_tokens=args.batch_size * args.input_len, max_num_batched_tokens=args.batch_size * args.input_len,
@@ -63,19 +64,28 @@ def main(args: argparse.Namespace):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Benchmark the latency of processing a single batch of ' description='Benchmark the latency of processing a single batch of '
'requests till completion.') 'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m') parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128) parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8) parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n', type=int, default=1, parser.add_argument('--n',
type=int,
default=1,
help='Number of generated sequences per prompt.') help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true') parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters', type=int, default=3, parser.add_argument('--num-iters',
type=int,
default=3,
help='Number of iterations to run.') help='Number of iterations to run.')
parser.add_argument('--trust-remote-code', action='store_true', parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface') help='trust remote code from huggingface')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@@ -3,7 +3,7 @@ import argparse
import json import json
import random import random
import time import time
from typing import List, Tuple from typing import List, Optional, Tuple
import torch import torch
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
@@ -22,15 +22,10 @@ def sample_requests(
with open(dataset_path) as f: with open(dataset_path) as f:
dataset = json.load(f) dataset = json.load(f)
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
data for data in dataset
if len(data["conversations"]) >= 2
]
# Only keep the first two turns of each conversation. # Only keep the first two turns of each conversation.
dataset = [ dataset = [(data["conversations"][0]["value"],
(data["conversations"][0]["value"], data["conversations"][1]["value"]) data["conversations"][1]["value"]) for data in dataset]
for data in dataset
]
# Tokenize the prompts and completions. # Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset] prompts = [prompt for prompt, _ in dataset]
@@ -63,6 +58,7 @@ def run_vllm(
requests: List[Tuple[str, int, int]], requests: List[Tuple[str, int, int]],
model: str, model: str,
tokenizer: str, tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int, tensor_parallel_size: int,
seed: int, seed: int,
n: int, n: int,
@@ -72,6 +68,7 @@ def run_vllm(
llm = LLM( llm = LLM(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
seed=seed, seed=seed,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
@@ -111,8 +108,8 @@ def run_hf(
trust_remote_code: bool, trust_remote_code: bool,
) -> float: ) -> float:
assert not use_beam_search assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(model, llm = AutoModelForCausalLM.from_pretrained(
torch_dtype=torch.float16, trust_remote_code=trust_remote_code) model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama": if llm.config.model_type == "llama":
# To enable padding in the HF backend. # To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
@@ -132,13 +129,14 @@ def run_hf(
if len(batch) < max_batch_size and i != len(requests) - 1: if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch. # Check if we can add more requests to the batch.
_, next_prompt_len, next_output_len = requests[i + 1] _, next_prompt_len, next_output_len = requests[i + 1]
if (max(max_prompt_len, next_prompt_len) + max( if (max(max_prompt_len, next_prompt_len) +
max_output_len, next_output_len)) <= 2048: max(max_output_len, next_output_len)) <= 2048:
# We can add more requests to the batch. # We can add more requests to the batch.
continue continue
# Generate the sequences. # Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids input_ids = tokenizer(batch, return_tensors="pt",
padding=True).input_ids
llm_outputs = llm.generate( llm_outputs = llm.generate(
input_ids=input_ids.cuda(), input_ids=input_ids.cuda(),
do_sample=not use_beam_search, do_sample=not use_beam_search,
@@ -165,44 +163,58 @@ def main(args: argparse.Namespace):
random.seed(args.seed) random.seed(args.seed)
# Sample the requests. # Sample the requests.
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) tokenizer = get_tokenizer(args.tokenizer,
trust_remote_code=args.trust_remote_code)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer) requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm( elapsed_time = run_vllm(requests, args.model, args.tokenizer,
requests, args.model, args.tokenizer, args.tensor_parallel_size, args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search, args.trust_remote_code) args.seed, args.n, args.use_beam_search,
args.trust_remote_code)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf( elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
requests, args.model, tokenizer, args.n, args.use_beam_search, args.use_beam_search, args.hf_max_batch_size,
args.hf_max_batch_size, args.trust_remote_code) args.trust_remote_code)
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum( total_num_tokens = sum(prompt_len + output_len
prompt_len + output_len for _, prompt_len, output_len in requests)
for _, prompt_len, output_len in requests
)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s") f"{total_num_tokens / elapsed_time:.2f} tokens/s")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.") parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", type=str, choices=["vllm", "hf"], parser.add_argument("--backend",
type=str,
choices=["vllm", "hf"],
default="vllm") default="vllm")
parser.add_argument("--dataset", type=str, required=True, parser.add_argument("--dataset",
type=str,
required=True,
help="Path to the dataset.") help="Path to the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m") parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", type=int, default=1, parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.") help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000, parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.") help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size", type=int, default=None, parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.") help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code', parser.add_argument('--trust-remote-code',
action='store_true', action='store_true',
@@ -215,6 +227,8 @@ if __name__ == "__main__":
elif args.backend == "hf": elif args.backend == "hf":
if args.hf_max_batch_size is None: if args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend.") raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model

View File

@@ -341,6 +341,9 @@ __global__ void single_query_cached_kv_attention_kernel(
} // namespace vllm } // namespace vllm
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
cudaFuncSetAttribute( \
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \ vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \ <<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \ out_ptr, \
@@ -401,6 +404,8 @@ void single_query_cached_kv_attention_launcher(
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_context_len * sizeof(float); int logits_size = padded_max_context_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size); int shared_mem_size = std::max(logits_size, outputs_size);
dim3 grid(num_heads, num_seqs); dim3 grid(num_heads, num_seqs);

13
csrc/cuda_utils.cpp Normal file
View File

@@ -0,0 +1,13 @@
#include <torch/extension.h>
int get_device_attribute(
int attribute,
int device_id);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"get_device_attribute",
&get_device_attribute,
"Gets the specified device attribute.");
}

View File

@@ -0,0 +1,14 @@
int get_device_attribute(
int attribute,
int device_id)
{
int device, value;
if (device_id < 0) {
cudaGetDevice(&device);
}
else {
device = device_id;
}
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
return value;
}

15
csrc/quantization.cpp Normal file
View File

@@ -0,0 +1,15 @@
#include <torch/extension.h>
torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"awq_gemm",
&awq_gemm,
"Quantized GEMM for AWQ");
}

View File

@@ -0,0 +1,87 @@
/*
Adapted from https://github.com/mit-han-lab/llm-awq
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#pragma once
namespace vllm {
namespace awq {
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result;
#endif
}
} // namespace awq
} // namespace vllm

View File

@@ -0,0 +1,491 @@
/*
Adapted from https://github.com/mit-han-lab/llm-awq
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include "dequantize.cuh"
#include <cuda_fp16.h>
namespace vllm {
namespace awq {
// Pack two half values.
static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (128 + 8)];
__shared__ half scaling_factors_shared[128];
__shared__ half zeros_shared[128];
int j_factors1 = ((OC + 128 - 1) / 128);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
half A_shared_warp[8];
half B_shared_warp[32];
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[(j_0_4_init * 8) + i] = 0.0;
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / 128;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * 2
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
+ (((int)threadIdx.x) % (128 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
+ (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
+ (((int)threadIdx.x) % (128 / 8)) * 8;
int* zeros_ptr = zeros
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
+ ((int)threadIdx.x) % (128 / 8);
half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx_y) % j_factors1) * (128)
+ (((int)threadIdx.x) % (128 / 8)) * 8;
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 128
+ ((int)threadIdx.y) * 64
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag)
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
: "r"(addr)
);
}
for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
}
}
}
// TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
if (row_offset < M)
{
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
}
}
}
#endif
}
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (64 + 8)];
__shared__ half scaling_factors_shared[64];
__shared__ half zeros_shared[64];
int j_factors1 = ((OC + 64 - 1) / 64);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
half A_shared_warp[8];
half B_shared_warp[16];
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[(j_0_4_init * 8) + i] = 0.0;
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / 64;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * 4
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ (((int)threadIdx.x) % (64 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
int* zeros_ptr = zeros
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ ((int)threadIdx.x) % (64 / 8);
half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx_y) % j_factors1) * (64)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64
+ ((int)threadIdx.y) * 32
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag)
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
{
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
: "r"(addr)
);
}
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
{
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
{
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
}
}
}
// TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
if (row_offset < M)
{
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
}
}
}
#endif
}
} // namespace awq
} // namespace vllm
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
// assume that batch_size < 16 for now
torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters)
{
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
int num_out_feats = _out_feats.size(-2);
int num_out_channels = _out_feats.size(-1);
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
int group_size = num_in_channels / _scaling_factors.size(0);
if (num_out_channels % 64 != 0)
throw std::invalid_argument("OC is not multiple of cta_N = 64");
if (num_out_channels % 8 != 0)
throw std::invalid_argument("OC is not multiple of pack_num = 8");
if (group_size % 32 != 0)
throw std::invalid_argument("Group size should be a multiple of 32");
if (num_out_channels % group_size != 0)
throw std::invalid_argument("OC is not multiple of Group size");
if (num_out_channels % 128 == 0)
{
int j_factors1 = num_out_channels / 128 / 1;
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else if (num_out_channels % 64 == 0)
{
int j_factors1 = num_out_channels / 64 / 1;
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
return _out_feats.sum(0);
}

View File

@@ -3,31 +3,15 @@
Installation Installation
============ ============
vLLM is a Python library that also contains some C++ and CUDA code. vLLM is a Python library that also contains pre-compiled C++ and CUDA (11.8) binaries.
This additional code requires compilation on the user's machine.
Requirements Requirements
------------ ------------
* OS: Linux * OS: Linux
* Python: 3.8 or higher * Python: 3.8 -- 3.11
* CUDA: 11.0 -- 11.8
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.) * GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
.. note::
As of now, vLLM does not support CUDA 12.
If you are using Hopper or Lovelace GPUs, please use CUDA 11.8 instead of CUDA 12.
.. tip::
If you have trouble installing vLLM, we recommend using the NVIDIA PyTorch Docker image.
.. code-block:: console
$ # Pull the Docker image with CUDA 11.8.
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
Inside the Docker container, please execute :code:`pip uninstall torch` before installing vLLM.
Install with pip Install with pip
---------------- ----------------
@@ -40,7 +24,7 @@ You can install vLLM using pip:
$ conda activate myenv $ conda activate myenv
$ # Install vLLM. $ # Install vLLM.
$ pip install vllm # This may take 5-10 minutes. $ pip install vllm
.. _build_from_source: .. _build_from_source:
@@ -55,3 +39,12 @@ You can also build and install vLLM from source:
$ git clone https://github.com/vllm-project/vllm.git $ git clone https://github.com/vllm-project/vllm.git
$ cd vllm $ cd vllm
$ pip install -e . # This may take 5-10 minutes. $ pip install -e . # This may take 5-10 minutes.
.. tip::
If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image.
.. code-block:: console
$ # Pull the Docker image with CUDA 11.8.
$ # Use `--ipc=host` to make sure the shared memory is large enough.
$ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:22.12-py3

View File

@@ -128,4 +128,4 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep
prompt="San Francisco is a") prompt="San Francisco is a")
print("Completion result:", completion) print("Completion result:", completion)
For a more detailed client example, refer to `examples/openai_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_client.py>`_. For a more detailed client example, refer to `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.

View File

@@ -43,6 +43,7 @@ vLLM is flexible and easy to use with:
For more information, check out the following: For more information, check out the following:
* `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention) * `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention)
* `vLLM paper <https://arxiv.org/abs/2309.06180>`_ (SOSP 2023)
* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al. * `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al.
@@ -63,6 +64,7 @@ Documentation
serving/distributed_serving serving/distributed_serving
serving/run_on_sky serving/run_on_sky
serving/deploying_with_triton
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1

View File

@@ -44,6 +44,9 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`LlamaForCausalLM` * - :code:`LlamaForCausalLM`
- LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco - LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc. - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc.
* - :code:`MistralForCausalLM`
- Mistral, Mistral-Instruct
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
* - :code:`MPTForCausalLM` * - :code:`MPTForCausalLM`
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc. - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.

View File

@@ -0,0 +1,6 @@
.. _deploying_with_triton:
Deploying with NVIDIA Triton
============================
The `Triton Inference Server <https://github.com/triton-inference-server>`_ hosts a tutorial demonstrating how to quickly deploy a simple `facebook/opt-125m <https://huggingface.co/facebook/opt-125m>`_ model using vLLM. Please see `Deploying a vLLM model in Triton <https://github.com/triton-inference-server/tutorials/blob/main/Quick_Deploy/vLLM/README.md#deploying-a-vllm-model-in-triton>`_ for more details.

View File

@@ -11,3 +11,4 @@ types-setuptools
# testing # testing
pytest pytest
pytest-forked pytest-forked
pytest-asyncio

View File

@@ -1,11 +1,13 @@
ninja # For faster builds. ninja # For faster builds.
psutil psutil
ray >= 2.5.1 ray >= 2.5.1
pandas # Required for Ray data.
pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy numpy
torch >= 2.0.0 torch >= 2.0.0
transformers >= 4.33.1 # Required for Code Llama. transformers >= 4.33.1 # Required for Code Llama.
xformers >= 0.0.21 xformers >= 0.0.22
fastapi fastapi
uvicorn uvicorn[standard]
pydantic < 2 # Required for OpenAI server. pydantic < 2 # Required for OpenAI server.

171
setup.py
View File

@@ -3,6 +3,7 @@ import os
import re import re
import subprocess import subprocess
from typing import List, Set from typing import List, Set
import warnings
from packaging.version import parse, Version from packaging.version import parse, Version
import setuptools import setuptools
@@ -11,6 +12,9 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
ROOT_DIR = os.path.dirname(__file__) ROOT_DIR = os.path.dirname(__file__)
# Supported NVIDIA GPU architectures.
SUPPORTED_ARCHS = ["7.0", "7.5", "8.0", "8.6", "8.9", "9.0"]
# Compiler flags. # Compiler flags.
CXX_FLAGS = ["-g", "-O2", "-std=c++17"] CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
# TODO(woosuk): Should we use -O3? # TODO(woosuk): Should we use -O3?
@@ -22,7 +26,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
if CUDA_HOME is None: if CUDA_HOME is None:
raise RuntimeError( raise RuntimeError(
f"Cannot find CUDA_HOME. CUDA must be available to build the package.") "Cannot find CUDA_HOME. CUDA must be available to build the package.")
def get_nvcc_cuda_version(cuda_dir: str) -> Version: def get_nvcc_cuda_version(cuda_dir: str) -> Version:
@@ -38,47 +42,82 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
return nvcc_cuda_version return nvcc_cuda_version
# Collect the compute capabilities of all available GPUs. def get_torch_arch_list() -> Set[str]:
device_count = torch.cuda.device_count() # TORCH_CUDA_ARCH_LIST can have one or more architectures,
compute_capabilities: Set[int] = set() # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
for i in range(device_count): # compiler to additionally include PTX code that can be runtime-compiled
major, minor = torch.cuda.get_device_capability(i) # and executed on the 8.6 or newer architectures. While the PTX code will
if major < 7: # not give the best performance on the newer architectures, it provides
raise RuntimeError( # forward compatibility.
"GPUs with compute capability less than 7.0 are not supported.") valid_arch_strs = SUPPORTED_ARCHS + [s + "+PTX" for s in SUPPORTED_ARCHS]
compute_capabilities.add(major * 10 + minor) arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
if arch_list is None:
return set()
# List are separated by ; or space.
arch_list = arch_list.replace(" ", ";").split(";")
for arch in arch_list:
if arch not in valid_arch_strs:
raise ValueError(
f"Unsupported CUDA arch ({arch}). "
f"Valid CUDA arch strings are: {valid_arch_strs}.")
return set(arch_list)
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()
if not compute_capabilities:
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# GPUs on the current machine.
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
if major < 7:
raise RuntimeError(
"GPUs with compute capability below 7.0 are not supported.")
compute_capabilities.add(f"{major}.{minor}")
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if not compute_capabilities:
# If no GPU is specified nor available, add all supported architectures
# based on the NVCC CUDA version.
compute_capabilities = set(SUPPORTED_ARCHS)
if nvcc_cuda_version < Version("11.1"):
compute_capabilities.remove("8.6")
if nvcc_cuda_version < Version("11.8"):
compute_capabilities.remove("8.9")
compute_capabilities.remove("9.0")
# Validate the NVCC CUDA version. # Validate the NVCC CUDA version.
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if nvcc_cuda_version < Version("11.0"): if nvcc_cuda_version < Version("11.0"):
raise RuntimeError("CUDA 11.0 or higher is required to build the package.") raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"): if nvcc_cuda_version < Version("11.1"):
raise RuntimeError( if any(cc.startswith("8.6") for cc in compute_capabilities):
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.") raise RuntimeError(
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"): "CUDA 11.1 or higher is required for compute capability 8.6.")
# CUDA 11.8 is required to generate the code targeting compute capability 8.9. if nvcc_cuda_version < Version("11.8"):
# However, GPUs with compute capability 8.9 can also run the code generated by if any(cc.startswith("8.9") for cc in compute_capabilities):
# the previous versions of CUDA 11 and targeting compute capability 8.0. # CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 # However, GPUs with compute capability 8.9 can also run the code generated by
# instead of 8.9. # the previous versions of CUDA 11 and targeting compute capability 8.0.
compute_capabilities.remove(89) # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
compute_capabilities.add(80) # instead of 8.9.
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): warnings.warn(
raise RuntimeError( "CUDA 11.8 or higher is required for compute capability 8.9. "
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.") "Targeting compute capability 8.0 instead.")
compute_capabilities = set(cc for cc in compute_capabilities
# If no GPU is available, add all supported compute capabilities. if not cc.startswith("8.9"))
if not compute_capabilities: compute_capabilities.add("8.0+PTX")
compute_capabilities = {70, 75, 80} if any(cc.startswith("9.0") for cc in compute_capabilities):
if nvcc_cuda_version >= Version("11.1"): raise RuntimeError(
compute_capabilities.add(86) "CUDA 11.8 or higher is required for compute capability 9.0.")
if nvcc_cuda_version >= Version("11.8"):
compute_capabilities.add(89)
compute_capabilities.add(90)
# Add target compute capabilities to NVCC flags. # Add target compute capabilities to NVCC flags.
for capability in compute_capabilities: for capability in compute_capabilities:
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"] num = capability[0] + capability[2]
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
if capability.endswith("+PTX"):
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
# Use NVCC threads to parallelize the build. # Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"): if nvcc_cuda_version >= Version("11.2"):
@@ -91,7 +130,10 @@ ext_modules = []
cache_extension = CUDAExtension( cache_extension = CUDAExtension(
name="vllm.cache_ops", name="vllm.cache_ops",
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"], sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
) )
ext_modules.append(cache_extension) ext_modules.append(cache_extension)
@@ -99,7 +141,10 @@ ext_modules.append(cache_extension)
attention_extension = CUDAExtension( attention_extension = CUDAExtension(
name="vllm.attention_ops", name="vllm.attention_ops",
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"], sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
) )
ext_modules.append(attention_extension) ext_modules.append(attention_extension)
@@ -107,7 +152,10 @@ ext_modules.append(attention_extension)
positional_encoding_extension = CUDAExtension( positional_encoding_extension = CUDAExtension(
name="vllm.pos_encoding_ops", name="vllm.pos_encoding_ops",
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"], sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
) )
ext_modules.append(positional_encoding_extension) ext_modules.append(positional_encoding_extension)
@@ -115,7 +163,10 @@ ext_modules.append(positional_encoding_extension)
layernorm_extension = CUDAExtension( layernorm_extension = CUDAExtension(
name="vllm.layernorm_ops", name="vllm.layernorm_ops",
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"], sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
) )
ext_modules.append(layernorm_extension) ext_modules.append(layernorm_extension)
@@ -123,10 +174,38 @@ ext_modules.append(layernorm_extension)
activation_extension = CUDAExtension( activation_extension = CUDAExtension(
name="vllm.activation_ops", name="vllm.activation_ops",
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"], sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
) )
ext_modules.append(activation_extension) ext_modules.append(activation_extension)
# Quantization kernels.
quantization_extension = CUDAExtension(
name="vllm.quantization_ops",
sources=[
"csrc/quantization.cpp",
"csrc/quantization/awq/gemm_kernels.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(quantization_extension)
# Misc. CUDA utils.
cuda_utils_extension = CUDAExtension(
name="vllm.cuda_utils",
sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(cuda_utils_extension)
def get_path(*filepath) -> str: def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath) return os.path.join(ROOT_DIR, *filepath)
@@ -138,8 +217,8 @@ def find_version(filepath: str):
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
""" """
with open(filepath) as fp: with open(filepath) as fp:
version_match = re.search( version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M) fp.read(), re.M)
if version_match: if version_match:
return version_match.group(1) return version_match.group(1)
raise RuntimeError("Unable to find version string.") raise RuntimeError("Unable to find version string.")
@@ -162,7 +241,8 @@ setuptools.setup(
version=find_version(get_path("vllm", "__init__.py")), version=find_version(get_path("vllm", "__init__.py")),
author="vLLM Team", author="vLLM Team",
license="Apache 2.0", license="Apache 2.0",
description="A high-throughput and memory-efficient inference and serving engine for LLMs", description=("A high-throughput and memory-efficient inference and "
"serving engine for LLMs"),
long_description=read_readme(), long_description=read_readme(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://github.com/vllm-project/vllm", url="https://github.com/vllm-project/vllm",
@@ -174,11 +254,12 @@ setuptools.setup(
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: Apache Software License", "License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
], ],
packages=setuptools.find_packages( packages=setuptools.find_packages(exclude=("benchmarks", "csrc", "docs",
exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")), "examples", "tests")),
python_requires=">=3.8", python_requires=">=3.8",
install_requires=get_requirements(), install_requires=get_requirements(),
ext_modules=ext_modules, ext_modules=ext_modules,

View File

@@ -0,0 +1,80 @@
import asyncio
from dataclasses import dataclass
import pytest
from vllm.engine.async_llm_engine import AsyncLLMEngine
@dataclass
class RequestOutput:
request_id: int
finished: bool = False
class MockEngine:
def __init__(self):
self.step_calls = 0
self.add_request_calls = 0
self.abort_request_calls = 0
self.request_id = None
async def step_async(self):
self.step_calls += 1
return [RequestOutput(
request_id=self.request_id)] if self.request_id else []
def generate(self, request_id):
self.request_id = request_id
def stop_generating(self):
self.request_id = None
def add_request(self, **kwargs):
self.add_request_calls += 1
return
def abort_request(self, request_id):
self.abort_request_calls += 1
return
class MockAsyncLLMEngine(AsyncLLMEngine):
def _init_engine(self, *args, **kwargs):
return MockEngine()
@pytest.mark.asyncio
async def test_new_requests_event():
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0
await engine.add_request("1", "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 1
assert engine.engine.step_calls == 1
await engine.add_request("2", "", None)
engine.engine.generate("2")
await asyncio.sleep(0)
assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls == 2
await asyncio.sleep(0)
assert engine.engine.step_calls == 3
engine.engine.stop_generating()
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await engine.add_request("3", "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5

View File

@@ -4,10 +4,25 @@ from vllm.engine.async_llm_engine import RequestTracker
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
class DummyEvent:
def __init__(self):
self._flag = False
def set(self):
self._flag = True
def clear(self):
self._flag = False
def test_request_tracker(): def test_request_tracker():
tracker = RequestTracker() tracker = RequestTracker()
tracker.new_requests_event = DummyEvent()
stream_1 = tracker.add_request("1") stream_1 = tracker.add_request("1")
assert tracker.new_requests_event._flag
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert len(new) == 1 assert len(new) == 1
assert new[0]["request_id"] == "1" assert new[0]["request_id"] == "1"
assert not finished assert not finished
@@ -15,7 +30,9 @@ def test_request_tracker():
stream_2 = tracker.add_request("2") stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3") stream_3 = tracker.add_request("3")
assert tracker.new_requests_event._flag
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert len(new) == 2 assert len(new) == 2
assert new[0]["request_id"] == "2" assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3" assert new[1]["request_id"] == "3"
@@ -26,6 +43,7 @@ def test_request_tracker():
# request_ids must be unique # request_ids must be unique
with pytest.raises(KeyError): with pytest.raises(KeyError):
tracker.add_request("1") tracker.add_request("1")
assert not tracker.new_requests_event._flag
tracker.abort_request("1") tracker.abort_request("1")
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
@@ -36,6 +54,7 @@ def test_request_tracker():
stream_4 = tracker.add_request("4") stream_4 = tracker.add_request("4")
tracker.abort_request("4") tracker.abort_request("4")
assert tracker.new_requests_event._flag
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1 assert len(finished) == 1
assert "4" in finished assert "4" in finished
@@ -43,9 +62,11 @@ def test_request_tracker():
assert stream_4.finished assert stream_4.finished
stream_5 = tracker.add_request("5") stream_5 = tracker.add_request("5")
assert tracker.new_requests_event._flag
tracker.process_request_output( tracker.process_request_output(
RequestOutput("2", "output", [], [], finished=True)) RequestOutput("2", "output", [], [], finished=True))
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert len(finished) == 1 assert len(finished) == 1
assert "2" in finished assert "2" in finished
assert len(new) == 1 assert len(new) == 1

View File

@@ -0,0 +1,62 @@
import pytest
from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import detokenize_incrementally
TRUTH = [
"Hello here, this is a simple test",
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",
"我很感谢你的热情"
]
TOKENIZERS = [
"facebook/opt-125m",
"gpt2",
"bigcode/tiny_starcoder_py",
"EleutherAI/gpt-j-6b",
"EleutherAI/pythia-70m",
"bigscience/bloom-560m",
"mosaicml/mpt-7b",
"tiiuae/falcon-7b",
"meta-llama/Llama-2-7b-hf",
"codellama/CodeLlama-7b-hf",
]
def _run_incremental_decode(tokenizer, all_input_ids,
skip_special_tokens: bool):
decoded_text = ""
offset = 0
token_offset = 0
prev_tokens = None
for i in range(len(all_input_ids)):
new_tokens, text, offset, token_offset = detokenize_incrementally(
tokenizer,
all_input_ids[:i + 1],
prev_tokens,
offset,
token_offset,
skip_special_tokens=skip_special_tokens)
decoded_text += text
if prev_tokens is None:
prev_tokens = new_tokens
else:
prev_tokens += new_tokens
return decoded_text
@pytest.mark.parametrize("truth", TRUTH)
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False))
def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
if skip_special_tokens:
all_input_ids = ([tokenizer.bos_token_id]
if tokenizer.bos_token_id is not None else
[]) + all_input_ids + [tokenizer.eos_token_id]
decoded_text = _run_incremental_decode(
tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
assert decoded_text == truth

View File

@@ -7,8 +7,12 @@ from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm import attention_ops from vllm import attention_ops
from vllm.utils import get_max_shared_memory_bytes
MAX_SEQ_LEN = 8192 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 128 # Arbitrary values for testing NUM_BLOCKS = 128 # Arbitrary values for testing
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -135,6 +139,7 @@ def test_single_query_cached_kv_attention(
device="cuda") device="cuda")
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
context_lens[-1] = MAX_SEQ_LEN
max_context_len = max(context_lens) max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
@@ -242,7 +247,11 @@ def test_multi_query_kv_attention(
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here.
max_len = min(MAX_SEQ_LEN, 4096)
seq_lens = random.sample(range(1, max_len), num_seqs)
num_tokens = sum(seq_lens) num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))

View File

@@ -133,9 +133,10 @@ def test_rotary_embedding(
device="cuda") device="cuda")
# Create the rotary embedding. # Create the rotary embedding.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) inv_freq = 1.0 / (base**(
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
t = torch.arange(max_position).float() t = torch.arange(max_position).float()
freqs = torch.einsum("i,j -> ij", t, inv_freq.float()) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1) cos_sin_cache = torch.cat((cos, sin), dim=-1)

View File

@@ -0,0 +1,184 @@
import pytest
import random
from typing import Tuple
from unittest.mock import patch
import torch
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.worker import Worker
class MockLogitsSampler(Sampler):
def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
super().__init__(vocab_size=vocab_size)
self.fake_logits = fake_logits
def forward(self, *args, **kwargs):
with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
lambda x, y: x):
with patch("vllm.model_executor.layers.sampler._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs)
def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024),
device="cuda",
dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
worker = Worker(None, None, None)
worker.block_size = 16
return input_tensor, fake_logits, sampler, worker
RANDOM_SEEDS = list(range(128))
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_greedy(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0, ),
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
assert nth_output.output_token == expected[i].item()
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_random(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
for i in range(batch_size):
fake_logits[i, i] = 1e2
seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
),
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
assert nth_output.output_token == i
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_beam(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
),
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# when handling an all-beam search case.
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_mixed(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
seq_group_metadata_list = []
expected_tokens = []
for i in range(batch_size):
n = 1
sampling_type = random.randint(0, 2)
if sampling_type == 0:
sampling_params = SamplingParams(temperature=0)
elif sampling_type == 1:
n = random.randint(1, 10)
sampling_params = SamplingParams(
temperature=random.random() + 0.1,
top_p=min(random.random() + 0.1, 1),
top_k=random.randint(0, 10) or -1,
n=n,
presence_penalty=random.randint(0, 1),
)
else:
sampling_params = SamplingParams(temperature=0,
use_beam_search=True,
best_of=2)
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected_tokens.append(i + idx)
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
if seq_group_metadata_list[i].sampling_params.use_beam_search:
continue
for nth_output in sequence_output:
assert nth_output.output_token in expected_tokens

View File

@@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
__version__ = "0.1.6" __version__ = "0.2.0"
__all__ = [ __all__ = [
"LLM", "LLM",

View File

@@ -38,6 +38,13 @@ class ModelConfig:
will use FP16 precision for FP32 and FP16 models, and BF16 precision will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models. for BF16 models.
seed: Random seed for reproducibility. seed: Random seed for reproducibility.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default
version.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
""" """
def __init__( def __init__(
@@ -50,6 +57,9 @@ class ModelConfig:
load_format: str, load_format: str,
dtype: str, dtype: str,
seed: int, seed: int,
revision: Optional[str] = None,
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
) -> None: ) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
@@ -58,11 +68,16 @@ class ModelConfig:
self.download_dir = download_dir self.download_dir = download_dir
self.load_format = load_format self.load_format = load_format
self.seed = seed self.seed = seed
self.revision = revision
self.quantization = quantization
self.hf_config = get_config(model, trust_remote_code) self.hf_config = get_config(model, trust_remote_code, revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_config,
max_model_len)
self._verify_load_format() self._verify_load_format()
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
self._verify_quantization()
def _verify_load_format(self) -> None: def _verify_load_format(self) -> None:
load_format = self.load_format.lower() load_format = self.load_format.lower()
@@ -82,6 +97,17 @@ class ModelConfig:
"either 'auto' or 'slow'.") "either 'auto' or 'slow'.")
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None:
supported_quantization = ["awq"]
if self.quantization is None:
return
quantization = self.quantization.lower()
if quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization: {self.quantization}. Must be one of "
f"{supported_quantization}.")
self.quantization = quantization
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
parallel_config: "ParallelConfig", parallel_config: "ParallelConfig",
@@ -109,22 +135,28 @@ class ModelConfig:
# FIXME(woosuk): This may not be true for all models. # FIXME(woosuk): This may not be true for all models.
return self.hf_config.hidden_size // self.hf_config.num_attention_heads return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_num_heads(self, parallel_config: "ParallelConfig") -> int: def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU worker."""
# For GPTBigCode & Falcon: # For GPTBigCode & Falcon:
# Note: for falcon, when new_decoder_architecture is True, the # Note: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of # multi_query flag is ignored and we use n_head_kv for the number of
# KV heads. # KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = ( new_decoder_arch_falcon = (
self.hf_config.model_type == "falcon" self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False)) and getattr(self.hf_config, "new_decoder_architecture", False))
if not new_decoder_arch_falcon and getattr(self.hf_config, if not new_decoder_arch_falcon and getattr(self.hf_config,
"multi_query", False): "multi_query", False):
# Multi-query attention, only one KV head. # Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1 return 1
# For Falcon: # For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None: if getattr(self.hf_config, "n_head_kv", None) is not None:
return (self.hf_config.n_head_kv // return (self.hf_config.n_head_kv //
parallel_config.tensor_parallel_size) parallel_config.tensor_parallel_size)
if getattr(self.hf_config, "num_kv_heads", None) is not None:
return (self.hf_config.num_kv_heads //
parallel_config.tensor_parallel_size)
# For LLaMA-2: # For LLaMA-2:
if getattr(self.hf_config, "num_key_value_heads", None) is not None: if getattr(self.hf_config, "num_key_value_heads", None) is not None:
return (self.hf_config.num_key_value_heads // return (self.hf_config.num_key_value_heads //
@@ -132,26 +164,6 @@ class ModelConfig:
total_num_attention_heads = self.hf_config.num_attention_heads total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size return total_num_attention_heads // parallel_config.tensor_parallel_size
def get_max_model_len(self) -> int:
max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(self.hf_config, key, None)
if max_len_key is not None:
max_model_len = min(max_model_len, max_len_key)
return max_model_len
def get_num_layers(self, parallel_config: "ParallelConfig") -> int: def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers total_num_hidden_layers = self.hf_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size return total_num_hidden_layers // parallel_config.pipeline_parallel_size
@@ -172,10 +184,12 @@ class CacheConfig:
block_size: int, block_size: int,
gpu_memory_utilization: float, gpu_memory_utilization: float,
swap_space: int, swap_space: int,
sliding_window: Optional[int] = None,
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB self.swap_space_bytes = swap_space * _GB
self.sliding_window = sliding_window
self._verify_args() self._verify_args()
# Will be set after profiling. # Will be set after profiling.
@@ -251,11 +265,36 @@ class SchedulerConfig:
and generated text). and generated text).
""" """
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int, def __init__(
max_model_len: int) -> None: self,
self.max_num_batched_tokens = max_num_batched_tokens max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
# If max_model_len is too short, use 2048 as the default value for
# higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
self._verify_args()
def _verify_args(self) -> None:
if self.max_num_batched_tokens < self.max_model_len:
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
f"smaller than max_model_len ({self.max_model_len}). "
"This effectively limits the maximum sequence length to "
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len.")
if self.max_num_batched_tokens < self.max_num_seqs:
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
"be greater than or equal to max_num_seqs "
f"({self.max_num_seqs}).")
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
@@ -311,3 +350,56 @@ def _get_and_verify_dtype(
f"of at least 8.0. Your {gpu_name} GPU has compute capability " f"of at least 8.0. Your {gpu_name} GPU has compute capability "
f"{compute_capability[0]}.{compute_capability[1]}.") f"{compute_capability[0]}.{compute_capability[1]}.")
return torch_dtype return torch_dtype
def _get_and_verify_max_len(
hf_config: PretrainedConfig,
max_model_len: Optional[int],
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(hf_config, key, None)
if max_len_key is not None:
derived_max_model_len = min(derived_max_model_len, max_len_key)
if derived_max_model_len == float("inf"):
if max_model_len is not None:
# If max_model_len is specified, we use it.
return max_model_len
default_max_len = 2048
logger.warning(
"The model's config.json does not contain any of the following "
"keys to determine the original maximum length of the model: "
f"{possible_keys}. Assuming the model's maximum length is "
f"{default_max_len}.")
derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
derived_max_model_len *= scaling_factor
if max_model_len is None:
max_model_len = derived_max_model_len
elif max_model_len > derived_max_model_len:
raise ValueError(
f"User-specified max_model_len ({max_model_len}) is greater than "
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
" in model's config.json). This may lead to incorrect model "
"outputs or CUDA errors. Make sure the value is correct and "
"within the model context size.")
return int(max_model_len)

View File

@@ -63,10 +63,18 @@ class BlockSpaceManager:
num_gpu_blocks: int, num_gpu_blocks: int,
num_cpu_blocks: int, num_cpu_blocks: int,
watermark: float = 0.01, watermark: float = 0.01,
sliding_window: Optional[int] = None,
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.num_total_gpu_blocks = num_gpu_blocks self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks self.num_total_cpu_blocks = num_cpu_blocks
self.block_sliding_window = None
if sliding_window is not None:
assert sliding_window % block_size == 0, (sliding_window,
block_size)
self.block_sliding_window = sliding_window // block_size
self.watermark = watermark self.watermark = watermark
assert watermark >= 0.0 assert watermark >= 0.0
@@ -83,6 +91,9 @@ class BlockSpaceManager:
# the same prompt. This may not be true for preempted sequences. # the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs()[0] seq = seq_group.get_seqs()[0]
num_required_blocks = len(seq.logical_token_blocks) num_required_blocks = len(seq.logical_token_blocks)
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
self.block_sliding_window)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction. # Use watermark to avoid frequent cache eviction.
return (num_free_gpu_blocks - num_required_blocks >= return (num_free_gpu_blocks - num_required_blocks >=
@@ -95,8 +106,12 @@ class BlockSpaceManager:
# Allocate new physical token blocks that will store the prompt tokens. # Allocate new physical token blocks that will store the prompt tokens.
block_table: BlockTable = [] block_table: BlockTable = []
for _ in range(len(seq.logical_token_blocks)): for logical_idx in range(len(seq.logical_token_blocks)):
block = self.gpu_allocator.allocate() if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window):
block = block_table[logical_idx % self.block_sliding_window]
else:
block = self.gpu_allocator.allocate()
# Set the reference counts of the token blocks. # Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs() block.ref_count = seq_group.num_seqs()
block_table.append(block) block_table.append(block)
@@ -118,11 +133,17 @@ class BlockSpaceManager:
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
if len(block_table) < len(logical_blocks): if len(block_table) < len(logical_blocks):
# The sequence has a new logical block. if (self.block_sliding_window
# Allocate a new physical block. and len(block_table) >= self.block_sliding_window):
block = self.gpu_allocator.allocate() # re-use a block
block_table.append(block) block_table.append(block_table[len(block_table) %
return None self.block_sliding_window])
else:
# The sequence has a new logical block.
# Allocate a new physical block.
block = self.gpu_allocator.allocate()
block_table.append(block)
return None
# We want to append the token to the last physical block. # We want to append the token to the last physical block.
last_block = block_table[-1] last_block = block_table[-1]
@@ -154,9 +175,7 @@ class BlockSpaceManager:
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs():
if seq.is_finished(): if seq.is_finished():
continue continue
block_table = self.block_tables[seq.seq_id] blocks.update(self.block_tables[seq.seq_id])
for block in block_table:
blocks.add(block)
return list(blocks) return list(blocks)
def can_swap_in(self, seq_group: SequenceGroup) -> bool: def can_swap_in(self, seq_group: SequenceGroup) -> bool:
@@ -224,7 +243,7 @@ class BlockSpaceManager:
return block_number_mapping return block_number_mapping
def _free_block_table(self, block_table: BlockTable) -> None: def _free_block_table(self, block_table: BlockTable) -> None:
for block in block_table: for block in set(block_table):
if block.device == Device.GPU: if block.device == Device.GPU:
self.gpu_allocator.free(block) self.gpu_allocator.free(block)
else: else:

View File

@@ -73,7 +73,7 @@ class Scheduler:
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
num_gpu_blocks=self.cache_config.num_gpu_blocks, num_gpu_blocks=self.cache_config.num_gpu_blocks,
num_cpu_blocks=self.cache_config.num_cpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks,
) sliding_window=self.cache_config.sliding_window)
# TODO(zhuohan): Use deque instead of list for better performance. # TODO(zhuohan): Use deque instead of list for better performance.
# Sequence groups in the WAITING state. # Sequence groups in the WAITING state.
@@ -175,7 +175,7 @@ class Scheduler:
num_curr_seqs += num_new_seqs num_curr_seqs += num_new_seqs
scheduled.append(seq_group) scheduled.append(seq_group)
if scheduled: if scheduled or ignored_seq_groups:
scheduler_outputs = SchedulerOutputs( scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled, scheduled_seq_groups=scheduled,
prompt_run=True, prompt_run=True,

View File

@@ -18,20 +18,22 @@ class EngineArgs:
load_format: str = 'auto' load_format: str = 'auto'
dtype: str = 'auto' dtype: str = 'auto'
seed: int = 0 seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False worker_use_ray: bool = False
pipeline_parallel_size: int = 1 pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1 tensor_parallel_size: int = 1
block_size: int = 16 block_size: int = 16
swap_space: int = 4 # GiB swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90 gpu_memory_utilization: float = 0.90
max_num_batched_tokens: int = 2560 max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256 max_num_seqs: int = 256
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None
quantization: Optional[str] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
self.tokenizer = self.model self.tokenizer = self.model
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
@@ -48,6 +50,13 @@ class EngineArgs:
type=str, type=str,
default=EngineArgs.tokenizer, default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use') help='name or path of the huggingface tokenizer to use')
parser.add_argument(
'--revision',
type=str,
default=None,
help='the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument('--tokenizer-mode', parser.add_argument('--tokenizer-mode',
type=str, type=str,
default=EngineArgs.tokenizer_mode, default=EngineArgs.tokenizer_mode,
@@ -79,16 +88,22 @@ class EngineArgs:
'a numpy cache to speed up the loading. ' 'a numpy cache to speed up the loading. '
'"dummy" will initialize the weights with random values, ' '"dummy" will initialize the weights with random values, '
'which is mainly for profiling.') 'which is mainly for profiling.')
# TODO(woosuk): Support FP32.
parser.add_argument( parser.add_argument(
'--dtype', '--dtype',
type=str, type=str,
default=EngineArgs.dtype, default=EngineArgs.dtype,
choices=['auto', 'half', 'bfloat16', 'float'], choices=[
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
],
help='data type for model weights and activations. ' help='data type for model weights and activations. '
'The "auto" option will use FP16 precision ' 'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision ' 'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.') 'for BF16 models.')
parser.add_argument('--max-model-len',
type=int,
default=None,
help='model context length. If unspecified, '
'will be automatically derived from the model.')
# Parallel arguments # Parallel arguments
parser.add_argument('--worker-use-ray', parser.add_argument('--worker-use-ray',
action='store_true', action='store_true',
@@ -136,6 +151,13 @@ class EngineArgs:
parser.add_argument('--disable-log-stats', parser.add_argument('--disable-log-stats',
action='store_true', action='store_true',
help='disable logging statistics') help='disable logging statistics')
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=str,
choices=['awq', None],
default=None,
help='Method used to quantize the weights')
return parser return parser
@classmethod @classmethod
@@ -149,20 +171,20 @@ class EngineArgs:
def create_engine_configs( def create_engine_configs(
self, self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs.
model_config = ModelConfig(self.model, self.tokenizer, model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code, self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format, self.download_dir, self.load_format,
self.dtype, self.seed) self.dtype, self.seed, self.revision,
cache_config = CacheConfig(self.block_size, self.max_model_len, self.quantization)
self.gpu_memory_utilization, cache_config = CacheConfig(
self.swap_space) self.block_size, self.gpu_memory_utilization, self.swap_space,
getattr(model_config.hf_config, 'sliding_window', None))
parallel_config = ParallelConfig(self.pipeline_parallel_size, parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size, self.tensor_parallel_size,
self.worker_use_ray) self.worker_use_ray)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs, self.max_num_seqs,
model_config.get_max_model_len()) model_config.max_model_len)
return model_config, cache_config, parallel_config, scheduler_config return model_config, cache_config, parallel_config, scheduler_config
@@ -171,6 +193,7 @@ class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine.""" """Arguments for asynchronous vLLM engine."""
engine_use_ray: bool = False engine_use_ray: bool = False
disable_log_requests: bool = False disable_log_requests: bool = False
max_log_len: Optional[int] = None
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
@@ -183,4 +206,10 @@ class AsyncEngineArgs(EngineArgs):
parser.add_argument('--disable-log-requests', parser.add_argument('--disable-log-requests',
action='store_true', action='store_true',
help='disable logging requests') help='disable logging requests')
parser.add_argument('--max-log-len',
type=int,
default=None,
help='max number of prompt characters or prompt '
'ID numbers being printed in log. '
'Default: unlimited.')
return parser return parser

View File

@@ -1,7 +1,8 @@
import asyncio import asyncio
import time import time
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union)
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
@@ -78,14 +79,24 @@ class RequestTracker:
self._finished_requests: asyncio.Queue[str] = asyncio.Queue() self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream, self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue() dict]] = asyncio.Queue()
self.new_requests_event = None
def __contains__(self, item): def __contains__(self, item):
return item in self._request_streams return item in self._request_streams
def propagate_exception(self, exc: Exception) -> None: def init_event(self):
"""Propagate an exception to all request streams.""" self.new_requests_event = asyncio.Event()
for stream in self._request_streams.values():
stream.put(exc) def propagate_exception(self,
exc: Exception,
request_id: Optional[str] = None) -> None:
"""Propagate an exception to request streams
(all if request_id is None)."""
if request_id is not None:
self._request_streams[request_id].put(exc)
else:
for stream in self._request_streams.values():
stream.put(exc)
def process_request_output(self, def process_request_output(self,
request_output: RequestOutput, request_output: RequestOutput,
@@ -112,6 +123,9 @@ class RequestTracker:
"request_id": request_id, "request_id": request_id,
**engine_add_request_kwargs **engine_add_request_kwargs
})) }))
self.new_requests_event.set()
return stream return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None: def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
@@ -148,8 +162,13 @@ class RequestTracker:
self._request_streams[stream.request_id] = stream self._request_streams[stream.request_id] = stream
new_requests.append(new_request) new_requests.append(new_request)
self.new_requests_event.clear()
return new_requests, finished_requests return new_requests, finished_requests
async def wait_for_new_requests(self):
await self.new_requests_event.wait()
class _AsyncLLMEngine(LLMEngine): class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods.""" """Extension of LLMEngine to add async methods."""
@@ -164,10 +183,9 @@ class _AsyncLLMEngine(LLMEngine):
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
(seq_group_metadata_list, scheduler_outputs, seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
early_return) = self._schedule() if scheduler_outputs.is_empty():
if early_return is not None: return ignored
return early_return
# Execute the model. # Execute the model.
output = await self._run_workers_async( output = await self._run_workers_async(
@@ -178,7 +196,7 @@ class _AsyncLLMEngine(LLMEngine):
blocks_to_copy=scheduler_outputs.blocks_to_copy, blocks_to_copy=scheduler_outputs.blocks_to_copy,
) )
return self._process_model_outputs(output, scheduler_outputs) return self._process_model_outputs(output, scheduler_outputs) + ignored
async def _run_workers_async( async def _run_workers_async(
self, self,
@@ -242,16 +260,22 @@ class AsyncLLMEngine:
engine_use_ray: bool, engine_use_ray: bool,
*args, *args,
log_requests: bool = True, log_requests: bool = True,
max_log_len: Optional[int] = None,
start_engine_loop: bool = True, start_engine_loop: bool = True,
**kwargs) -> None: **kwargs) -> None:
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray self.engine_use_ray = engine_use_ray
self.log_requests = log_requests self.log_requests = log_requests
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs) self.engine = self._init_engine(*args, **kwargs)
self.request_tracker: RequestTracker = RequestTracker()
self.background_loop = None self.background_loop = None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# collected
self._background_loop_unshielded = None
self.start_engine_loop = start_engine_loop self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
@@ -262,11 +286,14 @@ class AsyncLLMEngine:
"""Start the background loop.""" """Start the background loop."""
if self.is_running: if self.is_running:
raise RuntimeError("Background loop is already running.") raise RuntimeError("Background loop is already running.")
self.background_loop = asyncio.get_event_loop().create_task( self._request_tracker.init_event()
self.run_engine_loop())
self.background_loop.add_done_callback( self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish, partial(_raise_exception_on_finish,
request_tracker=self.request_tracker)) request_tracker=self._request_tracker))
self.background_loop = asyncio.shield(self._background_loop_unshielded)
def _init_engine(self, *args, def _init_engine(self, *args,
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]: **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
@@ -278,11 +305,13 @@ class AsyncLLMEngine:
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
return engine_class(*args, **kwargs) return engine_class(*args, **kwargs)
async def engine_step(self): async def engine_step(self) -> bool:
"""Kick the engine to process the waiting requests.""" """Kick the engine to process the waiting requests.
Returns True if there are in-progress requests."""
new_requests, finished_requests = ( new_requests, finished_requests = (
self.request_tracker.get_new_and_finished_requests()) self._request_tracker.get_new_and_finished_requests())
for new_request in new_requests: for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue. # Add the request into the vLLM engine's waiting queue.
@@ -302,9 +331,11 @@ class AsyncLLMEngine:
# Put the outputs into the corresponding streams. # Put the outputs into the corresponding streams.
for request_output in request_outputs: for request_output in request_outputs:
self.request_tracker.process_request_output( self._request_tracker.process_request_output(
request_output, verbose=self.log_requests) request_output, verbose=self.log_requests)
return len(request_outputs) > 0
async def _engine_abort(self, request_ids: Iterable[str]): async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.abort_request.remote(request_ids) await self.engine.abort_request.remote(request_ids)
@@ -312,8 +343,12 @@ class AsyncLLMEngine:
self.engine.abort_request(request_ids) self.engine.abort_request(request_ids)
async def run_engine_loop(self): async def run_engine_loop(self):
# Initialize the RequestTracker here so it uses the right event loop.
has_requests_in_progress = False
while True: while True:
await self.engine_step() if not has_requests_in_progress:
await self._request_tracker.wait_for_new_requests()
has_requests_in_progress = await self.engine_step()
await asyncio.sleep(0) await asyncio.sleep(0)
async def add_request( async def add_request(
@@ -325,10 +360,18 @@ class AsyncLLMEngine:
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
) -> AsyncStream: ) -> AsyncStream:
if self.log_requests: if self.log_requests:
shortened_prompt = prompt
shortened_token_ids = prompt_token_ids
if self.max_log_len is not None:
if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:self.max_log_len]
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:self.
max_log_len]
logger.info(f"Received request {request_id}: " logger.info(f"Received request {request_id}: "
f"prompt: {prompt!r}, " f"prompt: {shortened_prompt!r}, "
f"sampling params: {sampling_params}, " f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.") f"prompt token ids: {shortened_token_ids}.")
if not self.is_running: if not self.is_running:
if self.start_engine_loop: if self.start_engine_loop:
@@ -340,7 +383,7 @@ class AsyncLLMEngine:
"error that caused the background loop to stop " "error that caused the background loop to stop "
"(AsyncEngineDeadError).") "(AsyncEngineDeadError).")
stream = self.request_tracker.add_request( stream = self._request_tracker.add_request(
request_id, request_id,
prompt=prompt, prompt=prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
@@ -385,8 +428,9 @@ class AsyncLLMEngine:
async for request_output in stream: async for request_output in stream:
yield request_output yield request_output
except Exception as e: except (Exception, asyncio.CancelledError) as e:
# If there is an exception, abort the request. # If there is an exception or coroutine is cancelled, abort the
# request.
self._abort(request_id) self._abort(request_id)
raise e raise e
@@ -417,8 +461,8 @@ class AsyncLLMEngine:
Args: Args:
request_id: The unique id of the request. request_id: The unique id of the request.
""" """
self.request_tracker.abort_request(request_id, self._request_tracker.abort_request(request_id,
verbose=self.log_requests) verbose=self.log_requests)
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""
@@ -446,5 +490,6 @@ class AsyncLLMEngine:
placement_group, placement_group,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop) start_engine_loop=start_engine_loop)
return engine return engine

View File

@@ -54,8 +54,8 @@ class LLMEngine:
scheduler_config: The configuration related to the request scheduler. scheduler_config: The configuration related to the request scheduler.
distributed_init_method: The initialization method for distributed distributed_init_method: The initialization method for distributed
execution. See `torch.distributed.init_process_group` for details. execution. See `torch.distributed.init_process_group` for details.
stage_devices: The list of devices for each stage. Each stage is a list placement_group: Ray placement group for distributed execution.
of (rank, node_resource, device) tuples. Required for distributed execution.
log_stats: Whether to log statistics. log_stats: Whether to log statistics.
""" """
@@ -74,16 +74,21 @@ class LLMEngine:
f"model={model_config.model!r}, " f"model={model_config.model!r}, "
f"tokenizer={model_config.tokenizer!r}, " f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, " f"tokenizer_mode={model_config.tokenizer_mode}, "
f"revision={model_config.revision}, "
f"trust_remote_code={model_config.trust_remote_code}, " f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, " f"dtype={model_config.dtype}, "
f"max_seq_len={model_config.max_model_len}, "
f"download_dir={model_config.download_dir!r}, " f"download_dir={model_config.download_dir!r}, "
f"load_format={model_config.load_format}, " f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"quantization={model_config.quantization}, "
f"seed={model_config.seed})") f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
assert self.cache_config.sliding_window == getattr(
self.model_config.hf_config, "sliding_window", None)
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.log_stats = log_stats self.log_stats = log_stats
@@ -92,7 +97,8 @@ class LLMEngine:
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
model_config.tokenizer, model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode, tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code) trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision)
self.seq_counter = Counter() self.seq_counter = Counter()
# Create the parallel GPU workers. # Create the parallel GPU workers.
@@ -153,7 +159,7 @@ class LLMEngine:
placement_group=placement_group, placement_group=placement_group,
placement_group_capture_child_tasks=True), placement_group_capture_child_tasks=True),
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorker).remote() )(RayWorker).remote(self.model_config.trust_remote_code)
self.workers.append(worker) self.workers.append(worker)
# Initialize torch distributed process group for the workers. # Initialize torch distributed process group for the workers.
@@ -291,14 +297,12 @@ class LLMEngine:
def _schedule( def _schedule(
self self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
Optional[List[RequestOutput]]]: List[RequestOutput]]:
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if scheduler_outputs.is_empty(): return seq_group_metadata_list, scheduler_outputs, [
return seq_group_metadata_list, scheduler_outputs, [ RequestOutput.from_seq_group(seq_group)
RequestOutput.from_seq_group(seq_group) for seq_group in scheduler_outputs.ignored_seq_groups
for seq_group in scheduler_outputs.ignored_seq_groups ]
]
return seq_group_metadata_list, scheduler_outputs, None
def _check_beam_search_early_stopping( def _check_beam_search_early_stopping(
self, self,
@@ -386,7 +390,7 @@ class LLMEngine:
child_seqs.append((parent, parent)) child_seqs.append((parent, parent))
for seq, _ in child_seqs: for seq, _ in child_seqs:
self._decode_sequence(seq) self._decode_sequence(seq, seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params) self._check_stop(seq, seq_group.sampling_params)
# Non-beam search case # Non-beam search case
@@ -542,10 +546,9 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
(seq_group_metadata_list, scheduler_outputs, seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
early_return) = self._schedule() if scheduler_outputs.is_empty():
if early_return is not None: return ignored
return early_return
# Execute the model. # Execute the model.
output = self._run_workers( output = self._run_workers(
@@ -556,7 +559,7 @@ class LLMEngine:
blocks_to_copy=scheduler_outputs.blocks_to_copy, blocks_to_copy=scheduler_outputs.blocks_to_copy,
) )
return self._process_model_outputs(output, scheduler_outputs) return self._process_model_outputs(output, scheduler_outputs) + ignored
def _log_system_stats( def _log_system_stats(
self, self,
@@ -621,17 +624,25 @@ class LLMEngine:
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now self.last_logging_time = now
def _decode_sequence(self, seq: Sequence) -> None: def _decode_sequence(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
"""Decodes the new token for a sequence.""" """Decodes the new token for a sequence."""
new_token, new_output_text = detokenize_incrementally( (new_tokens, new_output_text, prefix_offset,
self.tokenizer, read_offset) = detokenize_incrementally(
seq.output_tokens, self.tokenizer,
seq.get_last_token_id(), all_input_ids=seq.get_token_ids(),
skip_special_tokens=True, prev_tokens=seq.tokens,
) prefix_offset=seq.prefix_offset,
if new_token is not None: read_offset=seq.read_offset,
seq.output_tokens.append(new_token) skip_special_tokens=sampling_params.skip_special_tokens,
seq.output_text = new_output_text )
if seq.tokens is None:
seq.tokens = new_tokens
else:
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset
seq.read_offset = read_offset
seq.output_text += new_output_text
def _check_stop(self, seq: Sequence, def _check_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None: sampling_params: SamplingParams) -> None:
@@ -643,6 +654,9 @@ class LLMEngine:
seq.output_text = seq.output_text[:-len(stop_str)] seq.output_text = seq.output_text[:-len(stop_str)]
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
return return
if seq.get_last_token_id() in sampling_params.stop_token_ids:
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if the sequence has reached max_model_len. # Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len: if seq.get_len() > self.scheduler_config.max_model_len:

View File

@@ -2,6 +2,9 @@ import socket
from typing import Optional, Tuple, TYPE_CHECKING from typing import Optional, Tuple, TYPE_CHECKING
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
try: try:
import ray import ray
@@ -11,7 +14,11 @@ try:
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be """Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self) -> None: def __init__(self, init_cached_hf_modules=False) -> None:
if init_cached_hf_modules:
# pylint: disable=import-outside-toplevel
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
self.worker = None self.worker = None
def init_worker(self, worker_init_fn): def init_worker(self, worker_init_fn):
@@ -24,7 +31,10 @@ try:
executor = getattr(self, method) executor = getattr(self, method)
return executor(*args, **kwargs) return executor(*args, **kwargs)
except ImportError: except ImportError as e:
logger.warning(f"Failed to import Ray with {e!r}. "
"For distributed inference, please install Ray with "
"`pip install ray pandas pyarrow`.")
ray = None ray = None
TorchDistributedWorker = None TorchDistributedWorker = None
RayWorker = None # pylint: disable=invalid-name RayWorker = None # pylint: disable=invalid-name
@@ -53,11 +63,10 @@ def initialize_cluster(
the default Ray cluster address. the default Ray cluster address.
Returns: Returns:
A tuple of (`distributed_init_method`, `all_stage_devices`). The A tuple of (`distributed_init_method`, `placement_group`). The
`distributed_init_method` is the address for initializing the `distributed_init_method` is the address for initializing the
distributed backend. `all_stage_devices` includes device IDs for distributed backend. `placement_group` includes the specification
each worker in each pipeline stage. Each device ID is a tuple of of the resources for each distributed worker.
(rank, node resource, device id).
""" """
if parallel_config.worker_use_ray or engine_use_ray: if parallel_config.worker_use_ray or engine_use_ray:
if ray is None: if ray is None:

View File

@@ -2,7 +2,7 @@ import argparse
import json import json
from typing import AsyncGenerator from typing import AsyncGenerator
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn import uvicorn
@@ -44,14 +44,8 @@ async def generate(request: Request) -> Response:
ret = {"text": text_outputs} ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
async def abort_request() -> None:
await engine.abort(request_id)
if stream: if stream:
background_tasks = BackgroundTasks() return StreamingResponse(stream_results())
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)
# Non-streaming case # Non-streaming case
final_output = None final_output = None

View File

@@ -37,7 +37,22 @@ class LLM:
the `torch_dtype` attribute specified in the model config file. the `torch_dtype` attribute specified in the model config file.
However, if the `torch_dtype` in the config is `float32`, we will However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead. use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq". If None, we assume the model weights are not
quantized and use `dtype` to determine the data type of the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
seed: The seed to initialize the random number generator for sampling. seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
""" """
def __init__( def __init__(
@@ -48,7 +63,11 @@ class LLM:
trust_remote_code: bool = False, trust_remote_code: bool = False,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
dtype: str = "auto", dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
seed: int = 0, seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
**kwargs, **kwargs,
) -> None: ) -> None:
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
@@ -60,7 +79,11 @@ class LLM:
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
dtype=dtype, dtype=dtype,
quantization=quantization,
revision=revision,
seed=seed, seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
**kwargs, **kwargs,
) )
self.llm_engine = LLMEngine.from_engine_args(engine_args) self.llm_engine = LLMEngine.from_engine_args(engine_args)

View File

@@ -10,7 +10,7 @@ from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
import fastapi import fastapi
import uvicorn import uvicorn
from fastapi import BackgroundTasks, Request from fastapi import Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
@@ -130,6 +130,8 @@ async def check_length(
input_ids = tokenizer(prompt).input_ids input_ids = tokenizer(prompt).input_ids
token_num = len(input_ids) token_num = len(input_ids)
if request.max_tokens is None:
request.max_tokens = max_model_len - token_num
if token_num + request.max_tokens > max_model_len: if token_num + request.max_tokens > max_model_len:
return input_ids, create_error_response( return input_ids, create_error_response(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
@@ -196,7 +198,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
if request.logit_bias is not None: if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine. # TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported") "logit_bias is not currently supported")
@@ -217,11 +219,13 @@ async def create_chat_completion(request: ChatCompletionRequest,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
stop=request.stop, stop=request.stop,
stop_token_ids=request.stop_token_ids,
max_tokens=request.max_tokens, max_tokens=request.max_tokens,
best_of=request.best_of, best_of=request.best_of,
top_k=request.top_k, top_k=request.top_k,
ignore_eos=request.ignore_eos, ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search, use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens,
) )
except ValueError as e: except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
@@ -229,9 +233,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
result_generator = engine.generate(prompt, sampling_params, request_id, result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids) token_ids)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json( def create_stream_response_json(
index: int, index: int,
text: str, text: str,
@@ -291,19 +292,15 @@ async def create_chat_completion(request: ChatCompletionRequest,
# Streaming response # Streaming response
if request.stream: if request.stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(), return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream", media_type="text/event-stream")
background=background_tasks)
# Non-streaming response # Non-streaming response
final_res: RequestOutput = None final_res: RequestOutput = None
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await abort_request() await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected") "Client disconnected")
final_res = res final_res = res
@@ -379,7 +376,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported") "suffix is not currently supported")
if request.logit_bias is not None: if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine. # TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported") "logit_bias is not currently supported")
@@ -425,10 +422,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
top_p=request.top_p, top_p=request.top_p,
top_k=request.top_k, top_k=request.top_k,
stop=request.stop, stop=request.stop,
stop_token_ids=request.stop_token_ids,
ignore_eos=request.ignore_eos, ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens, max_tokens=request.max_tokens,
logprobs=request.logprobs, logprobs=request.logprobs,
use_beam_search=request.use_beam_search, use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens,
) )
except ValueError as e: except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
@@ -448,9 +447,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
and (request.best_of is None or request.n == request.best_of) and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search) and not request.use_beam_search)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json( def create_stream_response_json(
index: int, index: int,
text: str, text: str,
@@ -510,19 +506,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
# Streaming response # Streaming response
if stream: if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(), return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream", media_type="text/event-stream")
background=background_tasks)
# Non-streaming response # Non-streaming response
final_res: RequestOutput = None final_res: RequestOutput = None
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await abort_request() await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected") "Client disconnected")
final_res = res final_res = res
@@ -623,7 +615,7 @@ if __name__ == "__main__":
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config()) engine_model_config = asyncio.run(engine.get_model_config())
max_model_len = engine_model_config.get_max_model_len() max_model_len = engine_model_config.max_model_len
# A separate tokenizer to map token IDs to strings. # A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(engine_args.tokenizer, tokenizer = get_tokenizer(engine_args.tokenizer,

View File

@@ -58,7 +58,7 @@ class ChatCompletionRequest(BaseModel):
temperature: Optional[float] = 0.7 temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0 top_p: Optional[float] = 1.0
n: Optional[int] = 1 n: Optional[int] = 1
max_tokens: Optional[int] = 16 max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
@@ -70,6 +70,8 @@ class ChatCompletionRequest(BaseModel):
top_k: Optional[int] = -1 top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
@@ -94,6 +96,8 @@ class CompletionRequest(BaseModel):
top_k: Optional[int] = -1 top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
class LogProbs(BaseModel): class LogProbs(BaseModel):

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
from xformers.ops import AttentionBias from xformers.ops import AttentionBias
@@ -29,6 +29,7 @@ class InputMetadata:
context_lens: torch.Tensor, context_lens: torch.Tensor,
max_context_len: int, max_context_len: int,
block_tables: torch.Tensor, block_tables: torch.Tensor,
sliding_window: Optional[int] = None,
) -> None: ) -> None:
self.seq_groups = seq_groups self.seq_groups = seq_groups
self.seq_data = seq_data self.seq_data = seq_data
@@ -38,6 +39,24 @@ class InputMetadata:
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.block_tables = block_tables self.block_tables = block_tables
self.to_cache = None
if sliding_window is not None:
# We need to keep the positions of sliding windows within
# the key / value tables, this is helpful to know which
# elements we need to cache and where
to_cache, start_idx = [], 0
for prompt_len in self.prompt_lens:
to_cache.extend(
range(
start_idx + max(0, prompt_len - sliding_window),
start_idx + prompt_len,
))
start_idx += prompt_len
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
self.to_cache = torch.tensor(to_cache,
dtype=torch.int32,
device=self.slot_mapping.device)
self.num_prompts = len(prompt_lens) self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens) self.num_prompt_tokens = sum(prompt_lens)
self.num_generation_tokens = context_lens.shape[0] self.num_generation_tokens = context_lens.shape[0]

View File

@@ -1,5 +1,5 @@
"""Multi-head attention.""" """Multi-head attention."""
from typing import List, Optional from typing import Any, Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -9,8 +9,10 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from vllm import attention_ops from vllm import attention_ops
from vllm import cache_ops from vllm import cache_ops
from vllm import pos_encoding_ops
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.rotary_embedding import (
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
RotaryEmbedding)
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
@@ -56,12 +58,14 @@ class PagedAttention(nn.Module):
num_heads: int, num_heads: int,
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: Optional[int] = None) -> None: num_kv_heads: Optional[int] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@@ -73,12 +77,19 @@ class PagedAttention(nn.Module):
raise ValueError(f"head_size ({self.head_size}) is not supported. " raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
def set_attn_bias(self, input_metadata: InputMetadata) -> None: def set_attn_bias(
self,
input_metadata: InputMetadata,
dtype: torch.dtype,
) -> None:
del dtype # Unused.
if input_metadata.attn_bias: if input_metadata.attn_bias:
# Already set by a previous layer. # Already set by a previous layer.
return return
prompt_lens = input_metadata.prompt_lens prompt_lens = input_metadata.prompt_lens
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens) attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(self.sliding_window)
input_metadata.attn_bias.append(attn_bias) input_metadata.attn_bias.append(attn_bias)
def multi_query_kv_attention( def multi_query_kv_attention(
@@ -196,7 +207,7 @@ class PagedAttention(nn.Module):
if num_prompt_tokens > 0: if num_prompt_tokens > 0:
# Prompt run. # Prompt run.
assert input_metadata.num_generation_tokens == 0 assert input_metadata.num_generation_tokens == 0
self.set_attn_bias(input_metadata) self.set_attn_bias(input_metadata, dtype=query.dtype)
self.multi_query_kv_attention( self.multi_query_kv_attention(
output[:num_prompt_tokens], output[:num_prompt_tokens],
query[:num_prompt_tokens], query[:num_prompt_tokens],
@@ -216,12 +227,20 @@ class PagedAttention(nn.Module):
if (num_valid_tokens > 0 and key_cache is not None if (num_valid_tokens > 0 and key_cache is not None
and value_cache is not None): and value_cache is not None):
# The stride is 3 because the key and value are sliced from qkv. # The stride is 3 because the key and value are sliced from qkv.
key_to_cache = key[:num_valid_tokens]
value_to_cache = value[:num_valid_tokens]
slot_mapping = input_metadata.slot_mapping
if input_metadata.to_cache is not None:
key_to_cache = key_to_cache[input_metadata.to_cache]
value_to_cache = value_to_cache[input_metadata.to_cache]
slot_mapping = slot_mapping[input_metadata.to_cache]
cache_ops.reshape_and_cache( cache_ops.reshape_and_cache(
key[:num_valid_tokens], key_to_cache,
value[:num_valid_tokens], value_to_cache,
key_cache, key_cache,
value_cache, value_cache,
input_metadata.slot_mapping, slot_mapping,
) )
if input_metadata.num_generation_tokens > 0: if input_metadata.num_generation_tokens > 0:
@@ -242,7 +261,7 @@ class PagedAttention(nn.Module):
class PagedAttentionWithRoPE(PagedAttention): class PagedAttentionWithRoPE(PagedAttention):
"""PagedAttention with rotary embedding.""" """PagedAttention with rotary positional embedding."""
def __init__( def __init__(
self, self,
@@ -254,26 +273,31 @@ class PagedAttentionWithRoPE(PagedAttention):
base: int = 10000, base: int = 10000,
num_kv_heads: Optional[int] = None, num_kv_heads: Optional[int] = None,
is_neox_style: bool = True, is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
sliding_window: Optional[int] = None,
) -> None: ) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads) super().__init__(num_heads,
self.is_neox_style = is_neox_style head_size,
scale,
# Create the cos and sin cache. num_kv_heads,
inv_freq = 1.0 / (base**( sliding_window=sliding_window)
torch.arange(0, rotary_dim, 2, device="cuda") / rotary_dim)) if rope_scaling is None:
t = torch.arange(max_position, device="cuda").float() self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
freqs = torch.einsum("i,j -> ij", t, inv_freq.float()) max_position, base,
cos = freqs.cos() is_neox_style)
sin = freqs.sin() else:
cache = torch.cat((cos, sin), dim=-1) scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"]
# FIXME(woosuk): This assumes that we configure the default dtype when if scaling_type == "linear":
# initializing the model. self.rotary_emb = LinearScalingRotaryEmbedding(
# TODO(woosuk): Make it more robust. head_size, rotary_dim, max_position, base, is_neox_style,
torch_dtype = torch.get_default_dtype() scaling_factor)
cache = cache.to(torch_dtype) elif scaling_type == "dynamic":
# Embedding size: [max_position, rotary_dim] self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
self.register_buffer("cos_sin_cache", cache, persistent=False) head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward( def forward(
self, self,
@@ -290,7 +314,7 @@ class PagedAttentionWithRoPE(PagedAttention):
Args: Args:
positions: shape = [num_tokens] positions: shape = [num_tokens]
query: shape = [num_tokens, num_heads * head_size] query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x, key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
@@ -306,14 +330,7 @@ class PagedAttentionWithRoPE(PagedAttention):
# Apply rotary embedding to the query and key before passing them # Apply rotary embedding to the query and key before passing them
# to the attention op. # to the attention op.
pos_encoding_ops.rotary_embedding( query, key = self.rotary_emb(positions, query, key)
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return super().forward( return super().forward(
query, query,
key, key,
@@ -340,13 +357,14 @@ class PagedAttentionWithALiBi(PagedAttention):
slopes = torch.tensor(slopes, dtype=torch.float32) slopes = torch.tensor(slopes, dtype=torch.float32)
self.register_buffer("alibi_slopes", slopes, persistent=False) self.register_buffer("alibi_slopes", slopes, persistent=False)
def set_attn_bias(self, input_metadata: InputMetadata) -> None: def set_attn_bias(self, input_metadata: InputMetadata,
dtype: torch.dtype) -> None:
if input_metadata.attn_bias: if input_metadata.attn_bias:
# Already set by a previous layer. # Already set by a previous layer.
return return
# Generates ALiBi mask for each prompt. # Generates ALiBi mask for each prompt.
for prompt_len in input_metadata.prompt_lens: for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len) bias = torch.arange(prompt_len, dtype=dtype)
# Note(zhuohan): HF uses # Note(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)` # `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but # here. We find that both biases give the same results, but
@@ -364,6 +382,7 @@ class PagedAttentionWithALiBi(PagedAttention):
prompt_len, prompt_len,
padded_len, padded_len,
device=self.alibi_slopes.device, device=self.alibi_slopes.device,
dtype=dtype,
)[:, :, :, :prompt_len].copy_(bias) )[:, :, :, :prompt_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None]) bias.mul_(self.alibi_slopes[:, None, None])
attn_bias = LowerTriangularMaskWithTensorBias(bias) attn_bias = LowerTriangularMaskWithTensorBias(bias)

View File

@@ -0,0 +1,37 @@
from vllm.model_executor.layers.quantized_linear.awq import (
AWQColumnParallelLinear, AWQRowParallelLinear)
from vllm.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, RowParallelLinear)
_QUANTIZED_LINEAR_REGISTRY = {
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
}
class ParallelLinear:
@classmethod
def column(cls, *args, **kwargs) -> ColumnParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return ColumnParallelLinear(*args, **kwargs)
name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
return quant_linear_cls(*args, **kwargs)
@classmethod
def row(cls, *args, **kwargs) -> RowParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return RowParallelLinear(*args, **kwargs)
name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
return quant_linear_cls(*args, **kwargs)

View File

@@ -0,0 +1,102 @@
from typing import Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.parallel_utils.tensor_parallel.layers import (
ColumnParallelLinear, RowParallelLinear)
class AWQColumnParallelLinear(ColumnParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.weight_bits == 0
assert (self.output_size_per_partition %
self.quant_config.pack_factor == 0)
self.qweight = Parameter(
torch.empty(
self.input_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
class AWQRowParallelLinear(RowParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert (self.input_size_per_partition %
self.quant_config.weight_bits == 0)
assert self.output_size % self.quant_config.pack_factor == 0
self.qweight = Parameter(
torch.empty(
self.input_size_per_partition,
self.output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
return out.reshape(out_shape)

View File

@@ -0,0 +1,169 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rotary Positional Embeddings."""
from typing import Tuple, Union
import torch
import torch.nn as nn
from vllm import pos_encoding_ops
class RotaryEmbedding(nn.Module):
"""Original rotary positional embedding."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
cache = self._compute_cos_sin_cache()
cache = cache.to(torch.get_default_dtype())
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings,
dtype=torch.float,
device="cuda")
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# pos_encoding_ops.rotary_embedding() is an in-place operation that
# updates the query and key tensors.
pos_encoding_ops.rotary_embedding(positions, query, key,
self.head_size, self.cos_sin_cache,
self.is_neox_style)
return query, key
class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling.
Credits to the Reddit user /u/kaiokendev
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
) -> None:
self.scaling_factor = scaling_factor
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * self.scaling_factor
t = torch.arange(max_len, dtype=torch.float, device="cuda")
t = t / self.scaling_factor
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
) -> None:
self.scaling_factor = scaling_factor
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)
def _compute_cos_sin_cache(self) -> torch.Tensor:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * self.scaling_factor
base = self.base * (
(self.scaling_factor * max_len / self.max_position_embeddings) -
(self.scaling_factor - 1))**(self.rotary_dim /
(self.rotary_dim - 2))
inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float, device="cuda")
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache

View File

@@ -1,15 +1,14 @@
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
from typing import Dict, List, Tuple, Optional from typing import Dict, List, Optional, Tuple
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
gather_from_tensor_model_parallel_region) gather_from_tensor_model_parallel_region)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceOutputs from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
@@ -44,12 +43,8 @@ class Sampler(nn.Module):
hidden_states = _prune_hidden_states(hidden_states, input_metadata) hidden_states = _prune_hidden_states(hidden_states, input_metadata)
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t()) logits = _get_logits(hidden_states, embedding, embedding_bias,
if embedding_bias is not None: self.vocab_size)
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab (if any).
logits = logits[:, :self.vocab_size]
# Apply presence and frequency penalties. # Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata) output_tokens = _get_output_tokens(input_metadata)
@@ -59,7 +54,7 @@ class Sampler(nn.Module):
assert len(presence_penalties) == logits.shape[0] assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0]
logits = _apply_penalties(logits, output_tokens, presence_penalties, logits = _apply_penalties(logits, output_tokens, presence_penalties,
frequency_penalties, self.vocab_size) frequency_penalties)
# Apply temperature scaling. # Apply temperature scaling.
temperatures = _get_temperatures(input_metadata) temperatures = _get_temperatures(input_metadata)
@@ -82,26 +77,55 @@ class Sampler(nn.Module):
# We use float32 for probabilities and log probabilities. # We use float32 for probabilities and log probabilities.
# Compute the probabilities. # Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p and top-k). # Compute the log probabilities.
logprobs = torch.log(probs) # Use log_softmax to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens. # Sample the next tokens.
return _sample(probs, logprobs, input_metadata) return _sample(probs, logprobs, input_metadata)
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor],
vocab_size: int) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab (if any).
logits = logits[:, :vocab_size]
return logits
def _prune_hidden_states( def _prune_hidden_states(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
last_token_indices = {t: [] for t in SamplingType}
start_idx = 0 start_idx = 0
last_token_indicies: List[int] = [] for i, seq_group in enumerate(input_metadata.seq_groups):
for prompt_len in input_metadata.prompt_lens: seq_ids, sampling_params = seq_group
last_token_indicies.append(start_idx + prompt_len - 1) sampling_type = sampling_params.sampling_type
start_idx += prompt_len if i < input_metadata.num_prompts:
last_token_indicies.extend( assert len(seq_ids) == 1, "Prompt input should have only one seq."
range(start_idx, start_idx + input_metadata.num_generation_tokens)) prompt_len = input_metadata.prompt_lens[i]
return hidden_states.index_select( last_token_indices[sampling_type].append(start_idx + prompt_len -
0, torch.tensor(last_token_indicies, device=hidden_states.device)) 1)
start_idx += prompt_len
else:
num_seqs = len(seq_ids)
last_token_indices[sampling_type].extend(
range(start_idx, start_idx + num_seqs))
start_idx += num_seqs
all_last_token_indices = []
for sampling_type in SamplingType:
all_last_token_indices.extend(last_token_indices[sampling_type])
all_last_token_indices = torch.tensor(all_last_token_indices,
dtype=torch.long,
device=hidden_states.device)
return hidden_states.index_select(0, all_last_token_indices)
def _get_penalties( def _get_penalties(
@@ -109,37 +133,22 @@ def _get_penalties(
# Collect the presence and frequency penalties. # Collect the presence and frequency penalties.
presence_penalties: List[float] = [] presence_penalties: List[float] = []
frequency_penalties: List[float] = [] frequency_penalties: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for seq_group in input_metadata.seq_groups:
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
p = sampling_params.presence_penalty p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty f = sampling_params.frequency_penalty
if i < input_metadata.num_prompts: presence_penalties += [p] * len(seq_ids)
# A prompt input. frequency_penalties += [f] * len(seq_ids)
presence_penalties.append(p)
frequency_penalties.append(f)
else:
# A generation token.
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
return presence_penalties, frequency_penalties return presence_penalties, frequency_penalties
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
output_tokens: List[List[int]] = [] output_tokens: List[List[int]] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for seq_group in input_metadata.seq_groups:
seq_ids, _ = seq_group seq_ids, _ = seq_group
if i < input_metadata.num_prompts: for seq_id in seq_ids:
# A prompt input.
# NOTE: While the prompt input usually has no output tokens,
# it may have output tokens in the case of recomputation.
seq_id = seq_ids[0]
seq_data = input_metadata.seq_data[seq_id] seq_data = input_metadata.seq_data[seq_id]
output_tokens.append(seq_data.output_token_ids) output_tokens.append(seq_data.output_token_ids)
else:
# A generation token.
for seq_id in seq_ids:
seq_data = input_metadata.seq_data[seq_id]
output_tokens.append(seq_data.output_token_ids)
return output_tokens return output_tokens
@@ -148,11 +157,8 @@ def _apply_penalties(
output_tokens: List[List[int]], output_tokens: List[List[int]],
presence_penalties: List[float], presence_penalties: List[float],
frequency_penalties: List[float], frequency_penalties: List[float],
vocab_size: int,
) -> torch.Tensor: ) -> torch.Tensor:
num_seqs = logits.shape[0] num_seqs, vocab_size = logits.shape
# Collect the indices of sequences that have non-zero penalties.
indices = []
for i in range(num_seqs): for i in range(num_seqs):
if not output_tokens[i]: if not output_tokens[i]:
continue continue
@@ -160,40 +166,47 @@ def _apply_penalties(
f = frequency_penalties[i] f = frequency_penalties[i]
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS: if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
continue continue
indices.append(i) break
else:
# Return early if all sequences have zero penalties. # Return early if all sequences have zero penalties.
if not indices:
return logits return logits
bin_counts = [] max_output_len = max(len(tokens) for tokens in output_tokens)
for i in indices: padded_output_tokens = [
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size)) tokens + [vocab_size] * (max_output_len - len(tokens))
bin_counts = np.stack(bin_counts, axis=0) for tokens in output_tokens
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype, ]
device=logits.device) output_tokens_tensor = torch.tensor(padded_output_tokens,
dtype=torch.long,
device=logits.device)
# Compute the bin counts for the output tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=logits.device)
bin_counts.scatter_add_(1, output_tokens_tensor,
torch.ones_like(output_tokens_tensor))
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor(frequency_penalties, frequency_penalties = torch.tensor(frequency_penalties,
dtype=logits.dtype, dtype=logits.dtype,
device=logits.device) device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor(presence_penalties, presence_penalties = torch.tensor(presence_penalties,
dtype=logits.dtype, dtype=logits.dtype,
device=logits.device) device=logits.device)
# We follow the definition in OpenAI API. # We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details # Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype) logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
return logits return logits
def _get_temperatures(input_metadata: InputMetadata) -> List[float]: def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
# Collect the temperatures for the logits. # Collect the temperatures for the logits.
temperatures: List[float] = [] temperatures: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for seq_group in input_metadata.seq_groups:
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature temperature = sampling_params.temperature
if temperature < _SAMPLING_EPS: if temperature < _SAMPLING_EPS:
@@ -201,13 +214,7 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
# (i.e., greedy sampling or beam search). # (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero. # Set the temperature to 1 to avoid division by zero.
temperature = 1.0 temperature = 1.0
temperatures += [temperature] * len(seq_ids)
if i < input_metadata.num_prompts:
# A prompt input.
temperatures.append(temperature)
else:
# A generation token.
temperatures += [temperature] * len(seq_ids)
return temperatures return temperatures
@@ -217,21 +224,15 @@ def _get_top_p_top_k(
) -> Tuple[List[float], List[int]]: ) -> Tuple[List[float], List[int]]:
top_ps: List[float] = [] top_ps: List[float] = []
top_ks: List[int] = [] top_ks: List[int] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for seq_group in input_metadata.seq_groups:
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
top_p = sampling_params.top_p top_p = sampling_params.top_p
# k should not be greater than the vocab size. # k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size) top_k = min(sampling_params.top_k, vocab_size)
# k=-1 means no truncation. # k=-1 means no truncation.
top_k = vocab_size if top_k == -1 else top_k top_k = vocab_size if top_k == -1 else top_k
if i < input_metadata.num_prompts: top_ps += [top_p] * len(seq_ids)
# A prompt input. top_ks += [top_k] * len(seq_ids)
top_ps.append(top_p)
top_ks.append(top_k)
else:
# A generation token.
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
return top_ps, top_ks return top_ps, top_ks
@@ -267,95 +268,154 @@ def _apply_top_p_top_k(
def _get_topk_logprobs( def _get_topk_logprobs(
logprobs: torch.Tensor, logprobs: torch.Tensor,
num_logprobs: Optional[int], num_logprobs: Optional[int],
) -> Dict[int, float]: ) -> List[Dict[int, float]]:
num_seqs = logprobs.size(0)
if num_logprobs is None or num_logprobs == 0: if num_logprobs is None or num_logprobs == 0:
return {} return [{} for _ in range(num_seqs)]
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs) all_topk_logprobs, all_topk_ids = torch.topk(logprobs,
if num_logprobs == 1: num_logprobs,
topk_logprobs = [topk_logprobs.item()] dim=-1)
topk_ids = [topk_ids.item()] all_topk_logprobs = all_topk_logprobs.cpu()
else: all_topk_ids = all_topk_ids.cpu()
topk_logprobs = topk_logprobs.tolist() all_token_to_logprob = []
topk_ids = topk_ids.tolist() for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids):
token_to_logprob: Dict[int, float] = {}
token_to_logprob: Dict[int, float] = {} for token_id, logprob in zip(topk_ids, topk_logprobs):
for token_id, logprob in zip(topk_ids, topk_logprobs): token_to_logprob[token_id.item()] = logprob.item()
token_to_logprob[token_id] = logprob all_token_to_logprob.append(token_to_logprob)
return token_to_logprob return all_token_to_logprob
def _sample_from_prompt( def _build_sequence_outputs(
prob: torch.Tensor, parent_ids: List[int],
sampling_params: SamplingParams, next_token_ids: List[int],
) -> List[int]: selected_token_logprobs: torch.Tensor,
if sampling_params.use_beam_search: parent_seq_ids: List[int],
# Beam search. parent_logprobs: torch.Tensor,
beam_width = sampling_params.best_of num_output_logprobs: Optional[int],
# Sample 2 * beam_width candidates to make sure that with high ) -> List[SequenceOutputs]:
# probability we can get `beam_width` candidates in addition to # Get top-k log probabilities for the next tokens.
# the finished sequences for the next iteration. See next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs)
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563 seq_outputs: List[SequenceOutputs] = []
# for details. See also HF reference: for parent_id, next_token_id, token_logprob in zip(
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 parent_ids, next_token_ids, selected_token_logprobs):
_, next_token_ids = torch.topk(prob, 2 * beam_width) output_logprobs = next_logprobs[parent_id].copy()
next_token_ids = next_token_ids.tolist() output_logprobs[next_token_id] = token_logprob
elif sampling_params.temperature < _SAMPLING_EPS: seq_outputs.append(
# Greedy sampling. SequenceOutputs(parent_seq_ids[parent_id], next_token_id,
assert sampling_params.best_of == 1 output_logprobs))
next_token_id = torch.argmax(prob) return seq_outputs
next_token_ids = [next_token_id.item()]
else:
# Random sampling.
# Sample `best_of` tokens for the prompt.
num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial(prob,
num_samples=num_seqs,
replacement=True)
next_token_ids = next_token_ids.tolist()
return next_token_ids
def _sample_from_generation_tokens( def _greedy_sample(
seq_ids: List[int], selected_seq_groups: List[Tuple[List[int], SamplingParams]],
probs: torch.Tensor,
logprobs: torch.Tensor, logprobs: torch.Tensor,
seq_logprobs: List[float], ) -> List[Tuple[List[int], List[int]]]:
sampling_params: SamplingParams, samples = torch.argmax(logprobs, dim=-1).cpu()
) -> Tuple[List[int], List[int]]: sample_idx = 0
# NOTE(woosuk): sampling_params.best_of can be greater than results = []
# len(seq_ids) because some sequences in the group might have for seq_group in selected_seq_groups:
# been already terminated. seq_ids, _ = seq_group
if sampling_params.use_beam_search: num_parent_seqs = len(seq_ids)
# Beam search. assert num_parent_seqs == 1, (
# Add cumulative logprobs for the sequences in the group. "Greedy sampling should have only one seq.")
seq_logprobs = torch.tensor(seq_logprobs, parent_ids = list(range(num_parent_seqs))
dtype=torch.float, next_token_ids = [samples[sample_idx].item()]
device=logprobs.device) results.append((next_token_ids, parent_ids))
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1) sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
return results
vocab_size = logprobs.size(-1)
beam_width = len(seq_ids) def _random_sample(
_, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width) selected_seq_groups: List[Tuple[List[int], SamplingParams]],
topk_ids = topk_ids.tolist() is_prompts: List[bool],
seq_idx = [i // vocab_size for i in topk_ids] probs: torch.Tensor,
parent_seq_ids = [seq_ids[i] for i in seq_idx] ) -> List[Tuple[List[int], List[int]]]:
next_token_ids = [i % vocab_size for i in topk_ids] # Find the maximum best_of value of the prompt phase requests.
elif sampling_params.temperature < _SAMPLING_EPS: max_best_of = 1
# Greedy sampling. for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
assert len(seq_ids) == 1 if is_prompt:
next_token_id = torch.argmax(probs, dim=-1) seq_ids, sampling_params = seq_group
next_token_ids = [int(next_token_id.item())] max_best_of = max(max_best_of, sampling_params.best_of)
parent_seq_ids = seq_ids random_samples = torch.multinomial(probs,
else: num_samples=max_best_of,
# Random sampling. replacement=True).cpu()
# Sample 1 token for each sequence in the group. sample_idx = 0
next_token_ids = torch.multinomial(probs, results = []
num_samples=1, for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
replacement=True) seq_ids, sampling_params = seq_group
next_token_ids = next_token_ids.squeeze(dim=-1).tolist() num_parent_seqs = len(seq_ids)
parent_seq_ids = seq_ids if is_prompt:
return parent_seq_ids, next_token_ids # Prompt phase.
assert num_parent_seqs == 1, (
"Prompt input should have only one seq.")
parent_ids = [0] * sampling_params.best_of
next_token_ids = random_samples[
sample_idx, :sampling_params.best_of].tolist()
else:
# Generation phase.
parent_ids = list(range(num_parent_seqs))
next_token_ids = random_samples[sample_idx:sample_idx +
num_parent_seqs, 0].tolist()
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
assert sample_idx == probs.size(0)
return results
def _beam_search_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
is_prompts: List[bool],
seq_data: Dict[int, SequenceData],
logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
# We sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
# for details. See also HF reference:
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
#
# Note: Beam search is not vectorized, so its speed can be slower than
# other sampling methods.
sample_idx = 0
results = []
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
seq_ids, sampling_params = seq_group
num_parent_seqs = len(seq_ids)
beam_width = sampling_params.best_of
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
if is_prompt:
# Prompt phase.
assert num_parent_seqs == 1, (
"Prompt input should have only one seq.")
parent_ids = [0] * (2 * beam_width)
_, next_token_ids = torch.topk(seq_group_logprobs[0],
2 * beam_width)
next_token_ids = next_token_ids.tolist()
else:
# Generation phase.
cumulative_logprobs = [
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
]
cumulative_logprobs = torch.tensor(
cumulative_logprobs,
dtype=torch.float,
device=seq_group_logprobs.device)
seq_group_logprobs = (seq_group_logprobs +
cumulative_logprobs.unsqueeze(dim=1))
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
2 * beam_width)
topk_ids = topk_ids.tolist()
vocab_size = seq_group_logprobs.size(-1)
parent_ids = [i // vocab_size for i in topk_ids]
next_token_ids = [i % vocab_size for i in topk_ids]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
return results
def _sample( def _sample(
@@ -363,65 +423,80 @@ def _sample(
logprobs: torch.Tensor, logprobs: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
seq_outputs: SamplerOutput = [] categorized_seq_group_ids = {t: [] for t in SamplingType}
category_num_tokens = {t: 0 for t in SamplingType}
# TODO(woosuk): Optimize.
idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_group_outputs: List[SequenceOutputs] = []
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts: sampling_type = sampling_params.sampling_type
# Generate the next tokens for a prompt input. categorized_seq_group_ids[sampling_type].append(i)
assert len(seq_ids) == 1, "Prompt input should have only one seq." num_seqs = len(seq_ids)
parent_seq_id = seq_ids[0] category_num_tokens[sampling_type] += num_seqs
prob = probs[idx]
logprob = logprobs[idx]
idx += 1
# Sample the next tokens. seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
next_token_ids = _sample_from_prompt(prob, sampling_params) category_start_idx = 0
# Get top-k log probabilities for the next tokens. for sampling_type in SamplingType:
next_logprobs = _get_topk_logprobs(logprob, seq_group_ids = categorized_seq_group_ids[sampling_type]
sampling_params.logprobs) seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
# Build the output. num_tokens = category_num_tokens[sampling_type]
for next_token_id in next_token_ids: if num_tokens == 0:
output_logprobs = next_logprobs.copy() continue
output_logprobs[next_token_id] = logprob[next_token_id].item() category_logprobs = logprobs[category_start_idx:category_start_idx +
seq_group_outputs.append( num_tokens]
SequenceOutputs(parent_seq_id, next_token_id, category_probs = probs[category_start_idx:category_start_idx +
output_logprobs)) num_tokens]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, category_logprobs)
elif sampling_type == SamplingType.RANDOM:
sample_results = _random_sample(seq_groups, is_prompts,
category_probs)
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts,
input_metadata.seq_data,
category_logprobs)
else: else:
# Generate the next tokens for generation tokens. raise ValueError(f"Unsupported sampling type: {sampling_type}")
# Batched query for logprobs of selected token
batched_logprobs_query_seq_indices: List[int] = []
batched_logprobs_query_token_indices: List[int] = []
sample_idx = 0
for seq_group_id, seq_group, sample_result in zip(
seq_group_ids, seq_groups, sample_results):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
num_parent_seqs = len(seq_ids) num_parent_seqs = len(seq_ids)
prob = probs[idx:idx + num_parent_seqs] batched_logprobs_query_seq_indices.extend(
logprob = logprobs[idx:idx + num_parent_seqs] [sample_idx + parent_id for parent_id in parent_ids])
idx += num_parent_seqs batched_logprobs_query_token_indices.extend(next_token_ids)
sample_idx += num_parent_seqs
assert sample_idx == num_tokens
batched_logprobs_query_result = category_logprobs[[
batched_logprobs_query_seq_indices,
batched_logprobs_query_token_indices
]].tolist()
# Sample the next tokens. # Build the sequence outputs.
seq_logprobs = [ sample_idx = 0
input_metadata.seq_data[seq_id].cumulative_logprob result_idx = 0
for seq_id in seq_ids for seq_group_id, seq_group, sample_result in zip(
] seq_group_ids, seq_groups, sample_results):
parent_seq_ids, next_token_ids = _sample_from_generation_tokens( seq_ids, sampling_params = seq_group
seq_ids, prob, logprob, seq_logprobs, sampling_params) next_token_ids, parent_ids = sample_result
num_results = len(next_token_ids)
num_parent_seqs = len(seq_ids)
parent_logprobs = category_logprobs[sample_idx:sample_idx +
num_parent_seqs]
selected_token_logprobs = batched_logprobs_query_result[
result_idx:result_idx + num_results]
seq_output = _build_sequence_outputs(parent_ids, next_token_ids,
selected_token_logprobs,
seq_ids, parent_logprobs,
sampling_params.logprobs)
seq_outputs_dict[seq_group_id] = seq_output
sample_idx += num_parent_seqs
result_idx += num_results
assert sample_idx == num_tokens
category_start_idx += num_tokens
# Get top-k log probabilities for the next tokens. return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
next_logprobs: Dict[int, Dict[int, float]] = {}
for j, seq_id in enumerate(seq_ids):
next_logprobs[seq_id] = _get_topk_logprobs(
logprob[j], sampling_params.logprobs)
# Build the output.
for parent_seq_id, next_token_id in zip(parent_seq_ids,
next_token_ids):
j = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[j,
next_token_id].item()
seq_group_outputs.append(
SequenceOutputs(parent_seq_id, next_token_id,
output_logprobs))
seq_outputs.append(seq_group_outputs)
return seq_outputs

View File

@@ -8,7 +8,8 @@ from transformers import PretrainedConfig
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.models import * # pylint: disable=wildcard-import from vllm.model_executor.models import * # pylint: disable=wildcard-import
from vllm.model_executor.weight_utils import initialize_dummy_weights from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
# TODO(woosuk): Lazy-load the model classes. # TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = { _MODEL_REGISTRY = {
@@ -24,12 +25,18 @@ _MODEL_REGISTRY = {
"InternLMForCausalLM": InternLMForCausalLM, "InternLMForCausalLM": InternLMForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM, "LlamaForCausalLM": LlamaForCausalLM,
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MistralForCausalLM": MistralForCausalLM,
"MPTForCausalLM": MPTForCausalLM, "MPTForCausalLM": MPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM, "OPTForCausalLM": OPTForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel, "QWenLMHeadModel": QWenLMHeadModel,
"RWForCausalLM": FalconForCausalLM, "RWForCausalLM": FalconForCausalLM,
} }
# FIXME(woosuk): Remove this once all models support quantization.
_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
LlamaForCausalLM,
]
@contextlib.contextmanager @contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype): def _set_default_torch_dtype(dtype: torch.dtype):
@@ -52,10 +59,38 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config) model_class = _get_model_architecture(model_config.hf_config)
# Get the quantization config.
quant_config = None
if model_config.quantization is not None:
if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
raise ValueError(
f"Quantization is not supported for {model_class}.")
quant_config = get_quant_config(model_config.quantization,
model_config.model,
model_config.download_dir)
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} is not "
"supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
with _set_default_torch_dtype(model_config.dtype): with _set_default_torch_dtype(model_config.dtype):
# Create a model instance. # Create a model instance.
# The weights will be initialized as empty tensors. # The weights will be initialized as empty tensors.
model = model_class(model_config.hf_config) if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
model = model_class(model_config.hf_config, quant_config)
else:
model = model_class(model_config.hf_config)
if model_config.load_format == "dummy": if model_config.load_format == "dummy":
model = model.cuda() model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
@@ -64,6 +99,6 @@ def get_model(model_config: ModelConfig) -> nn.Module:
else: else:
# Load the weights from the cached or downloaded files. # Load the weights from the cached or downloaded files.
model.load_weights(model_config.model, model_config.download_dir, model.load_weights(model_config.model, model_config.download_dir,
model_config.load_format) model_config.load_format, model_config.revision)
model = model.cuda() model = model.cuda()
return model.eval() return model.eval()

View File

@@ -12,6 +12,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.mpt import MPTForCausalLM from vllm.model_executor.models.mpt import MPTForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.models.qwen import QWenLMHeadModel from vllm.model_executor.models.qwen import QWenLMHeadModel
from vllm.model_executor.models.mistral import MistralForCausalLM
__all__ = [ __all__ = [
"AquilaForCausalLM", "AquilaForCausalLM",
@@ -28,4 +29,5 @@ __all__ = [
"MPTForCausalLM", "MPTForCausalLM",
"OPTForCausalLM", "OPTForCausalLM",
"QWenLMHeadModel", "QWenLMHeadModel",
"MistralForCausalLM",
] ]

View File

@@ -105,6 +105,8 @@ class AquilaAttention(nn.Module):
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@@ -119,6 +121,8 @@ class AquilaAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = ColumnParallelLinear(
hidden_size, hidden_size,
@@ -140,6 +144,8 @@ class AquilaAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
base=self.rope_theta,
max_position=self.max_position_embeddings,
) )
def forward( def forward(
@@ -164,10 +170,15 @@ class AquilaDecoderLayer(nn.Module):
def __init__(self, config: AquilaConfig): def __init__(self, config: AquilaConfig):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = AquilaAttention( self.self_attn = AquilaAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_kv_heads=config.num_attention_heads, num_kv_heads=config.num_attention_heads,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
) )
self.mlp = AquilaMLP( self.mlp = AquilaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
@@ -288,7 +299,8 @@ class AquilaForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size) q_proj_shard_size = (self.config.hidden_size // tp_size)
@@ -305,7 +317,7 @@ class AquilaForCausalLM(nn.Module):
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue

View File

@@ -111,6 +111,8 @@ class BaiChuanAttention(nn.Module):
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
position_embedding: str, position_embedding: str,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@@ -122,6 +124,8 @@ class BaiChuanAttention(nn.Module):
tensor_model_parallel_world_size) tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.postion_embedding = position_embedding self.postion_embedding = position_embedding
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
# pylint: disable=invalid-name # pylint: disable=invalid-name
self.W_pack = ColumnParallelLinear( self.W_pack = ColumnParallelLinear(
@@ -151,10 +155,13 @@ class BaiChuanAttention(nn.Module):
scaling, alibi_slopes) scaling, alibi_slopes)
else: else:
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(self.num_heads, self.attn = PagedAttentionWithRoPE(
self.head_dim, self.num_heads,
self.scaling, self.head_dim,
rotary_dim=self.head_dim) self.scaling,
rotary_dim=self.head_dim,
base=self.rope_theta,
max_position=self.max_position_embeddings)
def forward( def forward(
self, self,
@@ -183,10 +190,15 @@ class BaiChuanDecoderLayer(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str): def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = BaiChuanAttention( self.self_attn = BaiChuanAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
position_embedding=position_embedding, position_embedding=position_embedding,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
) )
self.mlp = BaiChuanMLP( self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
@@ -303,13 +315,14 @@ class BaiChuanBaseForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue

View File

@@ -279,11 +279,12 @@ class BloomForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if name == "lm_head.weight": if name == "lm_head.weight":
# Since hidden_states are parallelized, we need to # Since hidden_states are parallelized, we need to
# load lm_head.weight in parallel. # load lm_head.weight in parallel.

View File

@@ -161,12 +161,17 @@ class FalconAttention(nn.Module):
"Rotary and alibi are mutually exclusive.") "Rotary and alibi are mutually exclusive.")
if self.use_rotary: if self.use_rotary:
# TODO(zhuohan): Pass in correct `max_position`` rope_theta = getattr(config, "rope_theta", 10000)
self.attn = PagedAttentionWithRoPE(self.num_heads, max_position_embeddings = getattr(config,
self.head_dim, "max_position_embeddings", 8192)
self.inv_norm_factor, self.attn = PagedAttentionWithRoPE(
rotary_dim=self.head_dim, self.num_heads,
num_kv_heads=self.num_kv_heads) self.head_dim,
self.inv_norm_factor,
base=rope_theta,
max_position=max_position_embeddings,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
elif self.use_alibi: elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads head_start = tp_rank * self.num_heads
@@ -420,7 +425,8 @@ class FalconForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_size = (get_tensor_model_parallel_world_size()) tp_size = (get_tensor_model_parallel_world_size())
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
@@ -452,7 +458,7 @@ class FalconForCausalLM(nn.Module):
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "query_key_value" in name: if "query_key_value" in name:
loaded_weight = convert_pyslice_to_tensor(loaded_weight) loaded_weight = convert_pyslice_to_tensor(loaded_weight)
loaded_weight_size = loaded_weight.size() loaded_weight_size = loaded_weight.size()

View File

@@ -231,14 +231,15 @@ class GPT2LMHeadModel(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_world_size = ( tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size()) get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
# linear layer. # linear layer.

View File

@@ -259,14 +259,15 @@ class GPTBigCodeForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_world_size = ( tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size()) get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
# linear layer. # linear layer.

View File

@@ -67,11 +67,17 @@ class GPTJAttention(nn.Module):
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
assert getattr(config, "rotary", True) assert getattr(config, "rotary", True)
assert config.rotary_dim % 2 == 0 assert config.rotary_dim % 2 == 0
self.attn = PagedAttentionWithRoPE(self.num_heads, rope_theta = getattr(config, "rope_theta", 10000)
self.head_size, max_position_embeddings = getattr(config, "max_position_embeddings",
scaling, 8192)
config.rotary_dim, self.attn = PagedAttentionWithRoPE(
is_neox_style=False) self.num_heads,
self.head_size,
scaling,
config.rotary_dim,
base=rope_theta,
max_position=max_position_embeddings,
is_neox_style=False)
self.warmup = False self.warmup = False
def forward( def forward(
@@ -222,11 +228,12 @@ class GPTJForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "attn.bias" in name or "attn.masked_bias" in name: if "attn.bias" in name or "attn.masked_bias" in name:
continue continue

View File

@@ -68,8 +68,16 @@ class GPTNeoXAttention(nn.Module):
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct) rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0 assert rotary_dim % 2 == 0
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size, rope_theta = getattr(config, "rope_theta", 10000)
scaling, rotary_dim) max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_size,
scaling,
rotary_dim,
base=rope_theta,
max_position=max_position_embeddings)
def forward( def forward(
self, self,
@@ -231,11 +239,12 @@ class GPTNeoXForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if ("attention.bias" in name or "attention.masked_bias" in name if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name): or "rotary_emb.inv_freq" in name):
continue continue

View File

@@ -59,6 +59,8 @@ class InternLMAttention(nn.Module):
self, self,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@@ -70,6 +72,8 @@ class InternLMAttention(nn.Module):
tensor_model_parallel_world_size) tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = ColumnParallelLinear(
hidden_size, hidden_size,
@@ -85,10 +89,13 @@ class InternLMAttention(nn.Module):
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False, perform_initialization=False,
) )
self.attn = PagedAttentionWithRoPE(self.num_heads, self.attn = PagedAttentionWithRoPE(
self.head_dim, self.num_heads,
self.scaling, self.head_dim,
rotary_dim=self.head_dim) self.scaling,
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim)
def forward( def forward(
self, self,
@@ -112,9 +119,14 @@ class InternLMDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig): def __init__(self, config: LlamaConfig):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = InternLMAttention( self.self_attn = InternLMAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
) )
self.mlp = InternLMMLP( self.mlp = InternLMMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
@@ -233,12 +245,13 @@ class InternLMForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue

View File

@@ -25,7 +25,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@@ -36,13 +36,15 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import ( from vllm.model_executor.layers.quantized_linear import ParallelLinear
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
hf_model_weights_iterator)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding)
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -55,18 +57,21 @@ class LlamaMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
): quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size, self.gate_up_proj = ParallelLinear.column(hidden_size,
2 * intermediate_size, 2 * intermediate_size,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False,
self.down_proj = RowParallelLinear(intermediate_size, quant_config=quant_config)
hidden_size, self.down_proj = ParallelLinear.row(intermediate_size,
bias=False, hidden_size,
input_is_parallel=True, bias=False,
perform_initialization=False) input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
@@ -87,7 +92,10 @@ class LlamaAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
): rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
@@ -102,28 +110,34 @@ class LlamaAttention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = ParallelLinear.column(
hidden_size, hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) * (self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim, self.head_dim,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False, perform_initialization=False,
quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = ParallelLinear.row(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False, perform_initialization=False,
quant_config=quant_config,
) )
self.attn = PagedAttentionWithRoPE(self.num_heads, self.attn = PagedAttentionWithRoPE(
self.head_dim, self.num_heads,
self.scaling, self.head_dim,
base=self.rope_theta, self.scaling,
rotary_dim=self.head_dim, base=self.rope_theta,
num_kv_heads=self.num_kv_heads) max_position=self.max_position_embeddings,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads,
rope_scaling=rope_scaling)
def forward( def forward(
self, self,
@@ -144,21 +158,32 @@ class LlamaAttention(nn.Module):
class LlamaDecoderLayer(nn.Module): class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig): def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
@@ -195,7 +220,11 @@ class LlamaDecoderLayer(nn.Module):
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig): def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
@@ -205,7 +234,8 @@ class LlamaModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False) vocab_size, config.hidden_size, perform_initialization=False)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers) LlamaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -237,16 +267,23 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module): class LlamaForCausalLM(nn.Module):
def __init__(self, config): def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.model = LlamaModel(config) self.quant_config = quant_config
self.model = LlamaModel(config, quant_config)
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size, # NOTE: The LM head is not quantized.
vocab_size, self.lm_head = ParallelLinear.column(config.hidden_size,
bias=False, vocab_size,
gather_output=False, bias=False,
perform_initialization=False) gather_output=False,
perform_initialization=False,
quant_config=None)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
@@ -263,15 +300,28 @@ class LlamaForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = [ _column_parallel_layers = []
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight" _row_parallel_layers = ["o_proj", "down_proj"]
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
if self.quant_config is None:
weight_suffixes = ["weight"]
else:
weight_suffixes = self.quant_config.get_tp_tensor_names()
column_parallel_weights: List[str] = []
for layer in self._column_parallel_layers:
for suffix in weight_suffixes:
column_parallel_weights.append(f"{layer}.{suffix}")
row_parallel_weights: List[str] = []
for layer in self._row_parallel_layers:
for suffix in weight_suffixes:
row_parallel_weights.append(f"{layer}.{suffix}")
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size) q_proj_shard_size = (self.config.hidden_size // tp_size)
@@ -288,15 +338,30 @@ class LlamaForCausalLM(nn.Module):
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
is_packed = False
is_transposed = False
if self.quant_config is not None:
is_packed = self.quant_config.is_packed(name)
is_transposed = self.quant_config.is_transposed(name)
if is_transposed:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
loaded_weight = loaded_weight.T
is_attention_weight = False is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs: for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "qkv_proj")] param = state_dict[name.replace(weight_name, "qkv_proj")]
if is_transposed:
param = param.T
if is_packed:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size * shard_size * tensor_model_parallel_rank:shard_size *
@@ -315,6 +380,9 @@ class LlamaForCausalLM(nn.Module):
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "gate_up_proj")] param = state_dict[name.replace(weight_name, "gate_up_proj")]
if is_transposed:
param = param.T
shard_size = param.shape[0] // 2 shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size * shard_size * tensor_model_parallel_rank:shard_size *
@@ -329,6 +397,8 @@ class LlamaForCausalLM(nn.Module):
continue continue
param = state_dict[name] param = state_dict[name]
if is_transposed:
param = param.T
if "embed_tokens" in name or "lm_head" in name: if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight, load_padded_tensor_parallel_vocab(param, loaded_weight,
@@ -336,6 +406,6 @@ class LlamaForCausalLM(nn.Module):
continue continue
load_tensor_parallel_weights(param, loaded_weight, name, load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights, column_parallel_weights,
self._row_parallel_weights, row_parallel_weights,
tensor_model_parallel_rank) tensor_model_parallel_rank)

View File

@@ -0,0 +1,404 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import List, Optional, Tuple
import torch
from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.quantized_linear import ParallelLinear
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding)
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mistral import MistralConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class MistralMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = ParallelLinear.column(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=quant_config)
self.down_proj = ParallelLinear.row(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class MistralAttention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.qkv_proj = ParallelLinear.column(
hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=quant_config,
)
self.o_proj = ParallelLinear.row(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config,
)
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
base=self.rope_theta,
max_position=max_position,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
class MistralDecoderLayer(nn.Module):
def __init__(
self,
config: MistralConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MistralAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
quant_config=quant_config,
sliding_window=config.sliding_window)
self.mlp = MistralMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class MistralModel(nn.Module):
def __init__(
self,
config: MistralConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False)
self.layers = nn.ModuleList([
MistralDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class MistralForCausalLM(nn.Module):
def __init__(
self,
config: MistralConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = MistralModel(config, quant_config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
# NOTE: The LM head is not quantized.
self.lm_head = ParallelLinear.column(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=None)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_layers = []
_row_parallel_layers = ["o_proj", "down_proj"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
if self.quant_config is None:
weight_suffixes = ["weight"]
else:
weight_suffixes = self.quant_config.get_tp_tensor_names()
column_parallel_weights: List[str] = []
for layer in self._column_parallel_layers:
for suffix in weight_suffixes:
column_parallel_weights.append(f"{layer}.{suffix}")
row_parallel_weights: List[str] = []
for layer in self._row_parallel_layers:
for suffix in weight_suffixes:
row_parallel_weights.append(f"{layer}.{suffix}")
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.num_key_value_heads // tp_size)
attention_weight_specs = [
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
("k_proj", kv_proj_shard_size, q_proj_shard_size),
("v_proj", kv_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size),
]
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
is_packed = False
is_transposed = False
if self.quant_config is not None:
is_packed = self.quant_config.is_packed(name)
is_transposed = self.quant_config.is_transposed(name)
if is_transposed:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
loaded_weight = loaded_weight.T
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "qkv_proj")]
if is_transposed:
param = param.T
if is_packed:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
if is_transposed:
param = param.T
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
if is_transposed:
param = param.T
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
column_parallel_weights,
row_parallel_weights,
tensor_model_parallel_rank)

View File

@@ -244,12 +244,13 @@ class MPTForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "Wqkv" in name: if "Wqkv" in name:
# NOTE(woosuk): MPT's fused QKV has the shape of # NOTE(woosuk): MPT's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size]. # [3 * num_heads * head_size, hidden_size].

View File

@@ -297,12 +297,13 @@ class OPTForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name: if "lm_head.weight" in name:
continue continue

View File

@@ -8,7 +8,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@@ -76,8 +76,12 @@ class QWenMLP(nn.Module):
class QWenAttention(nn.Module): class QWenAttention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, def __init__(self,
max_position_embeddings: int): hidden_size: int,
num_heads: int,
max_position_embeddings: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
@@ -109,8 +113,9 @@ class QWenAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
base=rope_theta,
max_position=max_position_embeddings, max_position=max_position_embeddings,
) rope_scaling=rope_scaling)
def forward( def forward(
self, self,
@@ -135,14 +140,19 @@ class QWenBlock(nn.Module):
def __init__(self, config: QWenConfig): def __init__(self, config: QWenConfig):
super().__init__() super().__init__()
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.attn = QWenAttention(config.n_embd, config.num_attention_heads, rope_theta = getattr(config, "rope_theta", 10000)
config.max_position_embeddings) rope_scaling = getattr(config, "rope_scaling", None)
self.attn = QWenAttention(config.hidden_size,
config.num_attention_heads,
config.max_position_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling)
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.n_embd, config.ffn_hidden_size // 2) self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2)
def forward( def forward(
self, self,
@@ -181,11 +191,11 @@ class QWenModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size, self.wte = VocabParallelEmbedding(vocab_size,
config.n_embd, config.hidden_size,
perform_initialization=False) perform_initialization=False)
self.h = nn.ModuleList( self.h = nn.ModuleList(
[QWenBlock(config) for _ in range(config.num_hidden_layers)]) [QWenBlock(config) for _ in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward( def forward(
self, self,
@@ -221,7 +231,7 @@ class QWenLMHeadModel(nn.Module):
self.transformer = QWenModel(config) self.transformer = QWenModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear( self.lm_head = ColumnParallelLinear(
config.n_embd, config.hidden_size,
vocab_size, vocab_size,
bias=False, bias=False,
gather_output=False, gather_output=False,
@@ -251,13 +261,14 @@ class QWenLMHeadModel(nn.Module):
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None,
): ):
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue

View File

@@ -4,7 +4,7 @@
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch # repo: https://github.com/pytorch/pytorch
from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -16,13 +16,11 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from .mappings import ( from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region, gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region,
) )
from .random import get_cuda_rng_tracker
from .utils import ( from .utils import (
divide, divide,
VocabUtility, VocabUtility,
@@ -65,59 +63,6 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
maybe_copy(attribute) maybe_copy(attribute)
def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
set_tensor_model_parallel_attributes(tensor=weight,
is_parallel=True,
dim=partition_dim,
stride=stride)
with get_cuda_rng_tracker().fork():
init_method(weight)
def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_size, partition_dim,
init_method, stride=1,
return_master_weight=False,
*, params_dtype=None):
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
the relevant chunk."""
set_tensor_model_parallel_attributes(tensor=weight,
is_parallel=True,
dim=partition_dim,
stride=stride)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Initialize master weight
master_weight = torch.empty(output_size, input_size,
dtype=torch.float,
requires_grad=False)
init_method(master_weight)
master_weight = master_weight.to(dtype=params_dtype)
# Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim)
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
torch.cat(my_weight_list, dim=partition_dim, out=weight)
if return_master_weight:
return master_weight
return None
class VocabParallelEmbedding(torch.nn.Module): class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension. """Embedding parallelized in the vocabulary dimension.
@@ -138,8 +83,11 @@ class VocabParallelEmbedding(torch.nn.Module):
init_method=init.xavier_normal_, init_method=init.xavier_normal_,
params_dtype: torch.dtype=None, params_dtype: torch.dtype=None,
use_cpu_initialization: bool=False, use_cpu_initialization: bool=False,
perform_initialization: bool=True): perform_initialization: bool=False):
super(VocabParallelEmbedding, self).__init__() super(VocabParallelEmbedding, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep the input dimensions. # Keep the input dimensions.
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
@@ -162,23 +110,9 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_per_partition = self.vocab_end_index - \ self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index self.vocab_start_index
# Allocate weights and initialize. self.weight = Parameter(torch.empty(
if use_cpu_initialization: self.num_embeddings_per_partition, self.embedding_dim,
self.weight = Parameter(torch.empty( device=torch.cuda.current_device(), dtype=params_dtype))
self.num_embeddings_per_partition, self.embedding_dim,
dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method,
params_dtype=params_dtype)
else:
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
device=torch.cuda.current_device(), dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=1)
def forward(self, input_): def forward(self, input_):
if self.tensor_model_parallel_size > 1: if self.tensor_model_parallel_size > 1:
@@ -238,9 +172,12 @@ class ColumnParallelLinear(torch.nn.Module):
skip_bias_add=False, skip_bias_add=False,
params_dtype=None, params_dtype=None,
use_cpu_initialization=False, use_cpu_initialization=False,
perform_initialization=True, perform_initialization=False,
quant_config=None,
): ):
super(ColumnParallelLinear, self).__init__() super(ColumnParallelLinear, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep input parameters # Keep input parameters
self.input_size = input_size self.input_size = input_size
@@ -250,6 +187,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.world_size = get_tensor_model_parallel_world_size() self.world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, self.world_size) self.output_size_per_partition = divide(output_size, self.world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
@@ -257,33 +195,13 @@ class ColumnParallelLinear(torch.nn.Module):
# Parameters. # Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose. # we allocate the transpose.
# Initialize weight. self.create_weights(params_dtype)
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size_per_partition,
self.input_size,
dtype=params_dtype))
if perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
else:
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride)
if bias: if bias:
if use_cpu_initialization: self.bias = Parameter(torch.empty(
self.bias = Parameter(torch.empty( self.output_size_per_partition,
self.output_size_per_partition, dtype=params_dtype)) device=torch.cuda.current_device(),
else: dtype=params_dtype))
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_tensor_model_parallel_attributes(self.bias, True, 0, stride) set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
@@ -291,6 +209,17 @@ class ColumnParallelLinear(torch.nn.Module):
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=dtype))
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
return F.linear(x, self.weight, bias)
def forward(self, input_): def forward(self, input_):
"""Forward of ColumnParallelLinear """Forward of ColumnParallelLinear
@@ -306,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_parallel = input_ input_parallel = input_
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias) output_parallel = self.apply_weights(input_parallel, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
@@ -359,10 +288,13 @@ class RowParallelLinear(torch.nn.Module):
skip_bias_add=False, skip_bias_add=False,
params_dtype=None, params_dtype=None,
use_cpu_initialization=False, use_cpu_initialization=False,
perform_initialization=True, perform_initialization=False,
reduce_results=True, reduce_results=True,
quant_config=None,
): ):
super(RowParallelLinear, self).__init__() super(RowParallelLinear, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep input parameters # Keep input parameters
self.input_size = input_size self.input_size = input_size
@@ -376,47 +308,32 @@ class RowParallelLinear(torch.nn.Module):
self.world_size = get_tensor_model_parallel_world_size() self.world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.world_size) self.input_size_per_partition = divide(input_size, self.world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
self.create_weights(params_dtype)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results") "results can lead to incorrect results")
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size,
self.input_size_per_partition,
dtype=params_dtype))
if perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test,
params_dtype=params_dtype)
else:
self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=1, stride=stride)
if bias: if bias:
if use_cpu_initialization: self.bias = Parameter(torch.empty(
self.bias = Parameter(torch.empty(self.output_size, self.output_size, device=torch.cuda.current_device(),
dtype=params_dtype)) dtype=params_dtype))
else:
self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(),
dtype=params_dtype))
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.weight_t = self.weight.t()
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=dtype))
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight)
def forward(self, input_): def forward(self, input_):
"""Forward of RowParallelLinear """Forward of RowParallelLinear
@@ -434,7 +351,7 @@ class RowParallelLinear(torch.nn.Module):
else: else:
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight) output_parallel = self.apply_weights(input_parallel)
if self.reduce_results and self.world_size > 1: if self.reduce_results and self.world_size > 1:
output_ = reduce_from_tensor_model_parallel_region(output_parallel) output_ = reduce_from_tensor_model_parallel_region(output_parallel)
else: else:

View File

@@ -0,0 +1,20 @@
from typing import Type
from vllm.model_executor.quantization_utils.awq import AWQConfig
from vllm.model_executor.quantization_utils.base import QuantizationConfig
_QUANTIZATION_REGISTRY = {
"awq": AWQConfig,
}
def get_quant_class(quantization: str) -> Type[QuantizationConfig]:
if quantization not in _QUANTIZATION_REGISTRY:
raise ValueError(f"Invalid quantization method: {quantization}")
return _QUANTIZATION_REGISTRY[quantization]
__all__ = [
"QuantizationConfig",
"get_quant_class",
]

View File

@@ -0,0 +1,72 @@
from typing import Any, Dict, List
import torch
from vllm.model_executor.quantization_utils.base import QuantizationConfig
class AWQConfig(QuantizationConfig):
"""Config class for AWQ.
Reference: https://arxiv.org/abs/2306.00978
"""
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"AWQ, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point})")
@classmethod
def get_name(cls) -> str:
return "awq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
# The AWQ kernel only supports Ampere or newer GPUs.
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
@classmethod
def get_packed_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros"]
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]
@classmethod
def get_tp_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]

View File

@@ -0,0 +1,75 @@
from typing import Any, Dict, List
import torch
class QuantizationConfig:
@classmethod
def get_name(cls) -> str:
"""Name of the quantization method."""
raise NotImplementedError
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
"""List of supported activation dtypes."""
raise NotImplementedError
@classmethod
def get_min_capability(cls) -> int:
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
This requirement is due to the custom CUDA kernels used by the
quantization method.
"""
raise NotImplementedError
@classmethod
def get_config_filenames(cls) -> List[str]:
"""List of filenames to search for in the model directory."""
raise NotImplementedError
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config."""
raise NotImplementedError
@staticmethod
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
return config[key]
raise ValueError(f"Cannot find any of {keys} in the model's "
"quantization config.")
@classmethod
def get_packed_tensor_names(cls) -> List[str]:
raise NotImplementedError
@classmethod
def is_packed(cls, tensor_name: str) -> bool:
"""Returns True if a tensor is packed.
A tensor is considered packed if each element in the tensor is a
packed representation of multiple elements in the original tensor.
For example, an INT32 element in the tensor may represent 8 INT4
elements in the original tensor.
"""
return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
raise NotImplementedError
@classmethod
def is_transposed(cls, tensor_name: str) -> bool:
"""Returns True if a tensor is transposed relative to nn.Linear.weight.
"""
return any(tag in tensor_name
for tag in cls.get_transposed_tensor_names())
@classmethod
def get_tp_tensor_names(cls) -> List[str]:
raise NotImplementedError

View File

@@ -4,7 +4,7 @@ import glob
import json import json
import os import os
from collections import defaultdict from collections import defaultdict
from typing import Iterator, List, Optional, Tuple, Any from typing import Any, Iterator, List, Optional, Tuple
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file, safe_open from safetensors.torch import load_file, save_file, safe_open
@@ -13,6 +13,8 @@ import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.quantization_utils import get_quant_class
from vllm.model_executor.quantization_utils.base import QuantizationConfig
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -44,7 +46,7 @@ def _shared_pointers(tensors):
def convert_bin_to_safetensor_file( def convert_bin_to_safetensor_file(
pt_filename: str, pt_filename: str,
sf_filename: str, sf_filename: str,
): ) -> None:
loaded = torch.load(pt_filename, map_location="cpu") loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded: if "state_dict" in loaded:
loaded = loaded["state_dict"] loaded = loaded["state_dict"]
@@ -78,15 +80,55 @@ def convert_bin_to_safetensor_file(
raise RuntimeError(f"The output tensors do not match for key {k}") raise RuntimeError(f"The output tensors do not match for key {k}")
# TODO(woosuk): Move this to other place.
def get_quant_config(
quantization: str,
model_name_or_path: str,
cache_dir: Optional[str] = None,
) -> QuantizationConfig:
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path,
allow_patterns="*.json",
cache_dir=cache_dir,
tqdm_class=Disabledtqdm)
else:
hf_folder = model_name_or_path
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
quant_cls = get_quant_class(quantization)
quant_config_files = [
f for f in config_files if any(
f.endswith(x) for x in quant_cls.get_config_filenames())
]
if len(quant_config_files) == 0:
raise ValueError(f"Cannot find the config file for {quantization}")
if len(quant_config_files) > 1:
raise ValueError(f"Found multiple config files for {quantization}: "
f"{quant_config_files}")
quant_config_file = quant_config_files[0]
with open(quant_config_file, "r") as f:
config = json.load(f)
return quant_cls.from_config(config)
def prepare_hf_model_weights( def prepare_hf_model_weights(
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_safetensors: bool = False, use_safetensors: bool = False,
fall_back_to_pt: bool = True, fall_back_to_pt: bool = True,
): revision: Optional[str] = None,
) -> Tuple[str, List[str], bool]:
# Download model weights from huggingface. # Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path) is_local = os.path.isdir(model_name_or_path)
allow_patterns = "*.safetensors" if use_safetensors else "*.bin" if use_safetensors:
allow_patterns = ["*.safetensors"]
else:
# Some quantized models use .pt files for storing the weights.
allow_patterns = ["*.bin", "*.pt"]
if not is_local: if not is_local:
# Use file lock to prevent multiple processes from # Use file lock to prevent multiple processes from
# downloading the same model weights at the same time. # downloading the same model weights at the same time.
@@ -94,10 +136,13 @@ def prepare_hf_model_weights(
hf_folder = snapshot_download(model_name_or_path, hf_folder = snapshot_download(model_name_or_path,
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
cache_dir=cache_dir, cache_dir=cache_dir,
tqdm_class=Disabledtqdm) tqdm_class=Disabledtqdm,
revision=revision)
else: else:
hf_folder = model_name_or_path hf_folder = model_name_or_path
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns)) hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if not use_safetensors: if not use_safetensors:
hf_weights_files = [ hf_weights_files = [
x for x in hf_weights_files if not x.endswith("training_args.bin") x for x in hf_weights_files if not x.endswith("training_args.bin")
@@ -107,7 +152,8 @@ def prepare_hf_model_weights(
return prepare_hf_model_weights(model_name_or_path, return prepare_hf_model_weights(model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
use_safetensors=False, use_safetensors=False,
fall_back_to_pt=False) fall_back_to_pt=False,
revision=revision)
if len(hf_weights_files) == 0: if len(hf_weights_files) == 0:
raise RuntimeError( raise RuntimeError(
@@ -120,6 +166,7 @@ def hf_model_weights_iterator(
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None,
) -> Iterator[Tuple[str, torch.Tensor]]: ) -> Iterator[Tuple[str, torch.Tensor]]:
use_safetensors = False use_safetensors = False
use_np_cache = False use_np_cache = False
@@ -140,7 +187,8 @@ def hf_model_weights_iterator(
model_name_or_path, model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
fall_back_to_pt=fall_back_to_pt) fall_back_to_pt=fall_back_to_pt,
revision=revision)
if use_np_cache: if use_np_cache:
# Currently np_cache only support *.bin checkpoints # Currently np_cache only support *.bin checkpoints

View File

@@ -1,9 +1,17 @@
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
from enum import IntEnum
from functools import cached_property
from typing import List, Optional, Union from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
class SamplingType(IntEnum):
GREEDY = 0
RANDOM = 1
BEAM = 2
class SamplingParams: class SamplingParams:
"""Sampling parameters for text generation. """Sampling parameters for text generation.
@@ -45,10 +53,15 @@ class SamplingParams:
(canonical beam search algorithm). (canonical beam search algorithm).
stop: List of strings that stop the generation when they are generated. stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings. The returned output will not contain the stop strings.
stop_token_ids: List of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
the stop tokens are sepcial tokens.
ignore_eos: Whether to ignore the EOS token and continue generating ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated. tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence. max_tokens: Maximum number of tokens to generate per output sequence.
logprobs: Number of log probabilities to return per output token. logprobs: Number of log probabilities to return per output token.
skip_special_tokens: Whether to skip special tokens in the output.
Defaults to true.
""" """
def __init__( def __init__(
@@ -64,9 +77,11 @@ class SamplingParams:
length_penalty: float = 1.0, length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False, early_stopping: Union[bool, str] = False,
stop: Union[None, str, List[str]] = None, stop: Union[None, str, List[str]] = None,
stop_token_ids: List[int] = None,
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: int = 16, max_tokens: int = 16,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
) -> None: ) -> None:
self.n = n self.n = n
self.best_of = best_of if best_of is not None else n self.best_of = best_of if best_of is not None else n
@@ -84,9 +99,14 @@ class SamplingParams:
self.stop = [stop] self.stop = [stop]
else: else:
self.stop = list(stop) self.stop = list(stop)
if stop_token_ids is None:
self.stop_token_ids = []
else:
self.stop_token_ids = list(stop_token_ids)
self.ignore_eos = ignore_eos self.ignore_eos = ignore_eos
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.logprobs = logprobs self.logprobs = logprobs
self.skip_special_tokens = skip_special_tokens
self._verify_args() self._verify_args()
if self.use_beam_search: if self.use_beam_search:
@@ -158,6 +178,14 @@ class SamplingParams:
if self.top_k != -1: if self.top_k != -1:
raise ValueError("top_k must be -1 when using greedy sampling.") raise ValueError("top_k must be -1 when using greedy sampling.")
@cached_property
def sampling_type(self) -> SamplingType:
if self.use_beam_search:
return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY
return SamplingType.RANDOM
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SamplingParams(n={self.n}, " return (f"SamplingParams(n={self.n}, "
f"best_of={self.best_of}, " f"best_of={self.best_of}, "
@@ -172,4 +200,5 @@ class SamplingParams:
f"stop={self.stop}, " f"stop={self.stop}, "
f"ignore_eos={self.ignore_eos}, " f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, " f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs})") f"logprobs={self.logprobs}, "
f"skip_special_tokens={self.skip_special_tokens})")

View File

@@ -114,7 +114,6 @@ class Sequence:
self.data = SequenceData(prompt_token_ids) self.data = SequenceData(prompt_token_ids)
self.output_logprobs: List[Dict[int, float]] = [] self.output_logprobs: List[Dict[int, float]] = []
self.output_tokens: List[str] = []
self.output_text = "" self.output_text = ""
self.logical_token_blocks: List[LogicalTokenBlock] = [] self.logical_token_blocks: List[LogicalTokenBlock] = []
@@ -122,6 +121,12 @@ class Sequence:
self._append_tokens_to_blocks(prompt_token_ids) self._append_tokens_to_blocks(prompt_token_ids)
self.status = SequenceStatus.WAITING self.status = SequenceStatus.WAITING
# Used for incremental detokenization
self.prefix_offset = 0
self.read_offset = 0
# Input + output tokens
self.tokens: Optional[List[str]] = None
def _append_logical_block(self) -> None: def _append_logical_block(self) -> None:
block = LogicalTokenBlock( block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks), block_number=len(self.logical_token_blocks),
@@ -245,8 +250,8 @@ class SequenceGroup:
# generation stage, we will have `best_of` sequences running. # generation stage, we will have `best_of` sequences running.
return self.sampling_params.best_of return self.sampling_params.best_of
# At sampling stages, return the number of actual sequences # At sampling stages, return the number of actual sequences
# running. # that are not finished yet.
return self.num_seqs(status=SequenceStatus.RUNNING) return self.num_unfinished_seqs()
def get_seqs( def get_seqs(
self, self,
@@ -259,12 +264,23 @@ class SequenceGroup:
seq for seq in self.seqs_dict.values() if seq.status == status seq for seq in self.seqs_dict.values() if seq.status == status
] ]
def get_unfinished_seqs(self) -> List[Sequence]:
return [
seq for seq in self.seqs_dict.values() if not seq.is_finished()
]
def get_finished_seqs(self) -> List[Sequence]: def get_finished_seqs(self) -> List[Sequence]:
return [seq for seq in self.seqs_dict.values() if seq.is_finished()] return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status)) return len(self.get_seqs(status))
def num_unfinished_seqs(self) -> int:
return len(self.get_unfinished_seqs())
def num_finished_seqs(self) -> int:
return len(self.get_finished_seqs())
def find(self, seq_id: int) -> Sequence: def find(self, seq_id: int) -> Sequence:
if seq_id not in self.seqs_dict: if seq_id not in self.seqs_dict:
raise ValueError(f"Sequence {seq_id} not found.") raise ValueError(f"Sequence {seq_id} not found.")
@@ -345,7 +361,7 @@ class SequenceOutputs:
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs): if not isinstance(other, SequenceOutputs):
return NotImplementedError() raise NotImplementedError()
return (self.parent_seq_id == other.parent_seq_id return (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token and self.output_token == other.output_token
and self.logprobs == other.logprobs) and self.logprobs == other.logprobs)

View File

@@ -1,3 +1,5 @@
from typing import Optional
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
@@ -12,10 +14,21 @@ _CONFIG_REGISTRY = {
} }
def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig: def get_config(model: str,
trust_remote_code: bool,
revision: Optional[str] = None) -> PretrainedConfig:
# NOTE: Because the Mistral model in HF hub does not have
# `configuration_mistral.py`, we cannot use `AutoConfig` to load the
# config. Instead, we use `MistralConfig` directly.
# NOTE: This is a hack. This does not work for local models.
# FIXME: Remove this once the Mistral model is available in the stable
# version of HF transformers.
if "mistral" in model.lower():
return MistralConfig.from_pretrained(model, revision=revision)
try: try:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code) model, trust_remote_code=trust_remote_code, revision=revision)
except ValueError as e: except ValueError as e:
if (not trust_remote_code and if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)): "requires you to execute the configuration file" in str(e)):
@@ -29,5 +42,5 @@ def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
raise e raise e
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type] config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model) config = config_class.from_pretrained(model, revision=revision)
return config return config

View File

@@ -6,6 +6,7 @@ from vllm.transformers_utils.configs.qwen import QWenConfig
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library. # `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.mistral import MistralConfig
__all__ = [ __all__ = [
"MPTConfig", "MPTConfig",
@@ -13,4 +14,5 @@ __all__ = [
"AquilaConfig", "AquilaConfig",
"QWenConfig", "QWenConfig",
"RWConfig", "RWConfig",
"MistralConfig",
] ]

View File

@@ -0,0 +1,66 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mistral-7B-v0.1 configuration"""
from transformers.configuration_utils import PretrainedConfig
class MistralConfig(PretrainedConfig):
model_type = "mistral"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=4096,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@@ -7,65 +7,54 @@ from transformers import PretrainedConfig
class QWenConfig(PretrainedConfig): class QWenConfig(PretrainedConfig):
model_type = "qwen" model_type = "qwen"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"hidden_size": "n_embd",
"num_attention_heads": "n_head",
"max_position_embeddings": "n_positions",
"num_hidden_layers": "n_layer",
}
def __init__( def __init__(
self, self,
vocab_size=151851, vocab_size=151936,
n_embd=4096, hidden_size=4096,
n_layer=32, num_hidden_layers=32,
n_head=32, num_attention_heads=32,
n_inner=None, emb_dropout_prob=0.0,
embd_pdrop=0.0, attn_dropout_prob=0.0,
attn_pdrop=0.0, layer_norm_epsilon=1e-6,
layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
max_position_embeddings=8192,
scale_attn_weights=True, scale_attn_weights=True,
use_cache=True, use_cache=True,
eos_token_id=151643, bf16=False,
apply_residual_connection_post_layernorm=False, fp16=False,
bf16=True, fp32=False,
kv_channels=128, kv_channels=128,
rotary_pct=1.0, rotary_pct=1.0,
rotary_emb_base=10000, rotary_emb_base=10000,
use_dynamic_ntk=False, use_dynamic_ntk=True,
use_logn_attn=False, use_logn_attn=True,
use_flash_attn=True, use_flash_attn="auto",
ffn_hidden_size=22016, intermediate_size=22016,
no_bias=True, no_bias=True,
tie_word_embeddings=False, tie_word_embeddings=False,
**kwargs, **kwargs,
): ):
self.eos_token_id = eos_token_id
super().__init__(eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_embd = n_embd self.hidden_size = hidden_size
self.n_layer = n_layer self.intermediate_size = intermediate_size
self.n_head = n_head self.num_hidden_layers = num_hidden_layers
self.n_inner = n_inner self.num_attention_heads = num_attention_heads
self.embd_pdrop = embd_pdrop self.emb_dropout_prob = emb_dropout_prob
self.attn_pdrop = attn_pdrop self.attn_dropout_prob = attn_dropout_prob
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache self.use_cache = use_cache
self.apply_residual_connection_post_layernorm = ( self.max_position_embeddings = max_position_embeddings
apply_residual_connection_post_layernorm)
self.bf16 = bf16 self.bf16 = bf16
self.fp16 = fp16
self.fp32 = fp32
self.kv_channels = kv_channels self.kv_channels = kv_channels
self.rotary_pct = rotary_pct self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base self.rotary_emb_base = rotary_emb_base
self.use_dynamic_ntk = use_dynamic_ntk self.use_dynamic_ntk = use_dynamic_ntk
self.use_logn_attn = use_logn_attn self.use_logn_attn = use_logn_attn
self.use_flash_attn = use_flash_attn self.use_flash_attn = use_flash_attn
self.ffn_hidden_size = ffn_hidden_size
self.no_bias = no_bias self.no_bias = no_bias
self.tie_word_embeddings = tie_word_embeddings super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

View File

@@ -1,4 +1,4 @@
from typing import List, Tuple, Union from typing import List, Optional, Tuple, Union
from transformers import (AutoTokenizer, PreTrainedTokenizer, from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast) PreTrainedTokenizerFast)
@@ -28,8 +28,8 @@ def get_tokenizer(
if ("llama" in tokenizer_name.lower() and kwargs.get("use_fast", True) if ("llama" in tokenizer_name.lower() and kwargs.get("use_fast", True)
and tokenizer_name != _FAST_LLAMA_TOKENIZER): and tokenizer_name != _FAST_LLAMA_TOKENIZER):
logger.info( logger.info(
"For some LLaMA-based models, initializing the fast tokenizer may " "For some LLaMA V1 models, initializing the fast tokenizer may "
"take a long time. To eliminate the initialization time, consider " "take a long time. To reduce the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer.") "tokenizer.")
try: try:
@@ -41,9 +41,9 @@ def get_tokenizer(
except TypeError as e: except TypeError as e:
# The LLaMA tokenizer causes a protobuf error in some environments. # The LLaMA tokenizer causes a protobuf error in some environments.
err_msg = ( err_msg = (
"Failed to load the tokenizer. If you are using a LLaMA-based " "Failed to load the tokenizer. If you are using a LLaMA V1 model "
f"model, use '{_FAST_LLAMA_TOKENIZER}' instead of the original " f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
"tokenizer.") "original tokenizer.")
raise RuntimeError(err_msg) from e raise RuntimeError(err_msg) from e
except ValueError as e: except ValueError as e:
# If the error pertains to the tokenizer class not existing or not # If the error pertains to the tokenizer class not existing or not
@@ -67,33 +67,11 @@ def get_tokenizer(
return tokenizer return tokenizer
def detokenize_incrementally( def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prev_output_tokens: List[str], output_tokens: List[str],
new_token_id: int,
skip_special_tokens: bool, skip_special_tokens: bool,
) -> Tuple[str, str]: ) -> str:
"""Detokenizes the new token in conjunction with the previous output tokens.
NOTE: This function does not update prev_output_tokens.
Returns:
new_token: The new token as a string.
output_text: The new output text as a string.
"""
if skip_special_tokens and (new_token_id in tokenizer.all_special_ids):
return None, prev_output_tokens
new_token = tokenizer.convert_ids_to_tokens(
new_token_id, skip_special_tokens=skip_special_tokens)
output_tokens = prev_output_tokens + [new_token]
# Convert the tokens to a string.
# Optimization: If the tokenizer does not have `added_tokens_encoder`,
# then we can directly use `convert_tokens_to_string`.
if not getattr(tokenizer, "added_tokens_encoder", {}):
output_text = tokenizer.convert_tokens_to_string(output_tokens)
return new_token, output_text
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over # NOTE(woosuk): The following code is slow because it runs a for loop over
@@ -115,5 +93,61 @@ def detokenize_incrementally(
if current_sub_text: if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text) sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text) sub_texts.append(sub_text)
output_text = " ".join(sub_texts) return " ".join(sub_texts)
return new_token, output_text
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
prefix_offset: int = 0,
read_offset: int = 0,
skip_special_tokens: bool = False,
) -> Tuple[List[str], str, int, int]:
new_token_id = all_input_ids[-1]
# This is the first iteration for this sequence
if prev_tokens is None:
new_tokens = tokenizer.convert_ids_to_tokens(
all_input_ids, skip_special_tokens=skip_special_tokens)
output_tokens = new_tokens
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
# Subtract 1 extra to account for the generated token.
prefix_offset = max(len(output_tokens) - 6, 0)
read_offset = max(len(output_tokens) - 1, 0)
else:
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
output_tokens = prev_tokens + new_tokens
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# surrounding ids.
if not getattr(tokenizer, "added_tokens_encoder", {}):
prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:read_offset])
new_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:])
else:
prefix_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens)
new_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens)
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_text = new_text[len(prefix_text):]
return new_tokens, new_text, read_offset, len(output_tokens)
else:
return new_tokens, "", prefix_offset, read_offset

View File

@@ -1,10 +1,12 @@
import enum import enum
from platform import uname
import uuid import uuid
from platform import uname
import psutil import psutil
import torch import torch
from vllm import cuda_utils
class Device(enum.Enum): class Device(enum.Enum):
GPU = enum.auto() GPU = enum.auto()
@@ -25,6 +27,15 @@ class Counter:
self.counter = 0 self.counter = 0
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 # pylint: disable=invalid-name
max_shared_mem = cuda_utils.get_device_attribute(
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
return int(max_shared_mem)
def get_gpu_memory(gpu: int = 0) -> int: def get_gpu_memory(gpu: int = 0) -> int:
"""Returns the total memory of the GPU in bytes.""" """Returns the total memory of the GPU in bytes."""
return torch.cuda.get_device_properties(gpu).total_memory return torch.cuda.get_device_properties(gpu).total_memory

View File

@@ -33,7 +33,7 @@ class CacheEngine:
self.head_size = model_config.get_head_size() self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config) self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_heads(parallel_config) self.num_heads = model_config.get_num_kv_heads(parallel_config)
self.dtype = model_config.dtype self.dtype = model_config.dtype
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
@@ -146,7 +146,7 @@ class CacheEngine:
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
) -> int: ) -> int:
head_size = model_config.get_head_size() head_size = model_config.get_head_size()
num_heads = model_config.get_num_heads(parallel_config) num_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config) num_layers = model_config.get_num_layers(parallel_config)
key_cache_block = block_size * num_heads * head_size key_cache_block = block_size * num_heads * head_size

View File

@@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.utils import get_gpu_memory from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
class Worker: class Worker:
@@ -42,6 +42,7 @@ class Worker:
# self.init_cache_engine(). # self.init_cache_engine().
self.cache_config = None self.cache_config = None
self.block_size = None self.block_size = None
self.sliding_window = None
self.cache_engine = None self.cache_engine = None
self.cache_events = None self.cache_events = None
self.gpu_cache = None self.gpu_cache = None
@@ -136,6 +137,15 @@ class Worker:
def init_cache_engine(self, cache_config: CacheConfig) -> None: def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config self.cache_config = cache_config
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.sliding_window = cache_config.sliding_window
if self.sliding_window is None:
max_seq_len = self.scheduler_config.max_model_len
else:
max_seq_len = min(self.scheduler_config.max_model_len,
self.sliding_window)
_check_if_can_support_max_seq_len(max_seq_len, self.block_size)
self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config) self.parallel_config)
self.cache_events = self.cache_engine.events self.cache_events = self.cache_engine.events
@@ -207,10 +217,11 @@ class Worker:
context_len = seq_data.get_len() context_len = seq_data.get_len()
position = context_len - 1 position = context_len - 1
if self.sliding_window is not None:
context_len = min(context_len, self.sliding_window)
input_positions.append(position) input_positions.append(position)
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
generation_block_tables.append(block_table)
max_context_len = max(max_context_len, context_len) max_context_len = max(max_context_len, context_len)
max_num_blocks_per_seq = max(max_num_blocks_per_seq, max_num_blocks_per_seq = max(max_num_blocks_per_seq,
@@ -222,21 +233,37 @@ class Worker:
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping.append(slot)
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
generation_block_tables.append(block_table)
# Optimization: Pad the input length to be a multiple of 8. # Optimization: Pad the input length to be a multiple of 8.
# This is required for utilizing the Tensor Cores in NVIDIA GPUs. # This is required for utilizing the Tensor Cores in NVIDIA GPUs.
input_tokens = _pad_to_alignment(input_tokens, multiple_of=8) input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
input_positions = _pad_to_alignment(input_positions, multiple_of=8) input_positions = _pad_to_alignment(input_positions, multiple_of=8)
# Convert to tensors. # Convert to tensors.
tokens_tensor = torch.cuda.LongTensor(input_tokens) tokens_tensor = torch.tensor(input_tokens,
positions_tensor = torch.cuda.LongTensor(input_positions) dtype=torch.long,
slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping) device="cuda")
context_lens_tensor = torch.cuda.IntTensor(context_lens) positions_tensor = torch.tensor(input_positions,
dtype=torch.long,
device="cuda")
slot_mapping_tensor = torch.tensor(slot_mapping,
dtype=torch.int,
device="cuda")
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device="cuda")
padded_block_tables = [ padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq) _pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables for block_table in generation_block_tables
] ]
block_tables_tensor = torch.cuda.IntTensor(padded_block_tables) block_tables_tensor = torch.tensor(padded_block_tables,
dtype=torch.int,
device="cuda")
seq_data: Dict[int, SequenceData] = {} seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
@@ -250,6 +277,7 @@ class Worker:
context_lens=context_lens_tensor, context_lens=context_lens_tensor,
max_context_len=max_context_len, max_context_len=max_context_len,
block_tables=block_tables_tensor, block_tables=block_tables_tensor,
sliding_window=self.sliding_window,
) )
return tokens_tensor, positions_tensor, input_metadata return tokens_tensor, positions_tensor, input_metadata
@@ -337,3 +365,23 @@ def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
def _pad_to_max(x: List[int], max_len: int) -> List[int]: def _pad_to_max(x: List[int], max_len: int) -> List[int]:
return x + [0] * (max_len - len(x)) return x + [0] * (max_len - len(x))
def _check_if_can_support_max_seq_len(max_seq_len: int,
block_size: int) -> None:
# Follows the logic in
# attention_kernels.cu::single_query_cached_kv_attention_launcher
max_shared_mem = get_max_shared_memory_bytes()
float32_bytes = torch.finfo(torch.float).bits // 8
padded_max_seq_len = (
(max_seq_len + block_size - 1) / block_size) * block_size
# padded_max_seq_len + extra buffer
required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
if padded_max_seq_len * float32_bytes > max_shared_mem:
raise RuntimeError(
f"vLLM cannot currently support max_model_len={max_seq_len} "
f"with block_size={block_size} on GPU with compute "
f"capability {torch.cuda.get_device_capability()} "
f"(required shared memory {required_shared_mem} > "
f"available shared memory {max_shared_mem}). "
"This will be fixed in a future release.")