Compare commits
61 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2fb71ec9f | ||
|
|
f936657eb6 | ||
|
|
6f88f762bf | ||
|
|
202351d5bf | ||
|
|
2e8e49fce3 | ||
|
|
a8e98aee0c | ||
|
|
bb1ba58f06 | ||
|
|
7bedab5748 | ||
|
|
20f7cc4cde | ||
|
|
649aa730c5 | ||
|
|
a19bc5c628 | ||
|
|
28e616c4e3 | ||
|
|
30e775281d | ||
|
|
21877b0d75 | ||
|
|
cf5cb1e33e | ||
|
|
03ffd0a022 | ||
|
|
a425bd9a9a | ||
|
|
bbbf86565f | ||
|
|
9f6be8692e | ||
|
|
f187877945 | ||
|
|
947b794146 | ||
|
|
8d926e91f1 | ||
|
|
4ee52bb169 | ||
|
|
7d7e3b78a3 | ||
|
|
f98b745a81 | ||
|
|
2d1e86f1b1 | ||
|
|
1ac4ccf73c | ||
|
|
2ac4d5e2bf | ||
|
|
3302f0aef3 | ||
|
|
6f2dd6c37e | ||
|
|
bc0644574c | ||
|
|
400b8289f7 | ||
|
|
c1026311b5 | ||
|
|
2b1c116b5a | ||
|
|
cc796b1358 | ||
|
|
f029ef94d7 | ||
|
|
95592fa00a | ||
|
|
fbe66e1d0b | ||
|
|
90979c38f8 | ||
|
|
e21d7687a9 | ||
|
|
ff36139ffc | ||
|
|
e3e79e9e8a | ||
|
|
b9fe4616f9 | ||
|
|
64ca424e75 | ||
|
|
b5f93d0631 | ||
|
|
a58936966f | ||
|
|
dd54a4b026 | ||
|
|
eda1a7cad3 | ||
|
|
f04908cae7 | ||
|
|
ab019eea75 | ||
|
|
9841d48a10 | ||
|
|
3272d7a0b7 | ||
|
|
0bb1e885a0 | ||
|
|
d6545ad22e | ||
|
|
90eb3f43ca | ||
|
|
e67b4f2c2a | ||
|
|
d6770d1f23 | ||
|
|
b9cecc2635 | ||
|
|
898285c9bf | ||
|
|
a62de9ecfd | ||
|
|
4042d192f5 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -173,3 +173,7 @@ cython_debug/
|
|||||||
|
|
||||||
# Sphinx documentation
|
# Sphinx documentation
|
||||||
_build/
|
_build/
|
||||||
|
|
||||||
|
# vim swap files
|
||||||
|
*.swo
|
||||||
|
*.swp
|
||||||
|
|||||||
32
README.md
32
README.md
@@ -10,13 +10,24 @@ Easy, fast, and cheap LLM serving for everyone
|
|||||||
</h3>
|
</h3>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://github.com/vllm-project/vllm/discussions"><b>Discussions</b></a> |
|
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
||||||
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
**The First vLLM Bay Area Meetup (Oct 5th 6pm-8pm PT)**
|
||||||
|
|
||||||
|
We are excited to invite you to the first vLLM meetup!
|
||||||
|
The vLLM team will share recent updates and roadmap.
|
||||||
|
We will also have vLLM users and contributors coming up to the stage to share their experiences.
|
||||||
|
Please register [here](https://lu.ma/first-vllm-meetup) and join us!
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
*Latest News* 🔥
|
*Latest News* 🔥
|
||||||
|
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
||||||
|
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
||||||
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
||||||
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
|
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
|
||||||
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
||||||
@@ -35,13 +46,13 @@ vLLM is fast with:
|
|||||||
|
|
||||||
vLLM is flexible and easy to use with:
|
vLLM is flexible and easy to use with:
|
||||||
|
|
||||||
- Seamless integration with popular HuggingFace models
|
- Seamless integration with popular Hugging Face models
|
||||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||||
- Tensor parallelism support for distributed inference
|
- Tensor parallelism support for distributed inference
|
||||||
- Streaming outputs
|
- Streaming outputs
|
||||||
- OpenAI-compatible API server
|
- OpenAI-compatible API server
|
||||||
|
|
||||||
vLLM seamlessly supports many Huggingface models, including the following architectures:
|
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
||||||
|
|
||||||
- Aquila (`BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
- Aquila (`BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
||||||
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
|
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
|
||||||
@@ -53,6 +64,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit
|
|||||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
||||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
||||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||||
|
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||||
@@ -72,7 +84,7 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started
|
|||||||
|
|
||||||
## Performance
|
## Performance
|
||||||
|
|
||||||
vLLM outperforms HuggingFace Transformers (HF) by up to 24x and Text Generation Inference (TGI) by up to 3.5x, in terms of throughput.
|
vLLM outperforms Hugging Face Transformers (HF) by up to 24x and Text Generation Inference (TGI) by up to 3.5x, in terms of throughput.
|
||||||
For details, check out our [blog post](https://vllm.ai).
|
For details, check out our [blog post](https://vllm.ai).
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
@@ -104,3 +116,15 @@ For details, check out our [blog post](https://vllm.ai).
|
|||||||
|
|
||||||
We welcome and value any contributions and collaborations.
|
We welcome and value any contributions and collaborations.
|
||||||
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{kwon2023efficient,
|
||||||
|
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
|
||||||
|
author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
|
||||||
|
booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ def main(args: argparse.Namespace):
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=args.model,
|
model=args.model,
|
||||||
tokenizer=args.tokenizer,
|
tokenizer=args.tokenizer,
|
||||||
|
quantization=args.quantization,
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
max_num_seqs=args.batch_size,
|
max_num_seqs=args.batch_size,
|
||||||
max_num_batched_tokens=args.batch_size * args.input_len,
|
max_num_batched_tokens=args.batch_size * args.input_len,
|
||||||
@@ -63,19 +64,28 @@ def main(args: argparse.Namespace):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Benchmark the latency of processing a single batch of '
|
description='Benchmark the latency of processing a single batch of '
|
||||||
'requests till completion.')
|
'requests till completion.')
|
||||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||||
parser.add_argument('--tokenizer', type=str, default=None)
|
parser.add_argument('--tokenizer', type=str, default=None)
|
||||||
|
parser.add_argument('--quantization',
|
||||||
|
'-q',
|
||||||
|
choices=['awq', None],
|
||||||
|
default=None)
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||||
parser.add_argument('--input-len', type=int, default=32)
|
parser.add_argument('--input-len', type=int, default=32)
|
||||||
parser.add_argument('--output-len', type=int, default=128)
|
parser.add_argument('--output-len', type=int, default=128)
|
||||||
parser.add_argument('--batch-size', type=int, default=8)
|
parser.add_argument('--batch-size', type=int, default=8)
|
||||||
parser.add_argument('--n', type=int, default=1,
|
parser.add_argument('--n',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
help='Number of generated sequences per prompt.')
|
help='Number of generated sequences per prompt.')
|
||||||
parser.add_argument('--use-beam-search', action='store_true')
|
parser.add_argument('--use-beam-search', action='store_true')
|
||||||
parser.add_argument('--num-iters', type=int, default=3,
|
parser.add_argument('--num-iters',
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
help='Number of iterations to run.')
|
help='Number of iterations to run.')
|
||||||
parser.add_argument('--trust-remote-code', action='store_true',
|
parser.add_argument('--trust-remote-code',
|
||||||
|
action='store_true',
|
||||||
help='trust remote code from huggingface')
|
help='trust remote code from huggingface')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||||
@@ -22,15 +22,10 @@ def sample_requests(
|
|||||||
with open(dataset_path) as f:
|
with open(dataset_path) as f:
|
||||||
dataset = json.load(f)
|
dataset = json.load(f)
|
||||||
# Filter out the conversations with less than 2 turns.
|
# Filter out the conversations with less than 2 turns.
|
||||||
dataset = [
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
data for data in dataset
|
|
||||||
if len(data["conversations"]) >= 2
|
|
||||||
]
|
|
||||||
# Only keep the first two turns of each conversation.
|
# Only keep the first two turns of each conversation.
|
||||||
dataset = [
|
dataset = [(data["conversations"][0]["value"],
|
||||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
data["conversations"][1]["value"]) for data in dataset]
|
||||||
for data in dataset
|
|
||||||
]
|
|
||||||
|
|
||||||
# Tokenize the prompts and completions.
|
# Tokenize the prompts and completions.
|
||||||
prompts = [prompt for prompt, _ in dataset]
|
prompts = [prompt for prompt, _ in dataset]
|
||||||
@@ -63,6 +58,7 @@ def run_vllm(
|
|||||||
requests: List[Tuple[str, int, int]],
|
requests: List[Tuple[str, int, int]],
|
||||||
model: str,
|
model: str,
|
||||||
tokenizer: str,
|
tokenizer: str,
|
||||||
|
quantization: Optional[str],
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
seed: int,
|
seed: int,
|
||||||
n: int,
|
n: int,
|
||||||
@@ -72,6 +68,7 @@ def run_vllm(
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
quantization=quantization,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
@@ -111,8 +108,8 @@ def run_hf(
|
|||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
) -> float:
|
) -> float:
|
||||||
assert not use_beam_search
|
assert not use_beam_search
|
||||||
llm = AutoModelForCausalLM.from_pretrained(model,
|
llm = AutoModelForCausalLM.from_pretrained(
|
||||||
torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||||
if llm.config.model_type == "llama":
|
if llm.config.model_type == "llama":
|
||||||
# To enable padding in the HF backend.
|
# To enable padding in the HF backend.
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
@@ -132,13 +129,14 @@ def run_hf(
|
|||||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||||
# Check if we can add more requests to the batch.
|
# Check if we can add more requests to the batch.
|
||||||
_, next_prompt_len, next_output_len = requests[i + 1]
|
_, next_prompt_len, next_output_len = requests[i + 1]
|
||||||
if (max(max_prompt_len, next_prompt_len) + max(
|
if (max(max_prompt_len, next_prompt_len) +
|
||||||
max_output_len, next_output_len)) <= 2048:
|
max(max_output_len, next_output_len)) <= 2048:
|
||||||
# We can add more requests to the batch.
|
# We can add more requests to the batch.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Generate the sequences.
|
# Generate the sequences.
|
||||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
input_ids = tokenizer(batch, return_tensors="pt",
|
||||||
|
padding=True).input_ids
|
||||||
llm_outputs = llm.generate(
|
llm_outputs = llm.generate(
|
||||||
input_ids=input_ids.cuda(),
|
input_ids=input_ids.cuda(),
|
||||||
do_sample=not use_beam_search,
|
do_sample=not use_beam_search,
|
||||||
@@ -165,44 +163,58 @@ def main(args: argparse.Namespace):
|
|||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
|
|
||||||
# Sample the requests.
|
# Sample the requests.
|
||||||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
tokenizer = get_tokenizer(args.tokenizer,
|
||||||
|
trust_remote_code=args.trust_remote_code)
|
||||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||||
|
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
elapsed_time = run_vllm(
|
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
||||||
requests, args.model, args.tokenizer, args.tensor_parallel_size,
|
args.quantization, args.tensor_parallel_size,
|
||||||
args.seed, args.n, args.use_beam_search, args.trust_remote_code)
|
args.seed, args.n, args.use_beam_search,
|
||||||
|
args.trust_remote_code)
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
requests, args.model, tokenizer, args.n, args.use_beam_search,
|
args.use_beam_search, args.hf_max_batch_size,
|
||||||
args.hf_max_batch_size, args.trust_remote_code)
|
args.trust_remote_code)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown backend: {args.backend}")
|
raise ValueError(f"Unknown backend: {args.backend}")
|
||||||
total_num_tokens = sum(
|
total_num_tokens = sum(prompt_len + output_len
|
||||||
prompt_len + output_len
|
for _, prompt_len, output_len in requests)
|
||||||
for _, prompt_len, output_len in requests
|
|
||||||
)
|
|
||||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||||
parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
|
parser.add_argument("--backend",
|
||||||
|
type=str,
|
||||||
|
choices=["vllm", "hf"],
|
||||||
default="vllm")
|
default="vllm")
|
||||||
parser.add_argument("--dataset", type=str, required=True,
|
parser.add_argument("--dataset",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
help="Path to the dataset.")
|
help="Path to the dataset.")
|
||||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||||||
parser.add_argument("--tokenizer", type=str, default=None)
|
parser.add_argument("--tokenizer", type=str, default=None)
|
||||||
|
parser.add_argument('--quantization',
|
||||||
|
'-q',
|
||||||
|
choices=['awq', None],
|
||||||
|
default=None)
|
||||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||||
parser.add_argument("--n", type=int, default=1,
|
parser.add_argument("--n",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
help="Number of generated sequences per prompt.")
|
help="Number of generated sequences per prompt.")
|
||||||
parser.add_argument("--use-beam-search", action="store_true")
|
parser.add_argument("--use-beam-search", action="store_true")
|
||||||
parser.add_argument("--num-prompts", type=int, default=1000,
|
parser.add_argument("--num-prompts",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
help="Number of prompts to process.")
|
help="Number of prompts to process.")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--hf-max-batch-size", type=int, default=None,
|
parser.add_argument("--hf-max-batch-size",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
help="Maximum batch size for HF backend.")
|
help="Maximum batch size for HF backend.")
|
||||||
parser.add_argument('--trust-remote-code',
|
parser.add_argument('--trust-remote-code',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
@@ -215,6 +227,8 @@ if __name__ == "__main__":
|
|||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
if args.hf_max_batch_size is None:
|
if args.hf_max_batch_size is None:
|
||||||
raise ValueError("HF max batch size is required for HF backend.")
|
raise ValueError("HF max batch size is required for HF backend.")
|
||||||
|
if args.quantization is not None:
|
||||||
|
raise ValueError("Quantization is only for vLLM backend.")
|
||||||
if args.tokenizer is None:
|
if args.tokenizer is None:
|
||||||
args.tokenizer = args.model
|
args.tokenizer = args.model
|
||||||
|
|
||||||
|
|||||||
@@ -341,6 +341,9 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
||||||
|
cudaFuncSetAttribute( \
|
||||||
|
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
|
||||||
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
out_ptr, \
|
out_ptr, \
|
||||||
@@ -401,6 +404,8 @@ void single_query_cached_kv_attention_launcher(
|
|||||||
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
||||||
int logits_size = padded_max_context_len * sizeof(float);
|
int logits_size = padded_max_context_len * sizeof(float);
|
||||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||||
|
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
||||||
|
// Keep that in sync with the logic here!
|
||||||
int shared_mem_size = std::max(logits_size, outputs_size);
|
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||||
|
|
||||||
dim3 grid(num_heads, num_seqs);
|
dim3 grid(num_heads, num_seqs);
|
||||||
|
|||||||
13
csrc/cuda_utils.cpp
Normal file
13
csrc/cuda_utils.cpp
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
int get_device_attribute(
|
||||||
|
int attribute,
|
||||||
|
int device_id);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def(
|
||||||
|
"get_device_attribute",
|
||||||
|
&get_device_attribute,
|
||||||
|
"Gets the specified device attribute.");
|
||||||
|
}
|
||||||
|
|
||||||
14
csrc/cuda_utils_kernels.cu
Normal file
14
csrc/cuda_utils_kernels.cu
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
int get_device_attribute(
|
||||||
|
int attribute,
|
||||||
|
int device_id)
|
||||||
|
{
|
||||||
|
int device, value;
|
||||||
|
if (device_id < 0) {
|
||||||
|
cudaGetDevice(&device);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
device = device_id;
|
||||||
|
}
|
||||||
|
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
15
csrc/quantization.cpp
Normal file
15
csrc/quantization.cpp
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
torch::Tensor awq_gemm(
|
||||||
|
torch::Tensor _in_feats,
|
||||||
|
torch::Tensor _kernel,
|
||||||
|
torch::Tensor _scaling_factors,
|
||||||
|
torch::Tensor _zeros,
|
||||||
|
int split_k_iters);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def(
|
||||||
|
"awq_gemm",
|
||||||
|
&awq_gemm,
|
||||||
|
"Quantized GEMM for AWQ");
|
||||||
|
}
|
||||||
87
csrc/quantization/awq/dequantize.cuh
Normal file
87
csrc/quantization/awq/dequantize.cuh
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
/*
|
||||||
|
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||||
|
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||||
|
@article{lin2023awq,
|
||||||
|
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||||
|
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||||
|
journal={arXiv},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace awq {
|
||||||
|
|
||||||
|
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||||
|
{
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
uint4 result;
|
||||||
|
|
||||||
|
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||||
|
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||||
|
|
||||||
|
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||||
|
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||||
|
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||||
|
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||||
|
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||||
|
|
||||||
|
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
||||||
|
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
||||||
|
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
||||||
|
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
||||||
|
|
||||||
|
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
||||||
|
// immediately before required.
|
||||||
|
const uint32_t top_i4s = i4s >> 8;
|
||||||
|
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[0])
|
||||||
|
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[1])
|
||||||
|
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[2])
|
||||||
|
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[3])
|
||||||
|
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
|
||||||
|
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
||||||
|
// half2 ctor. In this case, I chose performance reliability over code readability.
|
||||||
|
|
||||||
|
// This is the half2 {1032, 1032} represented as an integer.
|
||||||
|
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||||
|
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
|
||||||
|
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
|
||||||
|
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||||
|
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||||
|
// This is the half2 {-72, -72} represented as an integer.
|
||||||
|
// static constexpr uint32_t NEG_72 = 0xd480d480;
|
||||||
|
// Haotian: Let's use {-64, -64}.
|
||||||
|
static constexpr uint32_t NEG_64 = 0xd400d400;
|
||||||
|
|
||||||
|
// Finally, we construct the output numbers.
|
||||||
|
// Convert elt_01
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||||
|
// Convert elt_23
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||||
|
// Convert elt_45
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||||
|
// Convert elt_67
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace awq
|
||||||
|
} // namespace vllm
|
||||||
491
csrc/quantization/awq/gemm_kernels.cu
Normal file
491
csrc/quantization/awq/gemm_kernels.cu
Normal file
@@ -0,0 +1,491 @@
|
|||||||
|
/*
|
||||||
|
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||||
|
@article{lin2023awq,
|
||||||
|
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||||
|
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||||
|
journal={arXiv},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include "dequantize.cuh"
|
||||||
|
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace awq {
|
||||||
|
|
||||||
|
// Pack two half values.
|
||||||
|
static inline __device__ __host__ unsigned
|
||||||
|
__pack_half2(const half x, const half y) {
|
||||||
|
unsigned v0 = *((unsigned short *)&x);
|
||||||
|
unsigned v1 = *((unsigned short *)&y);
|
||||||
|
return (v1 << 16) | v0;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||||
|
{
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
static constexpr uint32_t ZERO = 0x0;
|
||||||
|
float C_warp[32];
|
||||||
|
__shared__ half A_shared[16 * (32 + 8)];
|
||||||
|
__shared__ half B_shared[32 * (128 + 8)];
|
||||||
|
|
||||||
|
__shared__ half scaling_factors_shared[128];
|
||||||
|
__shared__ half zeros_shared[128];
|
||||||
|
|
||||||
|
int j_factors1 = ((OC + 128 - 1) / 128);
|
||||||
|
int blockIdx_x = 0;
|
||||||
|
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
|
||||||
|
half A_shared_warp[8];
|
||||||
|
half B_shared_warp[32];
|
||||||
|
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||||
|
static constexpr int row_stride = 2 * 32 * 8 / 128;
|
||||||
|
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||||
|
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||||
|
|
||||||
|
half* A_ptr = A
|
||||||
|
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||||
|
|
||||||
|
int* B_ptr = B
|
||||||
|
+ ((int)threadIdx.y) * (OC / 8) * 2
|
||||||
|
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||||
|
+ (((int)threadIdx.x) % (128 / 8)) * 1;
|
||||||
|
// Why * 1 in the above line?
|
||||||
|
|
||||||
|
half* A_shared_ptr = A_shared
|
||||||
|
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||||
|
|
||||||
|
half* B_shared_ptr = B_shared
|
||||||
|
+ ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||||
|
|
||||||
|
int* zeros_ptr = zeros
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||||
|
+ ((int)threadIdx.x) % (128 / 8);
|
||||||
|
|
||||||
|
half* scaling_factors_ptr = scaling_factors
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (128)
|
||||||
|
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||||
|
|
||||||
|
half* C_ptr = C
|
||||||
|
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * 128
|
||||||
|
+ ((int)threadIdx.y) * 64
|
||||||
|
+ (((int)threadIdx.x) % 4) * 2;
|
||||||
|
|
||||||
|
// preload s.f. and zeros
|
||||||
|
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||||
|
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||||
|
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||||
|
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||||
|
__syncthreads();
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
if (ld_A_flag)
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||||
|
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||||
|
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||||
|
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||||
|
/*
|
||||||
|
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||||
|
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||||
|
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||||
|
|
||||||
|
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
|
||||||
|
|
||||||
|
// B: 32 x 136 (128+8) float16
|
||||||
|
// each warp: 32 x 4
|
||||||
|
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||||
|
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||||
|
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||||
|
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||||
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
|
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
|
||||||
|
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
// - zero and * scale
|
||||||
|
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
|
/*
|
||||||
|
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||||
|
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// write back
|
||||||
|
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Shang: Hoist loop invariance.
|
||||||
|
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
||||||
|
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||||
|
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||||
|
if (row_offset < M)
|
||||||
|
{
|
||||||
|
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||||
|
{
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
static constexpr uint32_t ZERO = 0x0;
|
||||||
|
float C_warp[32];
|
||||||
|
__shared__ half A_shared[16 * (32 + 8)];
|
||||||
|
__shared__ half B_shared[32 * (64 + 8)];
|
||||||
|
|
||||||
|
__shared__ half scaling_factors_shared[64];
|
||||||
|
__shared__ half zeros_shared[64];
|
||||||
|
|
||||||
|
int j_factors1 = ((OC + 64 - 1) / 64);
|
||||||
|
|
||||||
|
int blockIdx_x = 0;
|
||||||
|
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
|
||||||
|
half A_shared_warp[8];
|
||||||
|
half B_shared_warp[16];
|
||||||
|
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||||
|
static constexpr int row_stride = 2 * 32 * 8 / 64;
|
||||||
|
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||||
|
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||||
|
|
||||||
|
half* A_ptr = A
|
||||||
|
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||||
|
|
||||||
|
int* B_ptr = B
|
||||||
|
+ ((int)threadIdx.y) * (OC / 8) * 4
|
||||||
|
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||||
|
+ (((int)threadIdx.x) % (64 / 8)) * 1;
|
||||||
|
// Why * 1 in the above line?
|
||||||
|
|
||||||
|
half* A_shared_ptr = A_shared
|
||||||
|
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||||
|
|
||||||
|
half* B_shared_ptr = B_shared
|
||||||
|
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||||
|
|
||||||
|
int* zeros_ptr = zeros
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||||
|
+ ((int)threadIdx.x) % (64 / 8);
|
||||||
|
|
||||||
|
half* scaling_factors_ptr = scaling_factors
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (64)
|
||||||
|
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||||
|
|
||||||
|
half* C_ptr = C
|
||||||
|
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * 64
|
||||||
|
+ ((int)threadIdx.y) * 32
|
||||||
|
+ (((int)threadIdx.x) % 4) * 2;
|
||||||
|
|
||||||
|
// preload s.f. and zeros
|
||||||
|
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||||
|
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||||
|
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||||
|
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||||
|
__syncthreads();
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
if (ld_A_flag)
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||||
|
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||||
|
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||||
|
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||||
|
/*
|
||||||
|
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||||
|
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||||
|
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||||
|
|
||||||
|
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
|
||||||
|
|
||||||
|
// B: 32 x 136 (128+8) float16
|
||||||
|
// each warp: 32 x 4
|
||||||
|
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||||
|
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||||
|
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||||
|
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||||
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
|
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
|
||||||
|
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
// - zero and * scale
|
||||||
|
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
|
/*
|
||||||
|
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||||
|
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// write back
|
||||||
|
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
|
||||||
|
{
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
|
||||||
|
{
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
|
||||||
|
{
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Shang: Hoist loop invariance.
|
||||||
|
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
|
||||||
|
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||||
|
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||||
|
if (row_offset < M)
|
||||||
|
{
|
||||||
|
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace awq
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
// in_feats: M, IC [float16]
|
||||||
|
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
||||||
|
// scaling_factors: IC // G, OC [float16]
|
||||||
|
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
||||||
|
// assume that batch_size < 16 for now
|
||||||
|
|
||||||
|
torch::Tensor awq_gemm(
|
||||||
|
torch::Tensor _in_feats,
|
||||||
|
torch::Tensor _kernel,
|
||||||
|
torch::Tensor _scaling_factors,
|
||||||
|
torch::Tensor _zeros,
|
||||||
|
int split_k_iters)
|
||||||
|
{
|
||||||
|
int num_in_feats = _in_feats.size(0);
|
||||||
|
int num_in_channels = _in_feats.size(1);
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||||
|
|
||||||
|
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
|
||||||
|
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||||
|
int num_out_feats = _out_feats.size(-2);
|
||||||
|
int num_out_channels = _out_feats.size(-1);
|
||||||
|
|
||||||
|
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
||||||
|
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||||
|
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||||
|
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||||
|
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||||
|
int group_size = num_in_channels / _scaling_factors.size(0);
|
||||||
|
|
||||||
|
if (num_out_channels % 64 != 0)
|
||||||
|
throw std::invalid_argument("OC is not multiple of cta_N = 64");
|
||||||
|
if (num_out_channels % 8 != 0)
|
||||||
|
throw std::invalid_argument("OC is not multiple of pack_num = 8");
|
||||||
|
if (group_size % 32 != 0)
|
||||||
|
throw std::invalid_argument("Group size should be a multiple of 32");
|
||||||
|
if (num_out_channels % group_size != 0)
|
||||||
|
throw std::invalid_argument("OC is not multiple of Group size");
|
||||||
|
|
||||||
|
if (num_out_channels % 128 == 0)
|
||||||
|
{
|
||||||
|
int j_factors1 = num_out_channels / 128 / 1;
|
||||||
|
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||||
|
// threadIdx.x: 32
|
||||||
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
|
dim3 threads_per_block(32, 2);
|
||||||
|
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
|
||||||
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
|
}
|
||||||
|
else if (num_out_channels % 64 == 0)
|
||||||
|
{
|
||||||
|
int j_factors1 = num_out_channels / 64 / 1;
|
||||||
|
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||||
|
|
||||||
|
// threadIdx.x: 32
|
||||||
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
|
dim3 threads_per_block(32, 2);
|
||||||
|
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
|
||||||
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
|
}
|
||||||
|
return _out_feats.sum(0);
|
||||||
|
}
|
||||||
@@ -3,31 +3,15 @@
|
|||||||
Installation
|
Installation
|
||||||
============
|
============
|
||||||
|
|
||||||
vLLM is a Python library that also contains some C++ and CUDA code.
|
vLLM is a Python library that also contains pre-compiled C++ and CUDA (11.8) binaries.
|
||||||
This additional code requires compilation on the user's machine.
|
|
||||||
|
|
||||||
Requirements
|
Requirements
|
||||||
------------
|
------------
|
||||||
|
|
||||||
* OS: Linux
|
* OS: Linux
|
||||||
* Python: 3.8 or higher
|
* Python: 3.8 -- 3.11
|
||||||
* CUDA: 11.0 -- 11.8
|
|
||||||
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
|
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
|
||||||
|
|
||||||
.. note::
|
|
||||||
As of now, vLLM does not support CUDA 12.
|
|
||||||
If you are using Hopper or Lovelace GPUs, please use CUDA 11.8 instead of CUDA 12.
|
|
||||||
|
|
||||||
.. tip::
|
|
||||||
If you have trouble installing vLLM, we recommend using the NVIDIA PyTorch Docker image.
|
|
||||||
|
|
||||||
.. code-block:: console
|
|
||||||
|
|
||||||
$ # Pull the Docker image with CUDA 11.8.
|
|
||||||
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
|
|
||||||
|
|
||||||
Inside the Docker container, please execute :code:`pip uninstall torch` before installing vLLM.
|
|
||||||
|
|
||||||
Install with pip
|
Install with pip
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
@@ -40,7 +24,7 @@ You can install vLLM using pip:
|
|||||||
$ conda activate myenv
|
$ conda activate myenv
|
||||||
|
|
||||||
$ # Install vLLM.
|
$ # Install vLLM.
|
||||||
$ pip install vllm # This may take 5-10 minutes.
|
$ pip install vllm
|
||||||
|
|
||||||
|
|
||||||
.. _build_from_source:
|
.. _build_from_source:
|
||||||
@@ -55,3 +39,12 @@ You can also build and install vLLM from source:
|
|||||||
$ git clone https://github.com/vllm-project/vllm.git
|
$ git clone https://github.com/vllm-project/vllm.git
|
||||||
$ cd vllm
|
$ cd vllm
|
||||||
$ pip install -e . # This may take 5-10 minutes.
|
$ pip install -e . # This may take 5-10 minutes.
|
||||||
|
|
||||||
|
.. tip::
|
||||||
|
If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image.
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ # Pull the Docker image with CUDA 11.8.
|
||||||
|
$ # Use `--ipc=host` to make sure the shared memory is large enough.
|
||||||
|
$ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:22.12-py3
|
||||||
|
|||||||
@@ -128,4 +128,4 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep
|
|||||||
prompt="San Francisco is a")
|
prompt="San Francisco is a")
|
||||||
print("Completion result:", completion)
|
print("Completion result:", completion)
|
||||||
|
|
||||||
For a more detailed client example, refer to `examples/openai_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_client.py>`_.
|
For a more detailed client example, refer to `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ vLLM is flexible and easy to use with:
|
|||||||
For more information, check out the following:
|
For more information, check out the following:
|
||||||
|
|
||||||
* `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention)
|
* `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention)
|
||||||
|
* `vLLM paper <https://arxiv.org/abs/2309.06180>`_ (SOSP 2023)
|
||||||
* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al.
|
* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al.
|
||||||
|
|
||||||
|
|
||||||
@@ -63,6 +64,7 @@ Documentation
|
|||||||
|
|
||||||
serving/distributed_serving
|
serving/distributed_serving
|
||||||
serving/run_on_sky
|
serving/run_on_sky
|
||||||
|
serving/deploying_with_triton
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|||||||
@@ -44,6 +44,9 @@ Alongside each architecture, we include some popular models that use it.
|
|||||||
* - :code:`LlamaForCausalLM`
|
* - :code:`LlamaForCausalLM`
|
||||||
- LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
|
- LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
|
||||||
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc.
|
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc.
|
||||||
|
* - :code:`MistralForCausalLM`
|
||||||
|
- Mistral, Mistral-Instruct
|
||||||
|
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
|
||||||
* - :code:`MPTForCausalLM`
|
* - :code:`MPTForCausalLM`
|
||||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||||
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
||||||
|
|||||||
6
docs/source/serving/deploying_with_triton.rst
Normal file
6
docs/source/serving/deploying_with_triton.rst
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
.. _deploying_with_triton:
|
||||||
|
|
||||||
|
Deploying with NVIDIA Triton
|
||||||
|
============================
|
||||||
|
|
||||||
|
The `Triton Inference Server <https://github.com/triton-inference-server>`_ hosts a tutorial demonstrating how to quickly deploy a simple `facebook/opt-125m <https://huggingface.co/facebook/opt-125m>`_ model using vLLM. Please see `Deploying a vLLM model in Triton <https://github.com/triton-inference-server/tutorials/blob/main/Quick_Deploy/vLLM/README.md#deploying-a-vllm-model-in-triton>`_ for more details.
|
||||||
@@ -11,3 +11,4 @@ types-setuptools
|
|||||||
# testing
|
# testing
|
||||||
pytest
|
pytest
|
||||||
pytest-forked
|
pytest-forked
|
||||||
|
pytest-asyncio
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
ninja # For faster builds.
|
ninja # For faster builds.
|
||||||
psutil
|
psutil
|
||||||
ray >= 2.5.1
|
ray >= 2.5.1
|
||||||
|
pandas # Required for Ray data.
|
||||||
|
pyarrow # Required for Ray data.
|
||||||
sentencepiece # Required for LLaMA tokenizer.
|
sentencepiece # Required for LLaMA tokenizer.
|
||||||
numpy
|
numpy
|
||||||
torch >= 2.0.0
|
torch >= 2.0.0
|
||||||
transformers >= 4.33.1 # Required for Code Llama.
|
transformers >= 4.33.1 # Required for Code Llama.
|
||||||
xformers >= 0.0.21
|
xformers >= 0.0.22
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn[standard]
|
||||||
pydantic < 2 # Required for OpenAI server.
|
pydantic < 2 # Required for OpenAI server.
|
||||||
|
|||||||
171
setup.py
171
setup.py
@@ -3,6 +3,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import List, Set
|
from typing import List, Set
|
||||||
|
import warnings
|
||||||
|
|
||||||
from packaging.version import parse, Version
|
from packaging.version import parse, Version
|
||||||
import setuptools
|
import setuptools
|
||||||
@@ -11,6 +12,9 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
|||||||
|
|
||||||
ROOT_DIR = os.path.dirname(__file__)
|
ROOT_DIR = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
# Supported NVIDIA GPU architectures.
|
||||||
|
SUPPORTED_ARCHS = ["7.0", "7.5", "8.0", "8.6", "8.9", "9.0"]
|
||||||
|
|
||||||
# Compiler flags.
|
# Compiler flags.
|
||||||
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
|
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
|
||||||
# TODO(woosuk): Should we use -O3?
|
# TODO(woosuk): Should we use -O3?
|
||||||
@@ -22,7 +26,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
|||||||
|
|
||||||
if CUDA_HOME is None:
|
if CUDA_HOME is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
||||||
|
|
||||||
|
|
||||||
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||||
@@ -38,47 +42,82 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
|||||||
return nvcc_cuda_version
|
return nvcc_cuda_version
|
||||||
|
|
||||||
|
|
||||||
# Collect the compute capabilities of all available GPUs.
|
def get_torch_arch_list() -> Set[str]:
|
||||||
device_count = torch.cuda.device_count()
|
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
|
||||||
compute_capabilities: Set[int] = set()
|
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
|
||||||
for i in range(device_count):
|
# compiler to additionally include PTX code that can be runtime-compiled
|
||||||
major, minor = torch.cuda.get_device_capability(i)
|
# and executed on the 8.6 or newer architectures. While the PTX code will
|
||||||
if major < 7:
|
# not give the best performance on the newer architectures, it provides
|
||||||
raise RuntimeError(
|
# forward compatibility.
|
||||||
"GPUs with compute capability less than 7.0 are not supported.")
|
valid_arch_strs = SUPPORTED_ARCHS + [s + "+PTX" for s in SUPPORTED_ARCHS]
|
||||||
compute_capabilities.add(major * 10 + minor)
|
arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
|
||||||
|
if arch_list is None:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
# List are separated by ; or space.
|
||||||
|
arch_list = arch_list.replace(" ", ";").split(";")
|
||||||
|
for arch in arch_list:
|
||||||
|
if arch not in valid_arch_strs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported CUDA arch ({arch}). "
|
||||||
|
f"Valid CUDA arch strings are: {valid_arch_strs}.")
|
||||||
|
return set(arch_list)
|
||||||
|
|
||||||
|
|
||||||
|
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
|
||||||
|
compute_capabilities = get_torch_arch_list()
|
||||||
|
if not compute_capabilities:
|
||||||
|
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
|
||||||
|
# GPUs on the current machine.
|
||||||
|
device_count = torch.cuda.device_count()
|
||||||
|
for i in range(device_count):
|
||||||
|
major, minor = torch.cuda.get_device_capability(i)
|
||||||
|
if major < 7:
|
||||||
|
raise RuntimeError(
|
||||||
|
"GPUs with compute capability below 7.0 are not supported.")
|
||||||
|
compute_capabilities.add(f"{major}.{minor}")
|
||||||
|
|
||||||
|
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||||
|
if not compute_capabilities:
|
||||||
|
# If no GPU is specified nor available, add all supported architectures
|
||||||
|
# based on the NVCC CUDA version.
|
||||||
|
compute_capabilities = set(SUPPORTED_ARCHS)
|
||||||
|
if nvcc_cuda_version < Version("11.1"):
|
||||||
|
compute_capabilities.remove("8.6")
|
||||||
|
if nvcc_cuda_version < Version("11.8"):
|
||||||
|
compute_capabilities.remove("8.9")
|
||||||
|
compute_capabilities.remove("9.0")
|
||||||
|
|
||||||
# Validate the NVCC CUDA version.
|
# Validate the NVCC CUDA version.
|
||||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
|
||||||
if nvcc_cuda_version < Version("11.0"):
|
if nvcc_cuda_version < Version("11.0"):
|
||||||
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
|
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
|
||||||
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
|
if nvcc_cuda_version < Version("11.1"):
|
||||||
raise RuntimeError(
|
if any(cc.startswith("8.6") for cc in compute_capabilities):
|
||||||
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
|
raise RuntimeError(
|
||||||
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
"CUDA 11.1 or higher is required for compute capability 8.6.")
|
||||||
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
if nvcc_cuda_version < Version("11.8"):
|
||||||
# However, GPUs with compute capability 8.9 can also run the code generated by
|
if any(cc.startswith("8.9") for cc in compute_capabilities):
|
||||||
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||||
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||||
# instead of 8.9.
|
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
||||||
compute_capabilities.remove(89)
|
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
||||||
compute_capabilities.add(80)
|
# instead of 8.9.
|
||||||
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
warnings.warn(
|
||||||
raise RuntimeError(
|
"CUDA 11.8 or higher is required for compute capability 8.9. "
|
||||||
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
|
"Targeting compute capability 8.0 instead.")
|
||||||
|
compute_capabilities = set(cc for cc in compute_capabilities
|
||||||
# If no GPU is available, add all supported compute capabilities.
|
if not cc.startswith("8.9"))
|
||||||
if not compute_capabilities:
|
compute_capabilities.add("8.0+PTX")
|
||||||
compute_capabilities = {70, 75, 80}
|
if any(cc.startswith("9.0") for cc in compute_capabilities):
|
||||||
if nvcc_cuda_version >= Version("11.1"):
|
raise RuntimeError(
|
||||||
compute_capabilities.add(86)
|
"CUDA 11.8 or higher is required for compute capability 9.0.")
|
||||||
if nvcc_cuda_version >= Version("11.8"):
|
|
||||||
compute_capabilities.add(89)
|
|
||||||
compute_capabilities.add(90)
|
|
||||||
|
|
||||||
# Add target compute capabilities to NVCC flags.
|
# Add target compute capabilities to NVCC flags.
|
||||||
for capability in compute_capabilities:
|
for capability in compute_capabilities:
|
||||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
|
num = capability[0] + capability[2]
|
||||||
|
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
|
||||||
|
if capability.endswith("+PTX"):
|
||||||
|
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
|
||||||
|
|
||||||
# Use NVCC threads to parallelize the build.
|
# Use NVCC threads to parallelize the build.
|
||||||
if nvcc_cuda_version >= Version("11.2"):
|
if nvcc_cuda_version >= Version("11.2"):
|
||||||
@@ -91,7 +130,10 @@ ext_modules = []
|
|||||||
cache_extension = CUDAExtension(
|
cache_extension = CUDAExtension(
|
||||||
name="vllm.cache_ops",
|
name="vllm.cache_ops",
|
||||||
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
|
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
|
||||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(cache_extension)
|
ext_modules.append(cache_extension)
|
||||||
|
|
||||||
@@ -99,7 +141,10 @@ ext_modules.append(cache_extension)
|
|||||||
attention_extension = CUDAExtension(
|
attention_extension = CUDAExtension(
|
||||||
name="vllm.attention_ops",
|
name="vllm.attention_ops",
|
||||||
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
|
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
|
||||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(attention_extension)
|
ext_modules.append(attention_extension)
|
||||||
|
|
||||||
@@ -107,7 +152,10 @@ ext_modules.append(attention_extension)
|
|||||||
positional_encoding_extension = CUDAExtension(
|
positional_encoding_extension = CUDAExtension(
|
||||||
name="vllm.pos_encoding_ops",
|
name="vllm.pos_encoding_ops",
|
||||||
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
|
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
|
||||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(positional_encoding_extension)
|
ext_modules.append(positional_encoding_extension)
|
||||||
|
|
||||||
@@ -115,7 +163,10 @@ ext_modules.append(positional_encoding_extension)
|
|||||||
layernorm_extension = CUDAExtension(
|
layernorm_extension = CUDAExtension(
|
||||||
name="vllm.layernorm_ops",
|
name="vllm.layernorm_ops",
|
||||||
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
|
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
|
||||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(layernorm_extension)
|
ext_modules.append(layernorm_extension)
|
||||||
|
|
||||||
@@ -123,10 +174,38 @@ ext_modules.append(layernorm_extension)
|
|||||||
activation_extension = CUDAExtension(
|
activation_extension = CUDAExtension(
|
||||||
name="vllm.activation_ops",
|
name="vllm.activation_ops",
|
||||||
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
|
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
|
||||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(activation_extension)
|
ext_modules.append(activation_extension)
|
||||||
|
|
||||||
|
# Quantization kernels.
|
||||||
|
quantization_extension = CUDAExtension(
|
||||||
|
name="vllm.quantization_ops",
|
||||||
|
sources=[
|
||||||
|
"csrc/quantization.cpp",
|
||||||
|
"csrc/quantization/awq/gemm_kernels.cu",
|
||||||
|
],
|
||||||
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ext_modules.append(quantization_extension)
|
||||||
|
|
||||||
|
# Misc. CUDA utils.
|
||||||
|
cuda_utils_extension = CUDAExtension(
|
||||||
|
name="vllm.cuda_utils",
|
||||||
|
sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"],
|
||||||
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ext_modules.append(cuda_utils_extension)
|
||||||
|
|
||||||
|
|
||||||
def get_path(*filepath) -> str:
|
def get_path(*filepath) -> str:
|
||||||
return os.path.join(ROOT_DIR, *filepath)
|
return os.path.join(ROOT_DIR, *filepath)
|
||||||
@@ -138,8 +217,8 @@ def find_version(filepath: str):
|
|||||||
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
|
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
|
||||||
"""
|
"""
|
||||||
with open(filepath) as fp:
|
with open(filepath) as fp:
|
||||||
version_match = re.search(
|
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
|
||||||
r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M)
|
fp.read(), re.M)
|
||||||
if version_match:
|
if version_match:
|
||||||
return version_match.group(1)
|
return version_match.group(1)
|
||||||
raise RuntimeError("Unable to find version string.")
|
raise RuntimeError("Unable to find version string.")
|
||||||
@@ -162,7 +241,8 @@ setuptools.setup(
|
|||||||
version=find_version(get_path("vllm", "__init__.py")),
|
version=find_version(get_path("vllm", "__init__.py")),
|
||||||
author="vLLM Team",
|
author="vLLM Team",
|
||||||
license="Apache 2.0",
|
license="Apache 2.0",
|
||||||
description="A high-throughput and memory-efficient inference and serving engine for LLMs",
|
description=("A high-throughput and memory-efficient inference and "
|
||||||
|
"serving engine for LLMs"),
|
||||||
long_description=read_readme(),
|
long_description=read_readme(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
url="https://github.com/vllm-project/vllm",
|
url="https://github.com/vllm-project/vllm",
|
||||||
@@ -174,11 +254,12 @@ setuptools.setup(
|
|||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.9",
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
"License :: OSI Approved :: Apache Software License",
|
"License :: OSI Approved :: Apache Software License",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
],
|
],
|
||||||
packages=setuptools.find_packages(
|
packages=setuptools.find_packages(exclude=("benchmarks", "csrc", "docs",
|
||||||
exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")),
|
"examples", "tests")),
|
||||||
python_requires=">=3.8",
|
python_requires=">=3.8",
|
||||||
install_requires=get_requirements(),
|
install_requires=get_requirements(),
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
|
|||||||
80
tests/async_engine/test_async_llm_engine.py
Normal file
80
tests/async_engine/test_async_llm_engine.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RequestOutput:
|
||||||
|
request_id: int
|
||||||
|
finished: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MockEngine:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.step_calls = 0
|
||||||
|
self.add_request_calls = 0
|
||||||
|
self.abort_request_calls = 0
|
||||||
|
self.request_id = None
|
||||||
|
|
||||||
|
async def step_async(self):
|
||||||
|
self.step_calls += 1
|
||||||
|
return [RequestOutput(
|
||||||
|
request_id=self.request_id)] if self.request_id else []
|
||||||
|
|
||||||
|
def generate(self, request_id):
|
||||||
|
self.request_id = request_id
|
||||||
|
|
||||||
|
def stop_generating(self):
|
||||||
|
self.request_id = None
|
||||||
|
|
||||||
|
def add_request(self, **kwargs):
|
||||||
|
self.add_request_calls += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
def abort_request(self, request_id):
|
||||||
|
self.abort_request_calls += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncLLMEngine(AsyncLLMEngine):
|
||||||
|
|
||||||
|
def _init_engine(self, *args, **kwargs):
|
||||||
|
return MockEngine()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_requests_event():
|
||||||
|
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
|
||||||
|
engine.start_background_loop()
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.step_calls == 0
|
||||||
|
|
||||||
|
await engine.add_request("1", "", None)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.add_request_calls == 1
|
||||||
|
assert engine.engine.step_calls == 1
|
||||||
|
|
||||||
|
await engine.add_request("2", "", None)
|
||||||
|
engine.engine.generate("2")
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.add_request_calls == 2
|
||||||
|
assert engine.engine.step_calls == 2
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.step_calls == 3
|
||||||
|
engine.engine.stop_generating()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.step_calls == 4
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.step_calls == 4
|
||||||
|
|
||||||
|
await engine.add_request("3", "", None)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.add_request_calls == 3
|
||||||
|
assert engine.engine.step_calls == 5
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.add_request_calls == 3
|
||||||
|
assert engine.engine.step_calls == 5
|
||||||
@@ -4,10 +4,25 @@ from vllm.engine.async_llm_engine import RequestTracker
|
|||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
|
|
||||||
|
|
||||||
|
class DummyEvent:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._flag = False
|
||||||
|
|
||||||
|
def set(self):
|
||||||
|
self._flag = True
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self._flag = False
|
||||||
|
|
||||||
|
|
||||||
def test_request_tracker():
|
def test_request_tracker():
|
||||||
tracker = RequestTracker()
|
tracker = RequestTracker()
|
||||||
|
tracker.new_requests_event = DummyEvent()
|
||||||
stream_1 = tracker.add_request("1")
|
stream_1 = tracker.add_request("1")
|
||||||
|
assert tracker.new_requests_event._flag
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert not tracker.new_requests_event._flag
|
||||||
assert len(new) == 1
|
assert len(new) == 1
|
||||||
assert new[0]["request_id"] == "1"
|
assert new[0]["request_id"] == "1"
|
||||||
assert not finished
|
assert not finished
|
||||||
@@ -15,7 +30,9 @@ def test_request_tracker():
|
|||||||
|
|
||||||
stream_2 = tracker.add_request("2")
|
stream_2 = tracker.add_request("2")
|
||||||
stream_3 = tracker.add_request("3")
|
stream_3 = tracker.add_request("3")
|
||||||
|
assert tracker.new_requests_event._flag
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert not tracker.new_requests_event._flag
|
||||||
assert len(new) == 2
|
assert len(new) == 2
|
||||||
assert new[0]["request_id"] == "2"
|
assert new[0]["request_id"] == "2"
|
||||||
assert new[1]["request_id"] == "3"
|
assert new[1]["request_id"] == "3"
|
||||||
@@ -26,6 +43,7 @@ def test_request_tracker():
|
|||||||
# request_ids must be unique
|
# request_ids must be unique
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
tracker.add_request("1")
|
tracker.add_request("1")
|
||||||
|
assert not tracker.new_requests_event._flag
|
||||||
|
|
||||||
tracker.abort_request("1")
|
tracker.abort_request("1")
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
@@ -36,6 +54,7 @@ def test_request_tracker():
|
|||||||
|
|
||||||
stream_4 = tracker.add_request("4")
|
stream_4 = tracker.add_request("4")
|
||||||
tracker.abort_request("4")
|
tracker.abort_request("4")
|
||||||
|
assert tracker.new_requests_event._flag
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
assert len(finished) == 1
|
assert len(finished) == 1
|
||||||
assert "4" in finished
|
assert "4" in finished
|
||||||
@@ -43,9 +62,11 @@ def test_request_tracker():
|
|||||||
assert stream_4.finished
|
assert stream_4.finished
|
||||||
|
|
||||||
stream_5 = tracker.add_request("5")
|
stream_5 = tracker.add_request("5")
|
||||||
|
assert tracker.new_requests_event._flag
|
||||||
tracker.process_request_output(
|
tracker.process_request_output(
|
||||||
RequestOutput("2", "output", [], [], finished=True))
|
RequestOutput("2", "output", [], [], finished=True))
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert not tracker.new_requests_event._flag
|
||||||
assert len(finished) == 1
|
assert len(finished) == 1
|
||||||
assert "2" in finished
|
assert "2" in finished
|
||||||
assert len(new) == 1
|
assert len(new) == 1
|
||||||
|
|||||||
62
tests/engine/test_detokenize.py
Normal file
62
tests/engine/test_detokenize.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.transformers_utils.tokenizer import detokenize_incrementally
|
||||||
|
|
||||||
|
TRUTH = [
|
||||||
|
"Hello here, this is a simple test",
|
||||||
|
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",
|
||||||
|
"我很感谢你的热情"
|
||||||
|
]
|
||||||
|
TOKENIZERS = [
|
||||||
|
"facebook/opt-125m",
|
||||||
|
"gpt2",
|
||||||
|
"bigcode/tiny_starcoder_py",
|
||||||
|
"EleutherAI/gpt-j-6b",
|
||||||
|
"EleutherAI/pythia-70m",
|
||||||
|
"bigscience/bloom-560m",
|
||||||
|
"mosaicml/mpt-7b",
|
||||||
|
"tiiuae/falcon-7b",
|
||||||
|
"meta-llama/Llama-2-7b-hf",
|
||||||
|
"codellama/CodeLlama-7b-hf",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _run_incremental_decode(tokenizer, all_input_ids,
|
||||||
|
skip_special_tokens: bool):
|
||||||
|
decoded_text = ""
|
||||||
|
offset = 0
|
||||||
|
token_offset = 0
|
||||||
|
prev_tokens = None
|
||||||
|
for i in range(len(all_input_ids)):
|
||||||
|
new_tokens, text, offset, token_offset = detokenize_incrementally(
|
||||||
|
tokenizer,
|
||||||
|
all_input_ids[:i + 1],
|
||||||
|
prev_tokens,
|
||||||
|
offset,
|
||||||
|
token_offset,
|
||||||
|
skip_special_tokens=skip_special_tokens)
|
||||||
|
decoded_text += text
|
||||||
|
if prev_tokens is None:
|
||||||
|
prev_tokens = new_tokens
|
||||||
|
else:
|
||||||
|
prev_tokens += new_tokens
|
||||||
|
return decoded_text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("truth", TRUTH)
|
||||||
|
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
|
||||||
|
@pytest.mark.parametrize("skip_special_tokens", (True, False))
|
||||||
|
def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||||
|
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||||
|
if skip_special_tokens:
|
||||||
|
all_input_ids = ([tokenizer.bos_token_id]
|
||||||
|
if tokenizer.bos_token_id is not None else
|
||||||
|
[]) + all_input_ids + [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
decoded_text = _run_incremental_decode(
|
||||||
|
tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
|
assert decoded_text == truth
|
||||||
@@ -7,8 +7,12 @@ from xformers import ops as xops
|
|||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||||
|
|
||||||
from vllm import attention_ops
|
from vllm import attention_ops
|
||||||
|
from vllm.utils import get_max_shared_memory_bytes
|
||||||
|
|
||||||
MAX_SEQ_LEN = 8192
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
|
# This will change depending on the compute capability.
|
||||||
|
# - 512 as a buffer
|
||||||
|
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||||
NUM_BLOCKS = 128 # Arbitrary values for testing
|
NUM_BLOCKS = 128 # Arbitrary values for testing
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
@@ -135,6 +139,7 @@ def test_single_query_cached_kv_attention(
|
|||||||
device="cuda")
|
device="cuda")
|
||||||
|
|
||||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||||
|
context_lens[-1] = MAX_SEQ_LEN
|
||||||
max_context_len = max(context_lens)
|
max_context_len = max(context_lens)
|
||||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||||
|
|
||||||
@@ -242,7 +247,11 @@ def test_multi_query_kv_attention(
|
|||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
||||||
|
# As the xformers library is already tested with its own tests, we can use
|
||||||
|
# a smaller MAX_SEQ_LEN here.
|
||||||
|
max_len = min(MAX_SEQ_LEN, 4096)
|
||||||
|
seq_lens = random.sample(range(1, max_len), num_seqs)
|
||||||
num_tokens = sum(seq_lens)
|
num_tokens = sum(seq_lens)
|
||||||
|
|
||||||
scale = float(1.0 / (head_size**0.5))
|
scale = float(1.0 / (head_size**0.5))
|
||||||
|
|||||||
@@ -133,9 +133,10 @@ def test_rotary_embedding(
|
|||||||
device="cuda")
|
device="cuda")
|
||||||
|
|
||||||
# Create the rotary embedding.
|
# Create the rotary embedding.
|
||||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
inv_freq = 1.0 / (base**(
|
||||||
|
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||||
t = torch.arange(max_position).float()
|
t = torch.arange(max_position).float()
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
|||||||
184
tests/samplers/test_sampler.py
Normal file
184
tests/samplers/test_sampler.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
import pytest
|
||||||
|
import random
|
||||||
|
from typing import Tuple
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
|
|
||||||
|
class MockLogitsSampler(Sampler):
|
||||||
|
|
||||||
|
def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
|
||||||
|
super().__init__(vocab_size=vocab_size)
|
||||||
|
self.fake_logits = fake_logits
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
|
||||||
|
lambda x, y: x):
|
||||||
|
with patch("vllm.model_executor.layers.sampler._get_logits",
|
||||||
|
lambda *args, **kwargs: self.fake_logits):
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_test(
|
||||||
|
batch_size: int
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
|
||||||
|
vocab_size = 32000
|
||||||
|
input_tensor = torch.rand((batch_size, 1024),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float16)
|
||||||
|
fake_logits = torch.full((batch_size, vocab_size),
|
||||||
|
1e-2,
|
||||||
|
device=input_tensor.device,
|
||||||
|
dtype=input_tensor.dtype)
|
||||||
|
sampler = MockLogitsSampler(32000, fake_logits)
|
||||||
|
worker = Worker(None, None, None)
|
||||||
|
worker.block_size = 16
|
||||||
|
return input_tensor, fake_logits, sampler, worker
|
||||||
|
|
||||||
|
|
||||||
|
RANDOM_SEEDS = list(range(128))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
def test_sampler_all_greedy(seed: int):
|
||||||
|
set_random_seed(seed)
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||||
|
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: SequenceData([1, 2, 3])},
|
||||||
|
sampling_params=SamplingParams(temperature=0, ),
|
||||||
|
block_tables={0: [1]},
|
||||||
|
))
|
||||||
|
|
||||||
|
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||||
|
sampler_output = sampler(embedding=None,
|
||||||
|
hidden_states=input_tensor,
|
||||||
|
input_metadata=input_metadata)
|
||||||
|
expected = torch.argmax(fake_logits, dim=-1)
|
||||||
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
|
for nth_output in sequence_output:
|
||||||
|
assert nth_output.output_token == expected[i].item()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
def test_sampler_all_random(seed: int):
|
||||||
|
set_random_seed(seed)
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
fake_logits[i, i] = 1e2
|
||||||
|
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: SequenceData([1, 2, 3])},
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
|
n=random.randint(1, 10),
|
||||||
|
),
|
||||||
|
block_tables={0: [1]},
|
||||||
|
))
|
||||||
|
|
||||||
|
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||||
|
sampler_output = sampler(embedding=None,
|
||||||
|
hidden_states=input_tensor,
|
||||||
|
input_metadata=input_metadata)
|
||||||
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
|
for nth_output in sequence_output:
|
||||||
|
assert nth_output.output_token == i
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
def test_sampler_all_beam(seed: int):
|
||||||
|
set_random_seed(seed)
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||||
|
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: SequenceData([1, 2, 3])},
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
temperature=0,
|
||||||
|
best_of=2,
|
||||||
|
use_beam_search=True,
|
||||||
|
),
|
||||||
|
block_tables={0: [1]},
|
||||||
|
))
|
||||||
|
|
||||||
|
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||||
|
sampler(embedding=None,
|
||||||
|
hidden_states=input_tensor,
|
||||||
|
input_metadata=input_metadata)
|
||||||
|
# no assertion here as I am not sure how to determine whether
|
||||||
|
# the outputs are expected - in other words, this just tests
|
||||||
|
# whether there are no exceptions in the sampler
|
||||||
|
# when handling an all-beam search case.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
def test_sampler_mixed(seed: int):
|
||||||
|
set_random_seed(seed)
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||||
|
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
expected_tokens = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
n = 1
|
||||||
|
sampling_type = random.randint(0, 2)
|
||||||
|
if sampling_type == 0:
|
||||||
|
sampling_params = SamplingParams(temperature=0)
|
||||||
|
elif sampling_type == 1:
|
||||||
|
n = random.randint(1, 10)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=random.random() + 0.1,
|
||||||
|
top_p=min(random.random() + 0.1, 1),
|
||||||
|
top_k=random.randint(0, 10) or -1,
|
||||||
|
n=n,
|
||||||
|
presence_penalty=random.randint(0, 1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sampling_params = SamplingParams(temperature=0,
|
||||||
|
use_beam_search=True,
|
||||||
|
best_of=2)
|
||||||
|
for idx in range(n):
|
||||||
|
fake_logits[i, i + idx] = 1e2
|
||||||
|
expected_tokens.append(i + idx)
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: SequenceData([1, 2, 3])},
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
block_tables={0: [1]},
|
||||||
|
))
|
||||||
|
|
||||||
|
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||||
|
sampler_output = sampler(embedding=None,
|
||||||
|
hidden_states=input_tensor,
|
||||||
|
input_metadata=input_metadata)
|
||||||
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
|
if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
||||||
|
continue
|
||||||
|
for nth_output in sequence_output:
|
||||||
|
assert nth_output.output_token in expected_tokens
|
||||||
@@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
|
|||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
__version__ = "0.1.6"
|
__version__ = "0.2.0"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LLM",
|
"LLM",
|
||||||
|
|||||||
144
vllm/config.py
144
vllm/config.py
@@ -38,6 +38,13 @@ class ModelConfig:
|
|||||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||||
for BF16 models.
|
for BF16 models.
|
||||||
seed: Random seed for reproducibility.
|
seed: Random seed for reproducibility.
|
||||||
|
revision: The specific model version to use. It can be a branch name,
|
||||||
|
a tag name, or a commit id. If unspecified, will use the default
|
||||||
|
version.
|
||||||
|
max_model_len: Maximum length of a sequence (including prompt and
|
||||||
|
output). If None, will be derived from the model.
|
||||||
|
quantization: Quantization method that was used to quantize the model
|
||||||
|
weights. If None, we assume the model weights are not quantized.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -50,6 +57,9 @@ class ModelConfig:
|
|||||||
load_format: str,
|
load_format: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
max_model_len: Optional[int] = None,
|
||||||
|
quantization: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
@@ -58,11 +68,16 @@ class ModelConfig:
|
|||||||
self.download_dir = download_dir
|
self.download_dir = download_dir
|
||||||
self.load_format = load_format
|
self.load_format = load_format
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
self.revision = revision
|
||||||
|
self.quantization = quantization
|
||||||
|
|
||||||
self.hf_config = get_config(model, trust_remote_code)
|
self.hf_config = get_config(model, trust_remote_code, revision)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||||
|
self.max_model_len = _get_and_verify_max_len(self.hf_config,
|
||||||
|
max_model_len)
|
||||||
self._verify_load_format()
|
self._verify_load_format()
|
||||||
self._verify_tokenizer_mode()
|
self._verify_tokenizer_mode()
|
||||||
|
self._verify_quantization()
|
||||||
|
|
||||||
def _verify_load_format(self) -> None:
|
def _verify_load_format(self) -> None:
|
||||||
load_format = self.load_format.lower()
|
load_format = self.load_format.lower()
|
||||||
@@ -82,6 +97,17 @@ class ModelConfig:
|
|||||||
"either 'auto' or 'slow'.")
|
"either 'auto' or 'slow'.")
|
||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
|
|
||||||
|
def _verify_quantization(self) -> None:
|
||||||
|
supported_quantization = ["awq"]
|
||||||
|
if self.quantization is None:
|
||||||
|
return
|
||||||
|
quantization = self.quantization.lower()
|
||||||
|
if quantization not in supported_quantization:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown quantization: {self.quantization}. Must be one of "
|
||||||
|
f"{supported_quantization}.")
|
||||||
|
self.quantization = quantization
|
||||||
|
|
||||||
def verify_with_parallel_config(
|
def verify_with_parallel_config(
|
||||||
self,
|
self,
|
||||||
parallel_config: "ParallelConfig",
|
parallel_config: "ParallelConfig",
|
||||||
@@ -109,22 +135,28 @@ class ModelConfig:
|
|||||||
# FIXME(woosuk): This may not be true for all models.
|
# FIXME(woosuk): This may not be true for all models.
|
||||||
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||||
|
|
||||||
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||||
|
"""Returns the number of KV heads per GPU worker."""
|
||||||
# For GPTBigCode & Falcon:
|
# For GPTBigCode & Falcon:
|
||||||
# Note: for falcon, when new_decoder_architecture is True, the
|
# Note: for falcon, when new_decoder_architecture is True, the
|
||||||
# multi_query flag is ignored and we use n_head_kv for the number of
|
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||||
# KV heads.
|
# KV heads.
|
||||||
|
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
||||||
new_decoder_arch_falcon = (
|
new_decoder_arch_falcon = (
|
||||||
self.hf_config.model_type == "falcon"
|
self.hf_config.model_type in falcon_model_types
|
||||||
and getattr(self.hf_config, "new_decoder_architecture", False))
|
and getattr(self.hf_config, "new_decoder_architecture", False))
|
||||||
if not new_decoder_arch_falcon and getattr(self.hf_config,
|
if not new_decoder_arch_falcon and getattr(self.hf_config,
|
||||||
"multi_query", False):
|
"multi_query", False):
|
||||||
# Multi-query attention, only one KV head.
|
# Multi-query attention, only one KV head.
|
||||||
|
# Currently, tensor parallelism is not supported in this case.
|
||||||
return 1
|
return 1
|
||||||
# For Falcon:
|
# For Falcon:
|
||||||
if getattr(self.hf_config, "n_head_kv", None) is not None:
|
if getattr(self.hf_config, "n_head_kv", None) is not None:
|
||||||
return (self.hf_config.n_head_kv //
|
return (self.hf_config.n_head_kv //
|
||||||
parallel_config.tensor_parallel_size)
|
parallel_config.tensor_parallel_size)
|
||||||
|
if getattr(self.hf_config, "num_kv_heads", None) is not None:
|
||||||
|
return (self.hf_config.num_kv_heads //
|
||||||
|
parallel_config.tensor_parallel_size)
|
||||||
# For LLaMA-2:
|
# For LLaMA-2:
|
||||||
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
|
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
|
||||||
return (self.hf_config.num_key_value_heads //
|
return (self.hf_config.num_key_value_heads //
|
||||||
@@ -132,26 +164,6 @@ class ModelConfig:
|
|||||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
def get_max_model_len(self) -> int:
|
|
||||||
max_model_len = float("inf")
|
|
||||||
possible_keys = [
|
|
||||||
# OPT
|
|
||||||
"max_position_embeddings",
|
|
||||||
# GPT-2
|
|
||||||
"n_positions",
|
|
||||||
# MPT
|
|
||||||
"max_seq_len",
|
|
||||||
# Others
|
|
||||||
"max_sequence_length",
|
|
||||||
"max_seq_length",
|
|
||||||
"seq_len",
|
|
||||||
]
|
|
||||||
for key in possible_keys:
|
|
||||||
max_len_key = getattr(self.hf_config, key, None)
|
|
||||||
if max_len_key is not None:
|
|
||||||
max_model_len = min(max_model_len, max_len_key)
|
|
||||||
return max_model_len
|
|
||||||
|
|
||||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||||
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||||
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
||||||
@@ -172,10 +184,12 @@ class CacheConfig:
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
gpu_memory_utilization: float,
|
gpu_memory_utilization: float,
|
||||||
swap_space: int,
|
swap_space: int,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.gpu_memory_utilization = gpu_memory_utilization
|
self.gpu_memory_utilization = gpu_memory_utilization
|
||||||
self.swap_space_bytes = swap_space * _GB
|
self.swap_space_bytes = swap_space * _GB
|
||||||
|
self.sliding_window = sliding_window
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
# Will be set after profiling.
|
# Will be set after profiling.
|
||||||
@@ -251,11 +265,36 @@ class SchedulerConfig:
|
|||||||
and generated text).
|
and generated text).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
|
def __init__(
|
||||||
max_model_len: int) -> None:
|
self,
|
||||||
self.max_num_batched_tokens = max_num_batched_tokens
|
max_num_batched_tokens: Optional[int],
|
||||||
|
max_num_seqs: int,
|
||||||
|
max_model_len: int,
|
||||||
|
) -> None:
|
||||||
|
if max_num_batched_tokens is not None:
|
||||||
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
|
else:
|
||||||
|
# If max_model_len is too short, use 2048 as the default value for
|
||||||
|
# higher throughput.
|
||||||
|
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||||
self.max_num_seqs = max_num_seqs
|
self.max_num_seqs = max_num_seqs
|
||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
|
self._verify_args()
|
||||||
|
|
||||||
|
def _verify_args(self) -> None:
|
||||||
|
if self.max_num_batched_tokens < self.max_model_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
|
||||||
|
f"smaller than max_model_len ({self.max_model_len}). "
|
||||||
|
"This effectively limits the maximum sequence length to "
|
||||||
|
"max_num_batched_tokens and makes vLLM reject longer "
|
||||||
|
"sequences. Please increase max_num_batched_tokens or "
|
||||||
|
"decrease max_model_len.")
|
||||||
|
if self.max_num_batched_tokens < self.max_num_seqs:
|
||||||
|
raise ValueError(
|
||||||
|
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
|
||||||
|
"be greater than or equal to max_num_seqs "
|
||||||
|
f"({self.max_num_seqs}).")
|
||||||
|
|
||||||
|
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
@@ -311,3 +350,56 @@ def _get_and_verify_dtype(
|
|||||||
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
||||||
f"{compute_capability[0]}.{compute_capability[1]}.")
|
f"{compute_capability[0]}.{compute_capability[1]}.")
|
||||||
return torch_dtype
|
return torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _get_and_verify_max_len(
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
max_model_len: Optional[int],
|
||||||
|
) -> int:
|
||||||
|
"""Get and verify the model's maximum length."""
|
||||||
|
derived_max_model_len = float("inf")
|
||||||
|
possible_keys = [
|
||||||
|
# OPT
|
||||||
|
"max_position_embeddings",
|
||||||
|
# GPT-2
|
||||||
|
"n_positions",
|
||||||
|
# MPT
|
||||||
|
"max_seq_len",
|
||||||
|
# Others
|
||||||
|
"max_sequence_length",
|
||||||
|
"max_seq_length",
|
||||||
|
"seq_len",
|
||||||
|
]
|
||||||
|
for key in possible_keys:
|
||||||
|
max_len_key = getattr(hf_config, key, None)
|
||||||
|
if max_len_key is not None:
|
||||||
|
derived_max_model_len = min(derived_max_model_len, max_len_key)
|
||||||
|
if derived_max_model_len == float("inf"):
|
||||||
|
if max_model_len is not None:
|
||||||
|
# If max_model_len is specified, we use it.
|
||||||
|
return max_model_len
|
||||||
|
|
||||||
|
default_max_len = 2048
|
||||||
|
logger.warning(
|
||||||
|
"The model's config.json does not contain any of the following "
|
||||||
|
"keys to determine the original maximum length of the model: "
|
||||||
|
f"{possible_keys}. Assuming the model's maximum length is "
|
||||||
|
f"{default_max_len}.")
|
||||||
|
derived_max_model_len = default_max_len
|
||||||
|
|
||||||
|
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||||
|
if rope_scaling is not None:
|
||||||
|
assert "factor" in rope_scaling
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
derived_max_model_len *= scaling_factor
|
||||||
|
|
||||||
|
if max_model_len is None:
|
||||||
|
max_model_len = derived_max_model_len
|
||||||
|
elif max_model_len > derived_max_model_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"User-specified max_model_len ({max_model_len}) is greater than "
|
||||||
|
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
|
||||||
|
" in model's config.json). This may lead to incorrect model "
|
||||||
|
"outputs or CUDA errors. Make sure the value is correct and "
|
||||||
|
"within the model context size.")
|
||||||
|
return int(max_model_len)
|
||||||
|
|||||||
@@ -63,10 +63,18 @@ class BlockSpaceManager:
|
|||||||
num_gpu_blocks: int,
|
num_gpu_blocks: int,
|
||||||
num_cpu_blocks: int,
|
num_cpu_blocks: int,
|
||||||
watermark: float = 0.01,
|
watermark: float = 0.01,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.num_total_gpu_blocks = num_gpu_blocks
|
self.num_total_gpu_blocks = num_gpu_blocks
|
||||||
self.num_total_cpu_blocks = num_cpu_blocks
|
self.num_total_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
|
self.block_sliding_window = None
|
||||||
|
if sliding_window is not None:
|
||||||
|
assert sliding_window % block_size == 0, (sliding_window,
|
||||||
|
block_size)
|
||||||
|
self.block_sliding_window = sliding_window // block_size
|
||||||
|
|
||||||
self.watermark = watermark
|
self.watermark = watermark
|
||||||
assert watermark >= 0.0
|
assert watermark >= 0.0
|
||||||
|
|
||||||
@@ -83,6 +91,9 @@ class BlockSpaceManager:
|
|||||||
# the same prompt. This may not be true for preempted sequences.
|
# the same prompt. This may not be true for preempted sequences.
|
||||||
seq = seq_group.get_seqs()[0]
|
seq = seq_group.get_seqs()[0]
|
||||||
num_required_blocks = len(seq.logical_token_blocks)
|
num_required_blocks = len(seq.logical_token_blocks)
|
||||||
|
if self.block_sliding_window is not None:
|
||||||
|
num_required_blocks = min(num_required_blocks,
|
||||||
|
self.block_sliding_window)
|
||||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||||
# Use watermark to avoid frequent cache eviction.
|
# Use watermark to avoid frequent cache eviction.
|
||||||
return (num_free_gpu_blocks - num_required_blocks >=
|
return (num_free_gpu_blocks - num_required_blocks >=
|
||||||
@@ -95,8 +106,12 @@ class BlockSpaceManager:
|
|||||||
|
|
||||||
# Allocate new physical token blocks that will store the prompt tokens.
|
# Allocate new physical token blocks that will store the prompt tokens.
|
||||||
block_table: BlockTable = []
|
block_table: BlockTable = []
|
||||||
for _ in range(len(seq.logical_token_blocks)):
|
for logical_idx in range(len(seq.logical_token_blocks)):
|
||||||
block = self.gpu_allocator.allocate()
|
if (self.block_sliding_window is not None
|
||||||
|
and logical_idx >= self.block_sliding_window):
|
||||||
|
block = block_table[logical_idx % self.block_sliding_window]
|
||||||
|
else:
|
||||||
|
block = self.gpu_allocator.allocate()
|
||||||
# Set the reference counts of the token blocks.
|
# Set the reference counts of the token blocks.
|
||||||
block.ref_count = seq_group.num_seqs()
|
block.ref_count = seq_group.num_seqs()
|
||||||
block_table.append(block)
|
block_table.append(block)
|
||||||
@@ -118,11 +133,17 @@ class BlockSpaceManager:
|
|||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
|
|
||||||
if len(block_table) < len(logical_blocks):
|
if len(block_table) < len(logical_blocks):
|
||||||
# The sequence has a new logical block.
|
if (self.block_sliding_window
|
||||||
# Allocate a new physical block.
|
and len(block_table) >= self.block_sliding_window):
|
||||||
block = self.gpu_allocator.allocate()
|
# re-use a block
|
||||||
block_table.append(block)
|
block_table.append(block_table[len(block_table) %
|
||||||
return None
|
self.block_sliding_window])
|
||||||
|
else:
|
||||||
|
# The sequence has a new logical block.
|
||||||
|
# Allocate a new physical block.
|
||||||
|
block = self.gpu_allocator.allocate()
|
||||||
|
block_table.append(block)
|
||||||
|
return None
|
||||||
|
|
||||||
# We want to append the token to the last physical block.
|
# We want to append the token to the last physical block.
|
||||||
last_block = block_table[-1]
|
last_block = block_table[-1]
|
||||||
@@ -154,9 +175,7 @@ class BlockSpaceManager:
|
|||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs():
|
||||||
if seq.is_finished():
|
if seq.is_finished():
|
||||||
continue
|
continue
|
||||||
block_table = self.block_tables[seq.seq_id]
|
blocks.update(self.block_tables[seq.seq_id])
|
||||||
for block in block_table:
|
|
||||||
blocks.add(block)
|
|
||||||
return list(blocks)
|
return list(blocks)
|
||||||
|
|
||||||
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
||||||
@@ -224,7 +243,7 @@ class BlockSpaceManager:
|
|||||||
return block_number_mapping
|
return block_number_mapping
|
||||||
|
|
||||||
def _free_block_table(self, block_table: BlockTable) -> None:
|
def _free_block_table(self, block_table: BlockTable) -> None:
|
||||||
for block in block_table:
|
for block in set(block_table):
|
||||||
if block.device == Device.GPU:
|
if block.device == Device.GPU:
|
||||||
self.gpu_allocator.free(block)
|
self.gpu_allocator.free(block)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class Scheduler:
|
|||||||
block_size=self.cache_config.block_size,
|
block_size=self.cache_config.block_size,
|
||||||
num_gpu_blocks=self.cache_config.num_gpu_blocks,
|
num_gpu_blocks=self.cache_config.num_gpu_blocks,
|
||||||
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
||||||
)
|
sliding_window=self.cache_config.sliding_window)
|
||||||
|
|
||||||
# TODO(zhuohan): Use deque instead of list for better performance.
|
# TODO(zhuohan): Use deque instead of list for better performance.
|
||||||
# Sequence groups in the WAITING state.
|
# Sequence groups in the WAITING state.
|
||||||
@@ -175,7 +175,7 @@ class Scheduler:
|
|||||||
num_curr_seqs += num_new_seqs
|
num_curr_seqs += num_new_seqs
|
||||||
scheduled.append(seq_group)
|
scheduled.append(seq_group)
|
||||||
|
|
||||||
if scheduled:
|
if scheduled or ignored_seq_groups:
|
||||||
scheduler_outputs = SchedulerOutputs(
|
scheduler_outputs = SchedulerOutputs(
|
||||||
scheduled_seq_groups=scheduled,
|
scheduled_seq_groups=scheduled,
|
||||||
prompt_run=True,
|
prompt_run=True,
|
||||||
|
|||||||
@@ -18,20 +18,22 @@ class EngineArgs:
|
|||||||
load_format: str = 'auto'
|
load_format: str = 'auto'
|
||||||
dtype: str = 'auto'
|
dtype: str = 'auto'
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
|
max_model_len: Optional[int] = None
|
||||||
worker_use_ray: bool = False
|
worker_use_ray: bool = False
|
||||||
pipeline_parallel_size: int = 1
|
pipeline_parallel_size: int = 1
|
||||||
tensor_parallel_size: int = 1
|
tensor_parallel_size: int = 1
|
||||||
block_size: int = 16
|
block_size: int = 16
|
||||||
swap_space: int = 4 # GiB
|
swap_space: int = 4 # GiB
|
||||||
gpu_memory_utilization: float = 0.90
|
gpu_memory_utilization: float = 0.90
|
||||||
max_num_batched_tokens: int = 2560
|
max_num_batched_tokens: Optional[int] = None
|
||||||
max_num_seqs: int = 256
|
max_num_seqs: int = 256
|
||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
|
revision: Optional[str] = None
|
||||||
|
quantization: Optional[str] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
self.tokenizer = self.model
|
self.tokenizer = self.model
|
||||||
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(
|
def add_cli_args(
|
||||||
@@ -48,6 +50,13 @@ class EngineArgs:
|
|||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.tokenizer,
|
default=EngineArgs.tokenizer,
|
||||||
help='name or path of the huggingface tokenizer to use')
|
help='name or path of the huggingface tokenizer to use')
|
||||||
|
parser.add_argument(
|
||||||
|
'--revision',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='the specific model version to use. It can be a branch '
|
||||||
|
'name, a tag name, or a commit id. If unspecified, will use '
|
||||||
|
'the default version.')
|
||||||
parser.add_argument('--tokenizer-mode',
|
parser.add_argument('--tokenizer-mode',
|
||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.tokenizer_mode,
|
default=EngineArgs.tokenizer_mode,
|
||||||
@@ -79,16 +88,22 @@ class EngineArgs:
|
|||||||
'a numpy cache to speed up the loading. '
|
'a numpy cache to speed up the loading. '
|
||||||
'"dummy" will initialize the weights with random values, '
|
'"dummy" will initialize the weights with random values, '
|
||||||
'which is mainly for profiling.')
|
'which is mainly for profiling.')
|
||||||
# TODO(woosuk): Support FP32.
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dtype',
|
'--dtype',
|
||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.dtype,
|
default=EngineArgs.dtype,
|
||||||
choices=['auto', 'half', 'bfloat16', 'float'],
|
choices=[
|
||||||
|
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
|
||||||
|
],
|
||||||
help='data type for model weights and activations. '
|
help='data type for model weights and activations. '
|
||||||
'The "auto" option will use FP16 precision '
|
'The "auto" option will use FP16 precision '
|
||||||
'for FP32 and FP16 models, and BF16 precision '
|
'for FP32 and FP16 models, and BF16 precision '
|
||||||
'for BF16 models.')
|
'for BF16 models.')
|
||||||
|
parser.add_argument('--max-model-len',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='model context length. If unspecified, '
|
||||||
|
'will be automatically derived from the model.')
|
||||||
# Parallel arguments
|
# Parallel arguments
|
||||||
parser.add_argument('--worker-use-ray',
|
parser.add_argument('--worker-use-ray',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
@@ -136,6 +151,13 @@ class EngineArgs:
|
|||||||
parser.add_argument('--disable-log-stats',
|
parser.add_argument('--disable-log-stats',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='disable logging statistics')
|
help='disable logging statistics')
|
||||||
|
# Quantization settings.
|
||||||
|
parser.add_argument('--quantization',
|
||||||
|
'-q',
|
||||||
|
type=str,
|
||||||
|
choices=['awq', None],
|
||||||
|
default=None,
|
||||||
|
help='Method used to quantize the weights')
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -149,20 +171,20 @@ class EngineArgs:
|
|||||||
def create_engine_configs(
|
def create_engine_configs(
|
||||||
self,
|
self,
|
||||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
||||||
# Initialize the configs.
|
|
||||||
model_config = ModelConfig(self.model, self.tokenizer,
|
model_config = ModelConfig(self.model, self.tokenizer,
|
||||||
self.tokenizer_mode, self.trust_remote_code,
|
self.tokenizer_mode, self.trust_remote_code,
|
||||||
self.download_dir, self.load_format,
|
self.download_dir, self.load_format,
|
||||||
self.dtype, self.seed)
|
self.dtype, self.seed, self.revision,
|
||||||
cache_config = CacheConfig(self.block_size,
|
self.max_model_len, self.quantization)
|
||||||
self.gpu_memory_utilization,
|
cache_config = CacheConfig(
|
||||||
self.swap_space)
|
self.block_size, self.gpu_memory_utilization, self.swap_space,
|
||||||
|
getattr(model_config.hf_config, 'sliding_window', None))
|
||||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||||
self.tensor_parallel_size,
|
self.tensor_parallel_size,
|
||||||
self.worker_use_ray)
|
self.worker_use_ray)
|
||||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||||
self.max_num_seqs,
|
self.max_num_seqs,
|
||||||
model_config.get_max_model_len())
|
model_config.max_model_len)
|
||||||
return model_config, cache_config, parallel_config, scheduler_config
|
return model_config, cache_config, parallel_config, scheduler_config
|
||||||
|
|
||||||
|
|
||||||
@@ -171,6 +193,7 @@ class AsyncEngineArgs(EngineArgs):
|
|||||||
"""Arguments for asynchronous vLLM engine."""
|
"""Arguments for asynchronous vLLM engine."""
|
||||||
engine_use_ray: bool = False
|
engine_use_ray: bool = False
|
||||||
disable_log_requests: bool = False
|
disable_log_requests: bool = False
|
||||||
|
max_log_len: Optional[int] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(
|
def add_cli_args(
|
||||||
@@ -183,4 +206,10 @@ class AsyncEngineArgs(EngineArgs):
|
|||||||
parser.add_argument('--disable-log-requests',
|
parser.add_argument('--disable-log-requests',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='disable logging requests')
|
help='disable logging requests')
|
||||||
|
parser.add_argument('--max-log-len',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='max number of prompt characters or prompt '
|
||||||
|
'ID numbers being printed in log. '
|
||||||
|
'Default: unlimited.')
|
||||||
return parser
|
return parser
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||||
|
Union)
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
@@ -78,14 +79,24 @@ class RequestTracker:
|
|||||||
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||||
dict]] = asyncio.Queue()
|
dict]] = asyncio.Queue()
|
||||||
|
self.new_requests_event = None
|
||||||
|
|
||||||
def __contains__(self, item):
|
def __contains__(self, item):
|
||||||
return item in self._request_streams
|
return item in self._request_streams
|
||||||
|
|
||||||
def propagate_exception(self, exc: Exception) -> None:
|
def init_event(self):
|
||||||
"""Propagate an exception to all request streams."""
|
self.new_requests_event = asyncio.Event()
|
||||||
for stream in self._request_streams.values():
|
|
||||||
stream.put(exc)
|
def propagate_exception(self,
|
||||||
|
exc: Exception,
|
||||||
|
request_id: Optional[str] = None) -> None:
|
||||||
|
"""Propagate an exception to request streams
|
||||||
|
(all if request_id is None)."""
|
||||||
|
if request_id is not None:
|
||||||
|
self._request_streams[request_id].put(exc)
|
||||||
|
else:
|
||||||
|
for stream in self._request_streams.values():
|
||||||
|
stream.put(exc)
|
||||||
|
|
||||||
def process_request_output(self,
|
def process_request_output(self,
|
||||||
request_output: RequestOutput,
|
request_output: RequestOutput,
|
||||||
@@ -112,6 +123,9 @@ class RequestTracker:
|
|||||||
"request_id": request_id,
|
"request_id": request_id,
|
||||||
**engine_add_request_kwargs
|
**engine_add_request_kwargs
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
self.new_requests_event.set()
|
||||||
|
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
||||||
@@ -148,8 +162,13 @@ class RequestTracker:
|
|||||||
self._request_streams[stream.request_id] = stream
|
self._request_streams[stream.request_id] = stream
|
||||||
new_requests.append(new_request)
|
new_requests.append(new_request)
|
||||||
|
|
||||||
|
self.new_requests_event.clear()
|
||||||
|
|
||||||
return new_requests, finished_requests
|
return new_requests, finished_requests
|
||||||
|
|
||||||
|
async def wait_for_new_requests(self):
|
||||||
|
await self.new_requests_event.wait()
|
||||||
|
|
||||||
|
|
||||||
class _AsyncLLMEngine(LLMEngine):
|
class _AsyncLLMEngine(LLMEngine):
|
||||||
"""Extension of LLMEngine to add async methods."""
|
"""Extension of LLMEngine to add async methods."""
|
||||||
@@ -164,10 +183,9 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
and updates the scheduler with the model outputs. Finally, it decodes
|
and updates the scheduler with the model outputs. Finally, it decodes
|
||||||
the sequences and returns the newly generated results.
|
the sequences and returns the newly generated results.
|
||||||
"""
|
"""
|
||||||
(seq_group_metadata_list, scheduler_outputs,
|
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
|
||||||
early_return) = self._schedule()
|
if scheduler_outputs.is_empty():
|
||||||
if early_return is not None:
|
return ignored
|
||||||
return early_return
|
|
||||||
|
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
output = await self._run_workers_async(
|
output = await self._run_workers_async(
|
||||||
@@ -178,7 +196,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._process_model_outputs(output, scheduler_outputs)
|
return self._process_model_outputs(output, scheduler_outputs) + ignored
|
||||||
|
|
||||||
async def _run_workers_async(
|
async def _run_workers_async(
|
||||||
self,
|
self,
|
||||||
@@ -242,16 +260,22 @@ class AsyncLLMEngine:
|
|||||||
engine_use_ray: bool,
|
engine_use_ray: bool,
|
||||||
*args,
|
*args,
|
||||||
log_requests: bool = True,
|
log_requests: bool = True,
|
||||||
|
max_log_len: Optional[int] = None,
|
||||||
start_engine_loop: bool = True,
|
start_engine_loop: bool = True,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
self.worker_use_ray = worker_use_ray
|
self.worker_use_ray = worker_use_ray
|
||||||
self.engine_use_ray = engine_use_ray
|
self.engine_use_ray = engine_use_ray
|
||||||
self.log_requests = log_requests
|
self.log_requests = log_requests
|
||||||
|
self.max_log_len = max_log_len
|
||||||
self.engine = self._init_engine(*args, **kwargs)
|
self.engine = self._init_engine(*args, **kwargs)
|
||||||
|
|
||||||
self.request_tracker: RequestTracker = RequestTracker()
|
|
||||||
self.background_loop = None
|
self.background_loop = None
|
||||||
|
# We need to keep a reference to unshielded
|
||||||
|
# task as well to prevent it from being garbage
|
||||||
|
# collected
|
||||||
|
self._background_loop_unshielded = None
|
||||||
self.start_engine_loop = start_engine_loop
|
self.start_engine_loop = start_engine_loop
|
||||||
|
self._request_tracker = RequestTracker()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
@@ -262,11 +286,14 @@ class AsyncLLMEngine:
|
|||||||
"""Start the background loop."""
|
"""Start the background loop."""
|
||||||
if self.is_running:
|
if self.is_running:
|
||||||
raise RuntimeError("Background loop is already running.")
|
raise RuntimeError("Background loop is already running.")
|
||||||
self.background_loop = asyncio.get_event_loop().create_task(
|
self._request_tracker.init_event()
|
||||||
self.run_engine_loop())
|
|
||||||
self.background_loop.add_done_callback(
|
self._background_loop_unshielded = asyncio.get_event_loop(
|
||||||
|
).create_task(self.run_engine_loop())
|
||||||
|
self._background_loop_unshielded.add_done_callback(
|
||||||
partial(_raise_exception_on_finish,
|
partial(_raise_exception_on_finish,
|
||||||
request_tracker=self.request_tracker))
|
request_tracker=self._request_tracker))
|
||||||
|
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||||
|
|
||||||
def _init_engine(self, *args,
|
def _init_engine(self, *args,
|
||||||
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
||||||
@@ -278,11 +305,13 @@ class AsyncLLMEngine:
|
|||||||
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
|
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
|
||||||
return engine_class(*args, **kwargs)
|
return engine_class(*args, **kwargs)
|
||||||
|
|
||||||
async def engine_step(self):
|
async def engine_step(self) -> bool:
|
||||||
"""Kick the engine to process the waiting requests."""
|
"""Kick the engine to process the waiting requests.
|
||||||
|
|
||||||
|
Returns True if there are in-progress requests."""
|
||||||
|
|
||||||
new_requests, finished_requests = (
|
new_requests, finished_requests = (
|
||||||
self.request_tracker.get_new_and_finished_requests())
|
self._request_tracker.get_new_and_finished_requests())
|
||||||
|
|
||||||
for new_request in new_requests:
|
for new_request in new_requests:
|
||||||
# Add the request into the vLLM engine's waiting queue.
|
# Add the request into the vLLM engine's waiting queue.
|
||||||
@@ -302,9 +331,11 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
# Put the outputs into the corresponding streams.
|
# Put the outputs into the corresponding streams.
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
self.request_tracker.process_request_output(
|
self._request_tracker.process_request_output(
|
||||||
request_output, verbose=self.log_requests)
|
request_output, verbose=self.log_requests)
|
||||||
|
|
||||||
|
return len(request_outputs) > 0
|
||||||
|
|
||||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
await self.engine.abort_request.remote(request_ids)
|
await self.engine.abort_request.remote(request_ids)
|
||||||
@@ -312,8 +343,12 @@ class AsyncLLMEngine:
|
|||||||
self.engine.abort_request(request_ids)
|
self.engine.abort_request(request_ids)
|
||||||
|
|
||||||
async def run_engine_loop(self):
|
async def run_engine_loop(self):
|
||||||
|
# Initialize the RequestTracker here so it uses the right event loop.
|
||||||
|
has_requests_in_progress = False
|
||||||
while True:
|
while True:
|
||||||
await self.engine_step()
|
if not has_requests_in_progress:
|
||||||
|
await self._request_tracker.wait_for_new_requests()
|
||||||
|
has_requests_in_progress = await self.engine_step()
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
async def add_request(
|
async def add_request(
|
||||||
@@ -325,10 +360,18 @@ class AsyncLLMEngine:
|
|||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
) -> AsyncStream:
|
) -> AsyncStream:
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
|
shortened_prompt = prompt
|
||||||
|
shortened_token_ids = prompt_token_ids
|
||||||
|
if self.max_log_len is not None:
|
||||||
|
if shortened_prompt is not None:
|
||||||
|
shortened_prompt = shortened_prompt[:self.max_log_len]
|
||||||
|
if shortened_token_ids is not None:
|
||||||
|
shortened_token_ids = shortened_token_ids[:self.
|
||||||
|
max_log_len]
|
||||||
logger.info(f"Received request {request_id}: "
|
logger.info(f"Received request {request_id}: "
|
||||||
f"prompt: {prompt!r}, "
|
f"prompt: {shortened_prompt!r}, "
|
||||||
f"sampling params: {sampling_params}, "
|
f"sampling params: {sampling_params}, "
|
||||||
f"prompt token ids: {prompt_token_ids}.")
|
f"prompt token ids: {shortened_token_ids}.")
|
||||||
|
|
||||||
if not self.is_running:
|
if not self.is_running:
|
||||||
if self.start_engine_loop:
|
if self.start_engine_loop:
|
||||||
@@ -340,7 +383,7 @@ class AsyncLLMEngine:
|
|||||||
"error that caused the background loop to stop "
|
"error that caused the background loop to stop "
|
||||||
"(AsyncEngineDeadError).")
|
"(AsyncEngineDeadError).")
|
||||||
|
|
||||||
stream = self.request_tracker.add_request(
|
stream = self._request_tracker.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
@@ -385,8 +428,9 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
async for request_output in stream:
|
async for request_output in stream:
|
||||||
yield request_output
|
yield request_output
|
||||||
except Exception as e:
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
# If there is an exception, abort the request.
|
# If there is an exception or coroutine is cancelled, abort the
|
||||||
|
# request.
|
||||||
self._abort(request_id)
|
self._abort(request_id)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@@ -417,8 +461,8 @@ class AsyncLLMEngine:
|
|||||||
Args:
|
Args:
|
||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
"""
|
"""
|
||||||
self.request_tracker.abort_request(request_id,
|
self._request_tracker.abort_request(request_id,
|
||||||
verbose=self.log_requests)
|
verbose=self.log_requests)
|
||||||
|
|
||||||
async def get_model_config(self) -> ModelConfig:
|
async def get_model_config(self) -> ModelConfig:
|
||||||
"""Get the model configuration of the vLLM engine."""
|
"""Get the model configuration of the vLLM engine."""
|
||||||
@@ -446,5 +490,6 @@ class AsyncLLMEngine:
|
|||||||
placement_group,
|
placement_group,
|
||||||
log_requests=not engine_args.disable_log_requests,
|
log_requests=not engine_args.disable_log_requests,
|
||||||
log_stats=not engine_args.disable_log_stats,
|
log_stats=not engine_args.disable_log_stats,
|
||||||
|
max_log_len=engine_args.max_log_len,
|
||||||
start_engine_loop=start_engine_loop)
|
start_engine_loop=start_engine_loop)
|
||||||
return engine
|
return engine
|
||||||
|
|||||||
@@ -54,8 +54,8 @@ class LLMEngine:
|
|||||||
scheduler_config: The configuration related to the request scheduler.
|
scheduler_config: The configuration related to the request scheduler.
|
||||||
distributed_init_method: The initialization method for distributed
|
distributed_init_method: The initialization method for distributed
|
||||||
execution. See `torch.distributed.init_process_group` for details.
|
execution. See `torch.distributed.init_process_group` for details.
|
||||||
stage_devices: The list of devices for each stage. Each stage is a list
|
placement_group: Ray placement group for distributed execution.
|
||||||
of (rank, node_resource, device) tuples.
|
Required for distributed execution.
|
||||||
log_stats: Whether to log statistics.
|
log_stats: Whether to log statistics.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -74,16 +74,21 @@ class LLMEngine:
|
|||||||
f"model={model_config.model!r}, "
|
f"model={model_config.model!r}, "
|
||||||
f"tokenizer={model_config.tokenizer!r}, "
|
f"tokenizer={model_config.tokenizer!r}, "
|
||||||
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
||||||
|
f"revision={model_config.revision}, "
|
||||||
f"trust_remote_code={model_config.trust_remote_code}, "
|
f"trust_remote_code={model_config.trust_remote_code}, "
|
||||||
f"dtype={model_config.dtype}, "
|
f"dtype={model_config.dtype}, "
|
||||||
|
f"max_seq_len={model_config.max_model_len}, "
|
||||||
f"download_dir={model_config.download_dir!r}, "
|
f"download_dir={model_config.download_dir!r}, "
|
||||||
f"load_format={model_config.load_format}, "
|
f"load_format={model_config.load_format}, "
|
||||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||||
|
f"quantization={model_config.quantization}, "
|
||||||
f"seed={model_config.seed})")
|
f"seed={model_config.seed})")
|
||||||
# TODO(woosuk): Print more configs in debug mode.
|
# TODO(woosuk): Print more configs in debug mode.
|
||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
|
assert self.cache_config.sliding_window == getattr(
|
||||||
|
self.model_config.hf_config, "sliding_window", None)
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
@@ -92,7 +97,8 @@ class LLMEngine:
|
|||||||
self.tokenizer = get_tokenizer(
|
self.tokenizer = get_tokenizer(
|
||||||
model_config.tokenizer,
|
model_config.tokenizer,
|
||||||
tokenizer_mode=model_config.tokenizer_mode,
|
tokenizer_mode=model_config.tokenizer_mode,
|
||||||
trust_remote_code=model_config.trust_remote_code)
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
revision=model_config.revision)
|
||||||
self.seq_counter = Counter()
|
self.seq_counter = Counter()
|
||||||
|
|
||||||
# Create the parallel GPU workers.
|
# Create the parallel GPU workers.
|
||||||
@@ -153,7 +159,7 @@ class LLMEngine:
|
|||||||
placement_group=placement_group,
|
placement_group=placement_group,
|
||||||
placement_group_capture_child_tasks=True),
|
placement_group_capture_child_tasks=True),
|
||||||
**ray_remote_kwargs,
|
**ray_remote_kwargs,
|
||||||
)(RayWorker).remote()
|
)(RayWorker).remote(self.model_config.trust_remote_code)
|
||||||
self.workers.append(worker)
|
self.workers.append(worker)
|
||||||
|
|
||||||
# Initialize torch distributed process group for the workers.
|
# Initialize torch distributed process group for the workers.
|
||||||
@@ -291,14 +297,12 @@ class LLMEngine:
|
|||||||
def _schedule(
|
def _schedule(
|
||||||
self
|
self
|
||||||
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
|
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
|
||||||
Optional[List[RequestOutput]]]:
|
List[RequestOutput]]:
|
||||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||||
if scheduler_outputs.is_empty():
|
return seq_group_metadata_list, scheduler_outputs, [
|
||||||
return seq_group_metadata_list, scheduler_outputs, [
|
RequestOutput.from_seq_group(seq_group)
|
||||||
RequestOutput.from_seq_group(seq_group)
|
for seq_group in scheduler_outputs.ignored_seq_groups
|
||||||
for seq_group in scheduler_outputs.ignored_seq_groups
|
]
|
||||||
]
|
|
||||||
return seq_group_metadata_list, scheduler_outputs, None
|
|
||||||
|
|
||||||
def _check_beam_search_early_stopping(
|
def _check_beam_search_early_stopping(
|
||||||
self,
|
self,
|
||||||
@@ -386,7 +390,7 @@ class LLMEngine:
|
|||||||
child_seqs.append((parent, parent))
|
child_seqs.append((parent, parent))
|
||||||
|
|
||||||
for seq, _ in child_seqs:
|
for seq, _ in child_seqs:
|
||||||
self._decode_sequence(seq)
|
self._decode_sequence(seq, seq_group.sampling_params)
|
||||||
self._check_stop(seq, seq_group.sampling_params)
|
self._check_stop(seq, seq_group.sampling_params)
|
||||||
|
|
||||||
# Non-beam search case
|
# Non-beam search case
|
||||||
@@ -542,10 +546,9 @@ class LLMEngine:
|
|||||||
and updates the scheduler with the model outputs. Finally, it decodes
|
and updates the scheduler with the model outputs. Finally, it decodes
|
||||||
the sequences and returns the newly generated results.
|
the sequences and returns the newly generated results.
|
||||||
"""
|
"""
|
||||||
(seq_group_metadata_list, scheduler_outputs,
|
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
|
||||||
early_return) = self._schedule()
|
if scheduler_outputs.is_empty():
|
||||||
if early_return is not None:
|
return ignored
|
||||||
return early_return
|
|
||||||
|
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
output = self._run_workers(
|
output = self._run_workers(
|
||||||
@@ -556,7 +559,7 @@ class LLMEngine:
|
|||||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._process_model_outputs(output, scheduler_outputs)
|
return self._process_model_outputs(output, scheduler_outputs) + ignored
|
||||||
|
|
||||||
def _log_system_stats(
|
def _log_system_stats(
|
||||||
self,
|
self,
|
||||||
@@ -621,17 +624,25 @@ class LLMEngine:
|
|||||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||||
self.last_logging_time = now
|
self.last_logging_time = now
|
||||||
|
|
||||||
def _decode_sequence(self, seq: Sequence) -> None:
|
def _decode_sequence(self, seq: Sequence,
|
||||||
|
sampling_params: SamplingParams) -> None:
|
||||||
"""Decodes the new token for a sequence."""
|
"""Decodes the new token for a sequence."""
|
||||||
new_token, new_output_text = detokenize_incrementally(
|
(new_tokens, new_output_text, prefix_offset,
|
||||||
self.tokenizer,
|
read_offset) = detokenize_incrementally(
|
||||||
seq.output_tokens,
|
self.tokenizer,
|
||||||
seq.get_last_token_id(),
|
all_input_ids=seq.get_token_ids(),
|
||||||
skip_special_tokens=True,
|
prev_tokens=seq.tokens,
|
||||||
)
|
prefix_offset=seq.prefix_offset,
|
||||||
if new_token is not None:
|
read_offset=seq.read_offset,
|
||||||
seq.output_tokens.append(new_token)
|
skip_special_tokens=sampling_params.skip_special_tokens,
|
||||||
seq.output_text = new_output_text
|
)
|
||||||
|
if seq.tokens is None:
|
||||||
|
seq.tokens = new_tokens
|
||||||
|
else:
|
||||||
|
seq.tokens.extend(new_tokens)
|
||||||
|
seq.prefix_offset = prefix_offset
|
||||||
|
seq.read_offset = read_offset
|
||||||
|
seq.output_text += new_output_text
|
||||||
|
|
||||||
def _check_stop(self, seq: Sequence,
|
def _check_stop(self, seq: Sequence,
|
||||||
sampling_params: SamplingParams) -> None:
|
sampling_params: SamplingParams) -> None:
|
||||||
@@ -643,6 +654,9 @@ class LLMEngine:
|
|||||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
return
|
return
|
||||||
|
if seq.get_last_token_id() in sampling_params.stop_token_ids:
|
||||||
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
|
return
|
||||||
|
|
||||||
# Check if the sequence has reached max_model_len.
|
# Check if the sequence has reached max_model_len.
|
||||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||||
|
|||||||
@@ -2,6 +2,9 @@ import socket
|
|||||||
from typing import Optional, Tuple, TYPE_CHECKING
|
from typing import Optional, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
@@ -11,7 +14,11 @@ try:
|
|||||||
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
|
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
|
||||||
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, init_cached_hf_modules=False) -> None:
|
||||||
|
if init_cached_hf_modules:
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
from transformers.dynamic_module_utils import init_hf_modules
|
||||||
|
init_hf_modules()
|
||||||
self.worker = None
|
self.worker = None
|
||||||
|
|
||||||
def init_worker(self, worker_init_fn):
|
def init_worker(self, worker_init_fn):
|
||||||
@@ -24,7 +31,10 @@ try:
|
|||||||
executor = getattr(self, method)
|
executor = getattr(self, method)
|
||||||
return executor(*args, **kwargs)
|
return executor(*args, **kwargs)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
|
logger.warning(f"Failed to import Ray with {e!r}. "
|
||||||
|
"For distributed inference, please install Ray with "
|
||||||
|
"`pip install ray pandas pyarrow`.")
|
||||||
ray = None
|
ray = None
|
||||||
TorchDistributedWorker = None
|
TorchDistributedWorker = None
|
||||||
RayWorker = None # pylint: disable=invalid-name
|
RayWorker = None # pylint: disable=invalid-name
|
||||||
@@ -53,11 +63,10 @@ def initialize_cluster(
|
|||||||
the default Ray cluster address.
|
the default Ray cluster address.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (`distributed_init_method`, `all_stage_devices`). The
|
A tuple of (`distributed_init_method`, `placement_group`). The
|
||||||
`distributed_init_method` is the address for initializing the
|
`distributed_init_method` is the address for initializing the
|
||||||
distributed backend. `all_stage_devices` includes device IDs for
|
distributed backend. `placement_group` includes the specification
|
||||||
each worker in each pipeline stage. Each device ID is a tuple of
|
of the resources for each distributed worker.
|
||||||
(rank, node resource, device id).
|
|
||||||
"""
|
"""
|
||||||
if parallel_config.worker_use_ray or engine_use_ray:
|
if parallel_config.worker_use_ray or engine_use_ray:
|
||||||
if ray is None:
|
if ray is None:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from fastapi import BackgroundTasks, FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
@@ -44,14 +44,8 @@ async def generate(request: Request) -> Response:
|
|||||||
ret = {"text": text_outputs}
|
ret = {"text": text_outputs}
|
||||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||||
|
|
||||||
async def abort_request() -> None:
|
|
||||||
await engine.abort(request_id)
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
background_tasks = BackgroundTasks()
|
return StreamingResponse(stream_results())
|
||||||
# Abort the request if the client disconnects.
|
|
||||||
background_tasks.add_task(abort_request)
|
|
||||||
return StreamingResponse(stream_results(), background=background_tasks)
|
|
||||||
|
|
||||||
# Non-streaming case
|
# Non-streaming case
|
||||||
final_output = None
|
final_output = None
|
||||||
|
|||||||
@@ -37,7 +37,22 @@ class LLM:
|
|||||||
the `torch_dtype` attribute specified in the model config file.
|
the `torch_dtype` attribute specified in the model config file.
|
||||||
However, if the `torch_dtype` in the config is `float32`, we will
|
However, if the `torch_dtype` in the config is `float32`, we will
|
||||||
use `float16` instead.
|
use `float16` instead.
|
||||||
|
quantization: The method used to quantize the model weights. Currently,
|
||||||
|
we support "awq". If None, we assume the model weights are not
|
||||||
|
quantized and use `dtype` to determine the data type of the weights.
|
||||||
|
revision: The specific model version to use. It can be a branch name,
|
||||||
|
a tag name, or a commit id.
|
||||||
seed: The seed to initialize the random number generator for sampling.
|
seed: The seed to initialize the random number generator for sampling.
|
||||||
|
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
||||||
|
reserve for the model weights, activations, and KV cache. Higher
|
||||||
|
values will increase the KV cache size and thus improve the model's
|
||||||
|
throughput. However, if the value is too high, it may cause out-of-
|
||||||
|
memory (OOM) errors.
|
||||||
|
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||||
|
This can be used for temporarily storing the states of the requests
|
||||||
|
when their `best_of` sampling parameters are larger than 1. If all
|
||||||
|
requests will have `best_of=1`, you can safely set this to 0.
|
||||||
|
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -48,7 +63,11 @@ class LLM:
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
dtype: str = "auto",
|
dtype: str = "auto",
|
||||||
|
quantization: Optional[str] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
|
gpu_memory_utilization: float = 0.9,
|
||||||
|
swap_space: int = 4,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
if "disable_log_stats" not in kwargs:
|
if "disable_log_stats" not in kwargs:
|
||||||
@@ -60,7 +79,11 @@ class LLM:
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
quantization=quantization,
|
||||||
|
revision=revision,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
swap_space=swap_space,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.llm_engine = LLMEngine.from_engine_args(engine_args)
|
self.llm_engine = LLMEngine.from_engine_args(engine_args)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import BackgroundTasks, Request
|
from fastapi import Request
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
@@ -130,6 +130,8 @@ async def check_length(
|
|||||||
input_ids = tokenizer(prompt).input_ids
|
input_ids = tokenizer(prompt).input_ids
|
||||||
token_num = len(input_ids)
|
token_num = len(input_ids)
|
||||||
|
|
||||||
|
if request.max_tokens is None:
|
||||||
|
request.max_tokens = max_model_len - token_num
|
||||||
if token_num + request.max_tokens > max_model_len:
|
if token_num + request.max_tokens > max_model_len:
|
||||||
return input_ids, create_error_response(
|
return input_ids, create_error_response(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
@@ -196,7 +198,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
if request.logit_bias is not None:
|
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||||
# TODO: support logit_bias in vLLM engine.
|
# TODO: support logit_bias in vLLM engine.
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"logit_bias is not currently supported")
|
"logit_bias is not currently supported")
|
||||||
@@ -217,11 +219,13 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
stop=request.stop,
|
stop=request.stop,
|
||||||
|
stop_token_ids=request.stop_token_ids,
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
best_of=request.best_of,
|
best_of=request.best_of,
|
||||||
top_k=request.top_k,
|
top_k=request.top_k,
|
||||||
ignore_eos=request.ignore_eos,
|
ignore_eos=request.ignore_eos,
|
||||||
use_beam_search=request.use_beam_search,
|
use_beam_search=request.use_beam_search,
|
||||||
|
skip_special_tokens=request.skip_special_tokens,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
@@ -229,9 +233,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
result_generator = engine.generate(prompt, sampling_params, request_id,
|
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||||
token_ids)
|
token_ids)
|
||||||
|
|
||||||
async def abort_request() -> None:
|
|
||||||
await engine.abort(request_id)
|
|
||||||
|
|
||||||
def create_stream_response_json(
|
def create_stream_response_json(
|
||||||
index: int,
|
index: int,
|
||||||
text: str,
|
text: str,
|
||||||
@@ -291,19 +292,15 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
|
|
||||||
# Streaming response
|
# Streaming response
|
||||||
if request.stream:
|
if request.stream:
|
||||||
background_tasks = BackgroundTasks()
|
|
||||||
# Abort the request if the client disconnects.
|
|
||||||
background_tasks.add_task(abort_request)
|
|
||||||
return StreamingResponse(completion_stream_generator(),
|
return StreamingResponse(completion_stream_generator(),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream")
|
||||||
background=background_tasks)
|
|
||||||
|
|
||||||
# Non-streaming response
|
# Non-streaming response
|
||||||
final_res: RequestOutput = None
|
final_res: RequestOutput = None
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
if await raw_request.is_disconnected():
|
if await raw_request.is_disconnected():
|
||||||
# Abort the request if the client disconnects.
|
# Abort the request if the client disconnects.
|
||||||
await abort_request()
|
await engine.abort(request_id)
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"Client disconnected")
|
"Client disconnected")
|
||||||
final_res = res
|
final_res = res
|
||||||
@@ -379,7 +376,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"suffix is not currently supported")
|
"suffix is not currently supported")
|
||||||
|
|
||||||
if request.logit_bias is not None:
|
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||||
# TODO: support logit_bias in vLLM engine.
|
# TODO: support logit_bias in vLLM engine.
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"logit_bias is not currently supported")
|
"logit_bias is not currently supported")
|
||||||
@@ -425,10 +422,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
top_k=request.top_k,
|
top_k=request.top_k,
|
||||||
stop=request.stop,
|
stop=request.stop,
|
||||||
|
stop_token_ids=request.stop_token_ids,
|
||||||
ignore_eos=request.ignore_eos,
|
ignore_eos=request.ignore_eos,
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
logprobs=request.logprobs,
|
logprobs=request.logprobs,
|
||||||
use_beam_search=request.use_beam_search,
|
use_beam_search=request.use_beam_search,
|
||||||
|
skip_special_tokens=request.skip_special_tokens,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
@@ -448,9 +447,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
and (request.best_of is None or request.n == request.best_of)
|
and (request.best_of is None or request.n == request.best_of)
|
||||||
and not request.use_beam_search)
|
and not request.use_beam_search)
|
||||||
|
|
||||||
async def abort_request() -> None:
|
|
||||||
await engine.abort(request_id)
|
|
||||||
|
|
||||||
def create_stream_response_json(
|
def create_stream_response_json(
|
||||||
index: int,
|
index: int,
|
||||||
text: str,
|
text: str,
|
||||||
@@ -510,19 +506,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
|
|
||||||
# Streaming response
|
# Streaming response
|
||||||
if stream:
|
if stream:
|
||||||
background_tasks = BackgroundTasks()
|
|
||||||
# Abort the request if the client disconnects.
|
|
||||||
background_tasks.add_task(abort_request)
|
|
||||||
return StreamingResponse(completion_stream_generator(),
|
return StreamingResponse(completion_stream_generator(),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream")
|
||||||
background=background_tasks)
|
|
||||||
|
|
||||||
# Non-streaming response
|
# Non-streaming response
|
||||||
final_res: RequestOutput = None
|
final_res: RequestOutput = None
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
if await raw_request.is_disconnected():
|
if await raw_request.is_disconnected():
|
||||||
# Abort the request if the client disconnects.
|
# Abort the request if the client disconnects.
|
||||||
await abort_request()
|
await engine.abort(request_id)
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"Client disconnected")
|
"Client disconnected")
|
||||||
final_res = res
|
final_res = res
|
||||||
@@ -623,7 +615,7 @@ if __name__ == "__main__":
|
|||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
engine_model_config = asyncio.run(engine.get_model_config())
|
engine_model_config = asyncio.run(engine.get_model_config())
|
||||||
max_model_len = engine_model_config.get_max_model_len()
|
max_model_len = engine_model_config.max_model_len
|
||||||
|
|
||||||
# A separate tokenizer to map token IDs to strings.
|
# A separate tokenizer to map token IDs to strings.
|
||||||
tokenizer = get_tokenizer(engine_args.tokenizer,
|
tokenizer = get_tokenizer(engine_args.tokenizer,
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
temperature: Optional[float] = 0.7
|
temperature: Optional[float] = 0.7
|
||||||
top_p: Optional[float] = 1.0
|
top_p: Optional[float] = 1.0
|
||||||
n: Optional[int] = 1
|
n: Optional[int] = 1
|
||||||
max_tokens: Optional[int] = 16
|
max_tokens: Optional[int] = None
|
||||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
presence_penalty: Optional[float] = 0.0
|
presence_penalty: Optional[float] = 0.0
|
||||||
@@ -70,6 +70,8 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
top_k: Optional[int] = -1
|
top_k: Optional[int] = -1
|
||||||
ignore_eos: Optional[bool] = False
|
ignore_eos: Optional[bool] = False
|
||||||
use_beam_search: Optional[bool] = False
|
use_beam_search: Optional[bool] = False
|
||||||
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
|
skip_special_tokens: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
@@ -94,6 +96,8 @@ class CompletionRequest(BaseModel):
|
|||||||
top_k: Optional[int] = -1
|
top_k: Optional[int] = -1
|
||||||
ignore_eos: Optional[bool] = False
|
ignore_eos: Optional[bool] = False
|
||||||
use_beam_search: Optional[bool] = False
|
use_beam_search: Optional[bool] = False
|
||||||
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
|
skip_special_tokens: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
class LogProbs(BaseModel):
|
class LogProbs(BaseModel):
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from xformers.ops import AttentionBias
|
from xformers.ops import AttentionBias
|
||||||
@@ -29,6 +29,7 @@ class InputMetadata:
|
|||||||
context_lens: torch.Tensor,
|
context_lens: torch.Tensor,
|
||||||
max_context_len: int,
|
max_context_len: int,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.seq_groups = seq_groups
|
self.seq_groups = seq_groups
|
||||||
self.seq_data = seq_data
|
self.seq_data = seq_data
|
||||||
@@ -38,6 +39,24 @@ class InputMetadata:
|
|||||||
self.max_context_len = max_context_len
|
self.max_context_len = max_context_len
|
||||||
self.block_tables = block_tables
|
self.block_tables = block_tables
|
||||||
|
|
||||||
|
self.to_cache = None
|
||||||
|
if sliding_window is not None:
|
||||||
|
# We need to keep the positions of sliding windows within
|
||||||
|
# the key / value tables, this is helpful to know which
|
||||||
|
# elements we need to cache and where
|
||||||
|
to_cache, start_idx = [], 0
|
||||||
|
for prompt_len in self.prompt_lens:
|
||||||
|
to_cache.extend(
|
||||||
|
range(
|
||||||
|
start_idx + max(0, prompt_len - sliding_window),
|
||||||
|
start_idx + prompt_len,
|
||||||
|
))
|
||||||
|
start_idx += prompt_len
|
||||||
|
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
|
||||||
|
self.to_cache = torch.tensor(to_cache,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.slot_mapping.device)
|
||||||
|
|
||||||
self.num_prompts = len(prompt_lens)
|
self.num_prompts = len(prompt_lens)
|
||||||
self.num_prompt_tokens = sum(prompt_lens)
|
self.num_prompt_tokens = sum(prompt_lens)
|
||||||
self.num_generation_tokens = context_lens.shape[0]
|
self.num_generation_tokens = context_lens.shape[0]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Multi-head attention."""
|
"""Multi-head attention."""
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -9,8 +9,10 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
|||||||
|
|
||||||
from vllm import attention_ops
|
from vllm import attention_ops
|
||||||
from vllm import cache_ops
|
from vllm import cache_ops
|
||||||
from vllm import pos_encoding_ops
|
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
|
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
|
||||||
|
RotaryEmbedding)
|
||||||
|
|
||||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
|
|
||||||
@@ -56,12 +58,14 @@ class PagedAttention(nn.Module):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
num_kv_heads: Optional[int] = None) -> None:
|
num_kv_heads: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
@@ -73,12 +77,19 @@ class PagedAttention(nn.Module):
|
|||||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||||
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
||||||
|
|
||||||
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
|
def set_attn_bias(
|
||||||
|
self,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
del dtype # Unused.
|
||||||
if input_metadata.attn_bias:
|
if input_metadata.attn_bias:
|
||||||
# Already set by a previous layer.
|
# Already set by a previous layer.
|
||||||
return
|
return
|
||||||
prompt_lens = input_metadata.prompt_lens
|
prompt_lens = input_metadata.prompt_lens
|
||||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
|
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
attn_bias = attn_bias.make_local_attention(self.sliding_window)
|
||||||
input_metadata.attn_bias.append(attn_bias)
|
input_metadata.attn_bias.append(attn_bias)
|
||||||
|
|
||||||
def multi_query_kv_attention(
|
def multi_query_kv_attention(
|
||||||
@@ -196,7 +207,7 @@ class PagedAttention(nn.Module):
|
|||||||
if num_prompt_tokens > 0:
|
if num_prompt_tokens > 0:
|
||||||
# Prompt run.
|
# Prompt run.
|
||||||
assert input_metadata.num_generation_tokens == 0
|
assert input_metadata.num_generation_tokens == 0
|
||||||
self.set_attn_bias(input_metadata)
|
self.set_attn_bias(input_metadata, dtype=query.dtype)
|
||||||
self.multi_query_kv_attention(
|
self.multi_query_kv_attention(
|
||||||
output[:num_prompt_tokens],
|
output[:num_prompt_tokens],
|
||||||
query[:num_prompt_tokens],
|
query[:num_prompt_tokens],
|
||||||
@@ -216,12 +227,20 @@ class PagedAttention(nn.Module):
|
|||||||
if (num_valid_tokens > 0 and key_cache is not None
|
if (num_valid_tokens > 0 and key_cache is not None
|
||||||
and value_cache is not None):
|
and value_cache is not None):
|
||||||
# The stride is 3 because the key and value are sliced from qkv.
|
# The stride is 3 because the key and value are sliced from qkv.
|
||||||
|
key_to_cache = key[:num_valid_tokens]
|
||||||
|
value_to_cache = value[:num_valid_tokens]
|
||||||
|
slot_mapping = input_metadata.slot_mapping
|
||||||
|
if input_metadata.to_cache is not None:
|
||||||
|
key_to_cache = key_to_cache[input_metadata.to_cache]
|
||||||
|
value_to_cache = value_to_cache[input_metadata.to_cache]
|
||||||
|
slot_mapping = slot_mapping[input_metadata.to_cache]
|
||||||
|
|
||||||
cache_ops.reshape_and_cache(
|
cache_ops.reshape_and_cache(
|
||||||
key[:num_valid_tokens],
|
key_to_cache,
|
||||||
value[:num_valid_tokens],
|
value_to_cache,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
input_metadata.slot_mapping,
|
slot_mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
if input_metadata.num_generation_tokens > 0:
|
if input_metadata.num_generation_tokens > 0:
|
||||||
@@ -242,7 +261,7 @@ class PagedAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class PagedAttentionWithRoPE(PagedAttention):
|
class PagedAttentionWithRoPE(PagedAttention):
|
||||||
"""PagedAttention with rotary embedding."""
|
"""PagedAttention with rotary positional embedding."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -254,26 +273,31 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
base: int = 10000,
|
base: int = 10000,
|
||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: Optional[int] = None,
|
||||||
is_neox_style: bool = True,
|
is_neox_style: bool = True,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(num_heads, head_size, scale, num_kv_heads)
|
super().__init__(num_heads,
|
||||||
self.is_neox_style = is_neox_style
|
head_size,
|
||||||
|
scale,
|
||||||
# Create the cos and sin cache.
|
num_kv_heads,
|
||||||
inv_freq = 1.0 / (base**(
|
sliding_window=sliding_window)
|
||||||
torch.arange(0, rotary_dim, 2, device="cuda") / rotary_dim))
|
if rope_scaling is None:
|
||||||
t = torch.arange(max_position, device="cuda").float()
|
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
max_position, base,
|
||||||
cos = freqs.cos()
|
is_neox_style)
|
||||||
sin = freqs.sin()
|
else:
|
||||||
cache = torch.cat((cos, sin), dim=-1)
|
scaling_type = rope_scaling["type"]
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
# FIXME(woosuk): This assumes that we configure the default dtype when
|
if scaling_type == "linear":
|
||||||
# initializing the model.
|
self.rotary_emb = LinearScalingRotaryEmbedding(
|
||||||
# TODO(woosuk): Make it more robust.
|
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
torch_dtype = torch.get_default_dtype()
|
scaling_factor)
|
||||||
cache = cache.to(torch_dtype)
|
elif scaling_type == "dynamic":
|
||||||
# Embedding size: [max_position, rotary_dim]
|
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
|
scaling_factor)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -290,7 +314,7 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
positions: shape = [num_tokens]
|
positions: shape = [num_tokens]
|
||||||
query: shape = [num_tokens, num_heads * head_size]
|
query: shape = [num_tokens, num_heads * head_size]
|
||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||||
@@ -306,14 +330,7 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
|
|
||||||
# Apply rotary embedding to the query and key before passing them
|
# Apply rotary embedding to the query and key before passing them
|
||||||
# to the attention op.
|
# to the attention op.
|
||||||
pos_encoding_ops.rotary_embedding(
|
query, key = self.rotary_emb(positions, query, key)
|
||||||
positions,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
self.head_size,
|
|
||||||
self.cos_sin_cache,
|
|
||||||
self.is_neox_style,
|
|
||||||
)
|
|
||||||
return super().forward(
|
return super().forward(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@@ -340,13 +357,14 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
slopes = torch.tensor(slopes, dtype=torch.float32)
|
slopes = torch.tensor(slopes, dtype=torch.float32)
|
||||||
self.register_buffer("alibi_slopes", slopes, persistent=False)
|
self.register_buffer("alibi_slopes", slopes, persistent=False)
|
||||||
|
|
||||||
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
|
def set_attn_bias(self, input_metadata: InputMetadata,
|
||||||
|
dtype: torch.dtype) -> None:
|
||||||
if input_metadata.attn_bias:
|
if input_metadata.attn_bias:
|
||||||
# Already set by a previous layer.
|
# Already set by a previous layer.
|
||||||
return
|
return
|
||||||
# Generates ALiBi mask for each prompt.
|
# Generates ALiBi mask for each prompt.
|
||||||
for prompt_len in input_metadata.prompt_lens:
|
for prompt_len in input_metadata.prompt_lens:
|
||||||
bias = torch.arange(prompt_len)
|
bias = torch.arange(prompt_len, dtype=dtype)
|
||||||
# Note(zhuohan): HF uses
|
# Note(zhuohan): HF uses
|
||||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||||
# here. We find that both biases give the same results, but
|
# here. We find that both biases give the same results, but
|
||||||
@@ -364,6 +382,7 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
prompt_len,
|
prompt_len,
|
||||||
padded_len,
|
padded_len,
|
||||||
device=self.alibi_slopes.device,
|
device=self.alibi_slopes.device,
|
||||||
|
dtype=dtype,
|
||||||
)[:, :, :, :prompt_len].copy_(bias)
|
)[:, :, :, :prompt_len].copy_(bias)
|
||||||
bias.mul_(self.alibi_slopes[:, None, None])
|
bias.mul_(self.alibi_slopes[:, None, None])
|
||||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||||
|
|||||||
37
vllm/model_executor/layers/quantized_linear/__init__.py
Normal file
37
vllm/model_executor/layers/quantized_linear/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from vllm.model_executor.layers.quantized_linear.awq import (
|
||||||
|
AWQColumnParallelLinear, AWQRowParallelLinear)
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
|
ColumnParallelLinear, RowParallelLinear)
|
||||||
|
|
||||||
|
_QUANTIZED_LINEAR_REGISTRY = {
|
||||||
|
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLinear:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def column(cls, *args, **kwargs) -> ColumnParallelLinear:
|
||||||
|
quant_config = kwargs.get("quant_config", None)
|
||||||
|
if quant_config is None:
|
||||||
|
return ColumnParallelLinear(*args, **kwargs)
|
||||||
|
|
||||||
|
name = quant_config.get_name()
|
||||||
|
if name not in _QUANTIZED_LINEAR_REGISTRY:
|
||||||
|
raise ValueError(f"No quantized linear is found for {name}")
|
||||||
|
|
||||||
|
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
|
||||||
|
return quant_linear_cls(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def row(cls, *args, **kwargs) -> RowParallelLinear:
|
||||||
|
quant_config = kwargs.get("quant_config", None)
|
||||||
|
if quant_config is None:
|
||||||
|
return RowParallelLinear(*args, **kwargs)
|
||||||
|
|
||||||
|
name = quant_config.get_name()
|
||||||
|
if name not in _QUANTIZED_LINEAR_REGISTRY:
|
||||||
|
raise ValueError(f"No quantized linear is found for {name}")
|
||||||
|
|
||||||
|
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
|
||||||
|
return quant_linear_cls(*args, **kwargs)
|
||||||
102
vllm/model_executor/layers/quantized_linear/awq.py
Normal file
102
vllm/model_executor/layers/quantized_linear/awq.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from vllm import quantization_ops
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel.layers import (
|
||||||
|
ColumnParallelLinear, RowParallelLinear)
|
||||||
|
|
||||||
|
|
||||||
|
class AWQColumnParallelLinear(ColumnParallelLinear):
|
||||||
|
|
||||||
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
|
assert self.input_size % self.quant_config.weight_bits == 0
|
||||||
|
assert (self.output_size_per_partition %
|
||||||
|
self.quant_config.pack_factor == 0)
|
||||||
|
self.qweight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size,
|
||||||
|
self.output_size_per_partition //
|
||||||
|
self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.qzeros = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size // self.quant_config.group_size,
|
||||||
|
self.output_size_per_partition //
|
||||||
|
self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.scales = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size // self.quant_config.group_size,
|
||||||
|
self.output_size_per_partition,
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
pack_factor = self.quant_config.pack_factor
|
||||||
|
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
|
||||||
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
|
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
||||||
|
self.qzeros, pack_factor)
|
||||||
|
if bias is not None:
|
||||||
|
out = out + bias
|
||||||
|
return out.reshape(out_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class AWQRowParallelLinear(RowParallelLinear):
|
||||||
|
|
||||||
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
|
assert (self.input_size_per_partition %
|
||||||
|
self.quant_config.weight_bits == 0)
|
||||||
|
assert self.output_size % self.quant_config.pack_factor == 0
|
||||||
|
self.qweight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size_per_partition,
|
||||||
|
self.output_size // self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.qzeros = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size_per_partition // self.quant_config.group_size,
|
||||||
|
self.output_size // self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.scales = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size_per_partition // self.quant_config.group_size,
|
||||||
|
self.output_size,
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
pack_factor = self.quant_config.pack_factor
|
||||||
|
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
|
||||||
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
|
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
||||||
|
self.qzeros, pack_factor)
|
||||||
|
return out.reshape(out_shape)
|
||||||
169
vllm/model_executor/layers/rotary_embedding.py
Normal file
169
vllm/model_executor/layers/rotary_embedding.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Rotary Positional Embeddings."""
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm import pos_encoding_ops
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
"""Original rotary positional embedding."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.head_size = head_size
|
||||||
|
self.rotary_dim = rotary_dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
self.is_neox_style = is_neox_style
|
||||||
|
|
||||||
|
cache = self._compute_cos_sin_cache()
|
||||||
|
cache = cache.to(torch.get_default_dtype())
|
||||||
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
|
||||||
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||||
|
"""Compute the inverse frequency."""
|
||||||
|
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
||||||
|
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
||||||
|
# avoid numerical issues with large base values (e.g., 10000000).
|
||||||
|
# This may cause a slight numerical difference between the HF
|
||||||
|
# implementation and ours.
|
||||||
|
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||||
|
# use CPU to compute the cache and then move it to GPU. However, we
|
||||||
|
# create the cache on GPU for faster initialization. This may cause
|
||||||
|
# a slight numerical difference between the HF implementation and ours.
|
||||||
|
inv_freq = 1.0 / (base**(torch.arange(
|
||||||
|
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
|
||||||
|
self.rotary_dim))
|
||||||
|
return inv_freq
|
||||||
|
|
||||||
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
|
"""Compute the cos and sin cache."""
|
||||||
|
inv_freq = self._compute_inv_freq(self.base)
|
||||||
|
t = torch.arange(self.max_position_embeddings,
|
||||||
|
dtype=torch.float,
|
||||||
|
device="cuda")
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
return cache
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# pos_encoding_ops.rotary_embedding() is an in-place operation that
|
||||||
|
# updates the query and key tensors.
|
||||||
|
pos_encoding_ops.rotary_embedding(positions, query, key,
|
||||||
|
self.head_size, self.cos_sin_cache,
|
||||||
|
self.is_neox_style)
|
||||||
|
return query, key
|
||||||
|
|
||||||
|
|
||||||
|
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||||
|
"""RotaryEmbedding extended with linear scaling.
|
||||||
|
|
||||||
|
Credits to the Reddit user /u/kaiokendev
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
scaling_factor: float,
|
||||||
|
) -> None:
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
|
is_neox_style)
|
||||||
|
|
||||||
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
|
inv_freq = self._compute_inv_freq(self.base)
|
||||||
|
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||||
|
# maximum length before applying the rope scaling.
|
||||||
|
# Thus, the maximum length after applying the rope scaling is
|
||||||
|
# self.max_position_embeddings * self.scaling_factor.
|
||||||
|
max_len = self.max_position_embeddings * self.scaling_factor
|
||||||
|
t = torch.arange(max_len, dtype=torch.float, device="cuda")
|
||||||
|
t = t / self.scaling_factor
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||||
|
"""RotaryEmbedding extended with Dynamic NTK scaling.
|
||||||
|
|
||||||
|
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
scaling_factor: float,
|
||||||
|
) -> None:
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
|
is_neox_style)
|
||||||
|
|
||||||
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
|
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||||
|
# maximum length before applying the rope scaling.
|
||||||
|
# Thus, the maximum length after applying the rope scaling is
|
||||||
|
# self.max_position_embeddings * self.scaling_factor.
|
||||||
|
max_len = self.max_position_embeddings * self.scaling_factor
|
||||||
|
base = self.base * (
|
||||||
|
(self.scaling_factor * max_len / self.max_position_embeddings) -
|
||||||
|
(self.scaling_factor - 1))**(self.rotary_dim /
|
||||||
|
(self.rotary_dim - 2))
|
||||||
|
inv_freq = self._compute_inv_freq(base)
|
||||||
|
t = torch.arange(max_len, dtype=torch.float, device="cuda")
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
return cache
|
||||||
@@ -1,15 +1,14 @@
|
|||||||
"""A layer that samples the next tokens from the model's outputs."""
|
"""A layer that samples the next tokens from the model's outputs."""
|
||||||
from typing import Dict, List, Tuple, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
gather_from_tensor_model_parallel_region)
|
gather_from_tensor_model_parallel_region)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import SamplerOutput, SequenceOutputs
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
|
|
||||||
@@ -44,12 +43,8 @@ class Sampler(nn.Module):
|
|||||||
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
||||||
|
|
||||||
# Get the logits for the next tokens.
|
# Get the logits for the next tokens.
|
||||||
logits = torch.matmul(hidden_states, embedding.t())
|
logits = _get_logits(hidden_states, embedding, embedding_bias,
|
||||||
if embedding_bias is not None:
|
self.vocab_size)
|
||||||
logits += embedding_bias
|
|
||||||
logits = gather_from_tensor_model_parallel_region(logits)
|
|
||||||
# Remove paddings in vocab (if any).
|
|
||||||
logits = logits[:, :self.vocab_size]
|
|
||||||
|
|
||||||
# Apply presence and frequency penalties.
|
# Apply presence and frequency penalties.
|
||||||
output_tokens = _get_output_tokens(input_metadata)
|
output_tokens = _get_output_tokens(input_metadata)
|
||||||
@@ -59,7 +54,7 @@ class Sampler(nn.Module):
|
|||||||
assert len(presence_penalties) == logits.shape[0]
|
assert len(presence_penalties) == logits.shape[0]
|
||||||
assert len(frequency_penalties) == logits.shape[0]
|
assert len(frequency_penalties) == logits.shape[0]
|
||||||
logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
||||||
frequency_penalties, self.vocab_size)
|
frequency_penalties)
|
||||||
|
|
||||||
# Apply temperature scaling.
|
# Apply temperature scaling.
|
||||||
temperatures = _get_temperatures(input_metadata)
|
temperatures = _get_temperatures(input_metadata)
|
||||||
@@ -82,26 +77,55 @@ class Sampler(nn.Module):
|
|||||||
# We use float32 for probabilities and log probabilities.
|
# We use float32 for probabilities and log probabilities.
|
||||||
# Compute the probabilities.
|
# Compute the probabilities.
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||||
# Compute the log probabilities (before applying top-p and top-k).
|
# Compute the log probabilities.
|
||||||
logprobs = torch.log(probs)
|
# Use log_softmax to ensure numerical stability.
|
||||||
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Sample the next tokens.
|
||||||
return _sample(probs, logprobs, input_metadata)
|
return _sample(probs, logprobs, input_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||||
|
embedding_bias: Optional[torch.Tensor],
|
||||||
|
vocab_size: int) -> torch.Tensor:
|
||||||
|
# Get the logits for the next tokens.
|
||||||
|
logits = torch.matmul(hidden_states, embedding.t())
|
||||||
|
if embedding_bias is not None:
|
||||||
|
logits += embedding_bias
|
||||||
|
logits = gather_from_tensor_model_parallel_region(logits)
|
||||||
|
# Remove paddings in vocab (if any).
|
||||||
|
logits = logits[:, :vocab_size]
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _prune_hidden_states(
|
def _prune_hidden_states(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
last_token_indices = {t: [] for t in SamplingType}
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
last_token_indicies: List[int] = []
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||||
for prompt_len in input_metadata.prompt_lens:
|
seq_ids, sampling_params = seq_group
|
||||||
last_token_indicies.append(start_idx + prompt_len - 1)
|
sampling_type = sampling_params.sampling_type
|
||||||
start_idx += prompt_len
|
if i < input_metadata.num_prompts:
|
||||||
last_token_indicies.extend(
|
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
||||||
range(start_idx, start_idx + input_metadata.num_generation_tokens))
|
prompt_len = input_metadata.prompt_lens[i]
|
||||||
return hidden_states.index_select(
|
last_token_indices[sampling_type].append(start_idx + prompt_len -
|
||||||
0, torch.tensor(last_token_indicies, device=hidden_states.device))
|
1)
|
||||||
|
start_idx += prompt_len
|
||||||
|
else:
|
||||||
|
num_seqs = len(seq_ids)
|
||||||
|
last_token_indices[sampling_type].extend(
|
||||||
|
range(start_idx, start_idx + num_seqs))
|
||||||
|
start_idx += num_seqs
|
||||||
|
|
||||||
|
all_last_token_indices = []
|
||||||
|
for sampling_type in SamplingType:
|
||||||
|
all_last_token_indices.extend(last_token_indices[sampling_type])
|
||||||
|
all_last_token_indices = torch.tensor(all_last_token_indices,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=hidden_states.device)
|
||||||
|
return hidden_states.index_select(0, all_last_token_indices)
|
||||||
|
|
||||||
|
|
||||||
def _get_penalties(
|
def _get_penalties(
|
||||||
@@ -109,37 +133,22 @@ def _get_penalties(
|
|||||||
# Collect the presence and frequency penalties.
|
# Collect the presence and frequency penalties.
|
||||||
presence_penalties: List[float] = []
|
presence_penalties: List[float] = []
|
||||||
frequency_penalties: List[float] = []
|
frequency_penalties: List[float] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for seq_group in input_metadata.seq_groups:
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
p = sampling_params.presence_penalty
|
p = sampling_params.presence_penalty
|
||||||
f = sampling_params.frequency_penalty
|
f = sampling_params.frequency_penalty
|
||||||
if i < input_metadata.num_prompts:
|
presence_penalties += [p] * len(seq_ids)
|
||||||
# A prompt input.
|
frequency_penalties += [f] * len(seq_ids)
|
||||||
presence_penalties.append(p)
|
|
||||||
frequency_penalties.append(f)
|
|
||||||
else:
|
|
||||||
# A generation token.
|
|
||||||
presence_penalties += [p] * len(seq_ids)
|
|
||||||
frequency_penalties += [f] * len(seq_ids)
|
|
||||||
return presence_penalties, frequency_penalties
|
return presence_penalties, frequency_penalties
|
||||||
|
|
||||||
|
|
||||||
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
||||||
output_tokens: List[List[int]] = []
|
output_tokens: List[List[int]] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for seq_group in input_metadata.seq_groups:
|
||||||
seq_ids, _ = seq_group
|
seq_ids, _ = seq_group
|
||||||
if i < input_metadata.num_prompts:
|
for seq_id in seq_ids:
|
||||||
# A prompt input.
|
|
||||||
# NOTE: While the prompt input usually has no output tokens,
|
|
||||||
# it may have output tokens in the case of recomputation.
|
|
||||||
seq_id = seq_ids[0]
|
|
||||||
seq_data = input_metadata.seq_data[seq_id]
|
seq_data = input_metadata.seq_data[seq_id]
|
||||||
output_tokens.append(seq_data.output_token_ids)
|
output_tokens.append(seq_data.output_token_ids)
|
||||||
else:
|
|
||||||
# A generation token.
|
|
||||||
for seq_id in seq_ids:
|
|
||||||
seq_data = input_metadata.seq_data[seq_id]
|
|
||||||
output_tokens.append(seq_data.output_token_ids)
|
|
||||||
return output_tokens
|
return output_tokens
|
||||||
|
|
||||||
|
|
||||||
@@ -148,11 +157,8 @@ def _apply_penalties(
|
|||||||
output_tokens: List[List[int]],
|
output_tokens: List[List[int]],
|
||||||
presence_penalties: List[float],
|
presence_penalties: List[float],
|
||||||
frequency_penalties: List[float],
|
frequency_penalties: List[float],
|
||||||
vocab_size: int,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_seqs = logits.shape[0]
|
num_seqs, vocab_size = logits.shape
|
||||||
# Collect the indices of sequences that have non-zero penalties.
|
|
||||||
indices = []
|
|
||||||
for i in range(num_seqs):
|
for i in range(num_seqs):
|
||||||
if not output_tokens[i]:
|
if not output_tokens[i]:
|
||||||
continue
|
continue
|
||||||
@@ -160,40 +166,47 @@ def _apply_penalties(
|
|||||||
f = frequency_penalties[i]
|
f = frequency_penalties[i]
|
||||||
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
|
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
|
||||||
continue
|
continue
|
||||||
indices.append(i)
|
break
|
||||||
|
else:
|
||||||
# Return early if all sequences have zero penalties.
|
# Return early if all sequences have zero penalties.
|
||||||
if not indices:
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
bin_counts = []
|
max_output_len = max(len(tokens) for tokens in output_tokens)
|
||||||
for i in indices:
|
padded_output_tokens = [
|
||||||
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
|
tokens + [vocab_size] * (max_output_len - len(tokens))
|
||||||
bin_counts = np.stack(bin_counts, axis=0)
|
for tokens in output_tokens
|
||||||
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
|
]
|
||||||
device=logits.device)
|
output_tokens_tensor = torch.tensor(padded_output_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=logits.device)
|
||||||
|
|
||||||
|
# Compute the bin counts for the output tokens.
|
||||||
|
# vocab_size + 1 for padding.
|
||||||
|
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=logits.device)
|
||||||
|
bin_counts.scatter_add_(1, output_tokens_tensor,
|
||||||
|
torch.ones_like(output_tokens_tensor))
|
||||||
|
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
|
||||||
|
|
||||||
frequency_penalties = [frequency_penalties[i] for i in indices]
|
|
||||||
frequency_penalties = torch.tensor(frequency_penalties,
|
frequency_penalties = torch.tensor(frequency_penalties,
|
||||||
dtype=logits.dtype,
|
dtype=logits.dtype,
|
||||||
device=logits.device)
|
device=logits.device)
|
||||||
presence_penalties = [presence_penalties[i] for i in indices]
|
|
||||||
presence_penalties = torch.tensor(presence_penalties,
|
presence_penalties = torch.tensor(presence_penalties,
|
||||||
dtype=logits.dtype,
|
dtype=logits.dtype,
|
||||||
device=logits.device)
|
device=logits.device)
|
||||||
|
|
||||||
# We follow the definition in OpenAI API.
|
# We follow the definition in OpenAI API.
|
||||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||||
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
||||||
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
|
logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
|
||||||
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
||||||
# Collect the temperatures for the logits.
|
# Collect the temperatures for the logits.
|
||||||
temperatures: List[float] = []
|
temperatures: List[float] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for seq_group in input_metadata.seq_groups:
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
temperature = sampling_params.temperature
|
temperature = sampling_params.temperature
|
||||||
if temperature < _SAMPLING_EPS:
|
if temperature < _SAMPLING_EPS:
|
||||||
@@ -201,13 +214,7 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
|||||||
# (i.e., greedy sampling or beam search).
|
# (i.e., greedy sampling or beam search).
|
||||||
# Set the temperature to 1 to avoid division by zero.
|
# Set the temperature to 1 to avoid division by zero.
|
||||||
temperature = 1.0
|
temperature = 1.0
|
||||||
|
temperatures += [temperature] * len(seq_ids)
|
||||||
if i < input_metadata.num_prompts:
|
|
||||||
# A prompt input.
|
|
||||||
temperatures.append(temperature)
|
|
||||||
else:
|
|
||||||
# A generation token.
|
|
||||||
temperatures += [temperature] * len(seq_ids)
|
|
||||||
return temperatures
|
return temperatures
|
||||||
|
|
||||||
|
|
||||||
@@ -217,21 +224,15 @@ def _get_top_p_top_k(
|
|||||||
) -> Tuple[List[float], List[int]]:
|
) -> Tuple[List[float], List[int]]:
|
||||||
top_ps: List[float] = []
|
top_ps: List[float] = []
|
||||||
top_ks: List[int] = []
|
top_ks: List[int] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for seq_group in input_metadata.seq_groups:
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
top_p = sampling_params.top_p
|
top_p = sampling_params.top_p
|
||||||
# k should not be greater than the vocab size.
|
# k should not be greater than the vocab size.
|
||||||
top_k = min(sampling_params.top_k, vocab_size)
|
top_k = min(sampling_params.top_k, vocab_size)
|
||||||
# k=-1 means no truncation.
|
# k=-1 means no truncation.
|
||||||
top_k = vocab_size if top_k == -1 else top_k
|
top_k = vocab_size if top_k == -1 else top_k
|
||||||
if i < input_metadata.num_prompts:
|
top_ps += [top_p] * len(seq_ids)
|
||||||
# A prompt input.
|
top_ks += [top_k] * len(seq_ids)
|
||||||
top_ps.append(top_p)
|
|
||||||
top_ks.append(top_k)
|
|
||||||
else:
|
|
||||||
# A generation token.
|
|
||||||
top_ps += [top_p] * len(seq_ids)
|
|
||||||
top_ks += [top_k] * len(seq_ids)
|
|
||||||
return top_ps, top_ks
|
return top_ps, top_ks
|
||||||
|
|
||||||
|
|
||||||
@@ -267,95 +268,154 @@ def _apply_top_p_top_k(
|
|||||||
def _get_topk_logprobs(
|
def _get_topk_logprobs(
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
num_logprobs: Optional[int],
|
num_logprobs: Optional[int],
|
||||||
) -> Dict[int, float]:
|
) -> List[Dict[int, float]]:
|
||||||
|
num_seqs = logprobs.size(0)
|
||||||
if num_logprobs is None or num_logprobs == 0:
|
if num_logprobs is None or num_logprobs == 0:
|
||||||
return {}
|
return [{} for _ in range(num_seqs)]
|
||||||
|
|
||||||
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
|
all_topk_logprobs, all_topk_ids = torch.topk(logprobs,
|
||||||
if num_logprobs == 1:
|
num_logprobs,
|
||||||
topk_logprobs = [topk_logprobs.item()]
|
dim=-1)
|
||||||
topk_ids = [topk_ids.item()]
|
all_topk_logprobs = all_topk_logprobs.cpu()
|
||||||
else:
|
all_topk_ids = all_topk_ids.cpu()
|
||||||
topk_logprobs = topk_logprobs.tolist()
|
all_token_to_logprob = []
|
||||||
topk_ids = topk_ids.tolist()
|
for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids):
|
||||||
|
token_to_logprob: Dict[int, float] = {}
|
||||||
token_to_logprob: Dict[int, float] = {}
|
for token_id, logprob in zip(topk_ids, topk_logprobs):
|
||||||
for token_id, logprob in zip(topk_ids, topk_logprobs):
|
token_to_logprob[token_id.item()] = logprob.item()
|
||||||
token_to_logprob[token_id] = logprob
|
all_token_to_logprob.append(token_to_logprob)
|
||||||
return token_to_logprob
|
return all_token_to_logprob
|
||||||
|
|
||||||
|
|
||||||
def _sample_from_prompt(
|
def _build_sequence_outputs(
|
||||||
prob: torch.Tensor,
|
parent_ids: List[int],
|
||||||
sampling_params: SamplingParams,
|
next_token_ids: List[int],
|
||||||
) -> List[int]:
|
selected_token_logprobs: torch.Tensor,
|
||||||
if sampling_params.use_beam_search:
|
parent_seq_ids: List[int],
|
||||||
# Beam search.
|
parent_logprobs: torch.Tensor,
|
||||||
beam_width = sampling_params.best_of
|
num_output_logprobs: Optional[int],
|
||||||
# Sample 2 * beam_width candidates to make sure that with high
|
) -> List[SequenceOutputs]:
|
||||||
# probability we can get `beam_width` candidates in addition to
|
# Get top-k log probabilities for the next tokens.
|
||||||
# the finished sequences for the next iteration. See
|
next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs)
|
||||||
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
seq_outputs: List[SequenceOutputs] = []
|
||||||
# for details. See also HF reference:
|
for parent_id, next_token_id, token_logprob in zip(
|
||||||
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
parent_ids, next_token_ids, selected_token_logprobs):
|
||||||
_, next_token_ids = torch.topk(prob, 2 * beam_width)
|
output_logprobs = next_logprobs[parent_id].copy()
|
||||||
next_token_ids = next_token_ids.tolist()
|
output_logprobs[next_token_id] = token_logprob
|
||||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
seq_outputs.append(
|
||||||
# Greedy sampling.
|
SequenceOutputs(parent_seq_ids[parent_id], next_token_id,
|
||||||
assert sampling_params.best_of == 1
|
output_logprobs))
|
||||||
next_token_id = torch.argmax(prob)
|
return seq_outputs
|
||||||
next_token_ids = [next_token_id.item()]
|
|
||||||
else:
|
|
||||||
# Random sampling.
|
|
||||||
# Sample `best_of` tokens for the prompt.
|
|
||||||
num_seqs = sampling_params.best_of
|
|
||||||
next_token_ids = torch.multinomial(prob,
|
|
||||||
num_samples=num_seqs,
|
|
||||||
replacement=True)
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
|
||||||
return next_token_ids
|
|
||||||
|
|
||||||
|
|
||||||
def _sample_from_generation_tokens(
|
def _greedy_sample(
|
||||||
seq_ids: List[int],
|
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||||
probs: torch.Tensor,
|
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
seq_logprobs: List[float],
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
sampling_params: SamplingParams,
|
samples = torch.argmax(logprobs, dim=-1).cpu()
|
||||||
) -> Tuple[List[int], List[int]]:
|
sample_idx = 0
|
||||||
# NOTE(woosuk): sampling_params.best_of can be greater than
|
results = []
|
||||||
# len(seq_ids) because some sequences in the group might have
|
for seq_group in selected_seq_groups:
|
||||||
# been already terminated.
|
seq_ids, _ = seq_group
|
||||||
if sampling_params.use_beam_search:
|
num_parent_seqs = len(seq_ids)
|
||||||
# Beam search.
|
assert num_parent_seqs == 1, (
|
||||||
# Add cumulative logprobs for the sequences in the group.
|
"Greedy sampling should have only one seq.")
|
||||||
seq_logprobs = torch.tensor(seq_logprobs,
|
parent_ids = list(range(num_parent_seqs))
|
||||||
dtype=torch.float,
|
next_token_ids = [samples[sample_idx].item()]
|
||||||
device=logprobs.device)
|
results.append((next_token_ids, parent_ids))
|
||||||
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
|
sample_idx += num_parent_seqs
|
||||||
|
assert sample_idx == logprobs.size(0)
|
||||||
|
return results
|
||||||
|
|
||||||
vocab_size = logprobs.size(-1)
|
|
||||||
beam_width = len(seq_ids)
|
def _random_sample(
|
||||||
_, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
|
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||||
topk_ids = topk_ids.tolist()
|
is_prompts: List[bool],
|
||||||
seq_idx = [i // vocab_size for i in topk_ids]
|
probs: torch.Tensor,
|
||||||
parent_seq_ids = [seq_ids[i] for i in seq_idx]
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
next_token_ids = [i % vocab_size for i in topk_ids]
|
# Find the maximum best_of value of the prompt phase requests.
|
||||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
max_best_of = 1
|
||||||
# Greedy sampling.
|
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||||
assert len(seq_ids) == 1
|
if is_prompt:
|
||||||
next_token_id = torch.argmax(probs, dim=-1)
|
seq_ids, sampling_params = seq_group
|
||||||
next_token_ids = [int(next_token_id.item())]
|
max_best_of = max(max_best_of, sampling_params.best_of)
|
||||||
parent_seq_ids = seq_ids
|
random_samples = torch.multinomial(probs,
|
||||||
else:
|
num_samples=max_best_of,
|
||||||
# Random sampling.
|
replacement=True).cpu()
|
||||||
# Sample 1 token for each sequence in the group.
|
sample_idx = 0
|
||||||
next_token_ids = torch.multinomial(probs,
|
results = []
|
||||||
num_samples=1,
|
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||||
replacement=True)
|
seq_ids, sampling_params = seq_group
|
||||||
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
|
num_parent_seqs = len(seq_ids)
|
||||||
parent_seq_ids = seq_ids
|
if is_prompt:
|
||||||
return parent_seq_ids, next_token_ids
|
# Prompt phase.
|
||||||
|
assert num_parent_seqs == 1, (
|
||||||
|
"Prompt input should have only one seq.")
|
||||||
|
parent_ids = [0] * sampling_params.best_of
|
||||||
|
next_token_ids = random_samples[
|
||||||
|
sample_idx, :sampling_params.best_of].tolist()
|
||||||
|
else:
|
||||||
|
# Generation phase.
|
||||||
|
parent_ids = list(range(num_parent_seqs))
|
||||||
|
next_token_ids = random_samples[sample_idx:sample_idx +
|
||||||
|
num_parent_seqs, 0].tolist()
|
||||||
|
results.append((next_token_ids, parent_ids))
|
||||||
|
sample_idx += num_parent_seqs
|
||||||
|
assert sample_idx == probs.size(0)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _beam_search_sample(
|
||||||
|
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||||
|
is_prompts: List[bool],
|
||||||
|
seq_data: Dict[int, SequenceData],
|
||||||
|
logprobs: torch.Tensor,
|
||||||
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
|
# We sample 2 * beam_width candidates to make sure that with high
|
||||||
|
# probability we can get `beam_width` candidates in addition to
|
||||||
|
# the finished sequences for the next iteration. See
|
||||||
|
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
||||||
|
# for details. See also HF reference:
|
||||||
|
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
||||||
|
#
|
||||||
|
# Note: Beam search is not vectorized, so its speed can be slower than
|
||||||
|
# other sampling methods.
|
||||||
|
sample_idx = 0
|
||||||
|
results = []
|
||||||
|
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||||
|
seq_ids, sampling_params = seq_group
|
||||||
|
num_parent_seqs = len(seq_ids)
|
||||||
|
beam_width = sampling_params.best_of
|
||||||
|
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
||||||
|
if is_prompt:
|
||||||
|
# Prompt phase.
|
||||||
|
assert num_parent_seqs == 1, (
|
||||||
|
"Prompt input should have only one seq.")
|
||||||
|
parent_ids = [0] * (2 * beam_width)
|
||||||
|
_, next_token_ids = torch.topk(seq_group_logprobs[0],
|
||||||
|
2 * beam_width)
|
||||||
|
next_token_ids = next_token_ids.tolist()
|
||||||
|
else:
|
||||||
|
# Generation phase.
|
||||||
|
cumulative_logprobs = [
|
||||||
|
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
|
||||||
|
]
|
||||||
|
cumulative_logprobs = torch.tensor(
|
||||||
|
cumulative_logprobs,
|
||||||
|
dtype=torch.float,
|
||||||
|
device=seq_group_logprobs.device)
|
||||||
|
seq_group_logprobs = (seq_group_logprobs +
|
||||||
|
cumulative_logprobs.unsqueeze(dim=1))
|
||||||
|
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
|
||||||
|
2 * beam_width)
|
||||||
|
topk_ids = topk_ids.tolist()
|
||||||
|
vocab_size = seq_group_logprobs.size(-1)
|
||||||
|
parent_ids = [i // vocab_size for i in topk_ids]
|
||||||
|
next_token_ids = [i % vocab_size for i in topk_ids]
|
||||||
|
results.append((next_token_ids, parent_ids))
|
||||||
|
sample_idx += num_parent_seqs
|
||||||
|
assert sample_idx == logprobs.size(0)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _sample(
|
def _sample(
|
||||||
@@ -363,65 +423,80 @@ def _sample(
|
|||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
seq_outputs: SamplerOutput = []
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||||
|
category_num_tokens = {t: 0 for t in SamplingType}
|
||||||
# TODO(woosuk): Optimize.
|
|
||||||
idx = 0
|
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||||
seq_group_outputs: List[SequenceOutputs] = []
|
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
if i < input_metadata.num_prompts:
|
sampling_type = sampling_params.sampling_type
|
||||||
# Generate the next tokens for a prompt input.
|
categorized_seq_group_ids[sampling_type].append(i)
|
||||||
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
num_seqs = len(seq_ids)
|
||||||
parent_seq_id = seq_ids[0]
|
category_num_tokens[sampling_type] += num_seqs
|
||||||
prob = probs[idx]
|
|
||||||
logprob = logprobs[idx]
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
# Sample the next tokens.
|
seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
|
||||||
next_token_ids = _sample_from_prompt(prob, sampling_params)
|
category_start_idx = 0
|
||||||
# Get top-k log probabilities for the next tokens.
|
for sampling_type in SamplingType:
|
||||||
next_logprobs = _get_topk_logprobs(logprob,
|
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
||||||
sampling_params.logprobs)
|
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
|
||||||
|
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
|
||||||
# Build the output.
|
num_tokens = category_num_tokens[sampling_type]
|
||||||
for next_token_id in next_token_ids:
|
if num_tokens == 0:
|
||||||
output_logprobs = next_logprobs.copy()
|
continue
|
||||||
output_logprobs[next_token_id] = logprob[next_token_id].item()
|
category_logprobs = logprobs[category_start_idx:category_start_idx +
|
||||||
seq_group_outputs.append(
|
num_tokens]
|
||||||
SequenceOutputs(parent_seq_id, next_token_id,
|
category_probs = probs[category_start_idx:category_start_idx +
|
||||||
output_logprobs))
|
num_tokens]
|
||||||
|
if sampling_type == SamplingType.GREEDY:
|
||||||
|
sample_results = _greedy_sample(seq_groups, category_logprobs)
|
||||||
|
elif sampling_type == SamplingType.RANDOM:
|
||||||
|
sample_results = _random_sample(seq_groups, is_prompts,
|
||||||
|
category_probs)
|
||||||
|
elif sampling_type == SamplingType.BEAM:
|
||||||
|
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
||||||
|
input_metadata.seq_data,
|
||||||
|
category_logprobs)
|
||||||
else:
|
else:
|
||||||
# Generate the next tokens for generation tokens.
|
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
||||||
|
|
||||||
|
# Batched query for logprobs of selected token
|
||||||
|
batched_logprobs_query_seq_indices: List[int] = []
|
||||||
|
batched_logprobs_query_token_indices: List[int] = []
|
||||||
|
sample_idx = 0
|
||||||
|
for seq_group_id, seq_group, sample_result in zip(
|
||||||
|
seq_group_ids, seq_groups, sample_results):
|
||||||
|
seq_ids, sampling_params = seq_group
|
||||||
|
next_token_ids, parent_ids = sample_result
|
||||||
num_parent_seqs = len(seq_ids)
|
num_parent_seqs = len(seq_ids)
|
||||||
prob = probs[idx:idx + num_parent_seqs]
|
batched_logprobs_query_seq_indices.extend(
|
||||||
logprob = logprobs[idx:idx + num_parent_seqs]
|
[sample_idx + parent_id for parent_id in parent_ids])
|
||||||
idx += num_parent_seqs
|
batched_logprobs_query_token_indices.extend(next_token_ids)
|
||||||
|
sample_idx += num_parent_seqs
|
||||||
|
assert sample_idx == num_tokens
|
||||||
|
batched_logprobs_query_result = category_logprobs[[
|
||||||
|
batched_logprobs_query_seq_indices,
|
||||||
|
batched_logprobs_query_token_indices
|
||||||
|
]].tolist()
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Build the sequence outputs.
|
||||||
seq_logprobs = [
|
sample_idx = 0
|
||||||
input_metadata.seq_data[seq_id].cumulative_logprob
|
result_idx = 0
|
||||||
for seq_id in seq_ids
|
for seq_group_id, seq_group, sample_result in zip(
|
||||||
]
|
seq_group_ids, seq_groups, sample_results):
|
||||||
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
seq_ids, sampling_params = seq_group
|
||||||
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
next_token_ids, parent_ids = sample_result
|
||||||
|
num_results = len(next_token_ids)
|
||||||
|
num_parent_seqs = len(seq_ids)
|
||||||
|
parent_logprobs = category_logprobs[sample_idx:sample_idx +
|
||||||
|
num_parent_seqs]
|
||||||
|
selected_token_logprobs = batched_logprobs_query_result[
|
||||||
|
result_idx:result_idx + num_results]
|
||||||
|
seq_output = _build_sequence_outputs(parent_ids, next_token_ids,
|
||||||
|
selected_token_logprobs,
|
||||||
|
seq_ids, parent_logprobs,
|
||||||
|
sampling_params.logprobs)
|
||||||
|
seq_outputs_dict[seq_group_id] = seq_output
|
||||||
|
sample_idx += num_parent_seqs
|
||||||
|
result_idx += num_results
|
||||||
|
assert sample_idx == num_tokens
|
||||||
|
category_start_idx += num_tokens
|
||||||
|
|
||||||
# Get top-k log probabilities for the next tokens.
|
return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
|
||||||
next_logprobs: Dict[int, Dict[int, float]] = {}
|
|
||||||
for j, seq_id in enumerate(seq_ids):
|
|
||||||
next_logprobs[seq_id] = _get_topk_logprobs(
|
|
||||||
logprob[j], sampling_params.logprobs)
|
|
||||||
|
|
||||||
# Build the output.
|
|
||||||
for parent_seq_id, next_token_id in zip(parent_seq_ids,
|
|
||||||
next_token_ids):
|
|
||||||
j = seq_ids.index(parent_seq_id)
|
|
||||||
output_logprobs = next_logprobs[parent_seq_id].copy()
|
|
||||||
output_logprobs[next_token_id] = logprob[j,
|
|
||||||
next_token_id].item()
|
|
||||||
seq_group_outputs.append(
|
|
||||||
SequenceOutputs(parent_seq_id, next_token_id,
|
|
||||||
output_logprobs))
|
|
||||||
seq_outputs.append(seq_group_outputs)
|
|
||||||
|
|
||||||
return seq_outputs
|
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.model_executor.models import * # pylint: disable=wildcard-import
|
from vllm.model_executor.models import * # pylint: disable=wildcard-import
|
||||||
from vllm.model_executor.weight_utils import initialize_dummy_weights
|
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||||
|
initialize_dummy_weights)
|
||||||
|
|
||||||
# TODO(woosuk): Lazy-load the model classes.
|
# TODO(woosuk): Lazy-load the model classes.
|
||||||
_MODEL_REGISTRY = {
|
_MODEL_REGISTRY = {
|
||||||
@@ -24,12 +25,18 @@ _MODEL_REGISTRY = {
|
|||||||
"InternLMForCausalLM": InternLMForCausalLM,
|
"InternLMForCausalLM": InternLMForCausalLM,
|
||||||
"LlamaForCausalLM": LlamaForCausalLM,
|
"LlamaForCausalLM": LlamaForCausalLM,
|
||||||
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
|
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
|
||||||
|
"MistralForCausalLM": MistralForCausalLM,
|
||||||
"MPTForCausalLM": MPTForCausalLM,
|
"MPTForCausalLM": MPTForCausalLM,
|
||||||
"OPTForCausalLM": OPTForCausalLM,
|
"OPTForCausalLM": OPTForCausalLM,
|
||||||
"QWenLMHeadModel": QWenLMHeadModel,
|
"QWenLMHeadModel": QWenLMHeadModel,
|
||||||
"RWForCausalLM": FalconForCausalLM,
|
"RWForCausalLM": FalconForCausalLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# FIXME(woosuk): Remove this once all models support quantization.
|
||||||
|
_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
|
||||||
|
LlamaForCausalLM,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _set_default_torch_dtype(dtype: torch.dtype):
|
def _set_default_torch_dtype(dtype: torch.dtype):
|
||||||
@@ -52,10 +59,38 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
|||||||
|
|
||||||
def get_model(model_config: ModelConfig) -> nn.Module:
|
def get_model(model_config: ModelConfig) -> nn.Module:
|
||||||
model_class = _get_model_architecture(model_config.hf_config)
|
model_class = _get_model_architecture(model_config.hf_config)
|
||||||
|
|
||||||
|
# Get the quantization config.
|
||||||
|
quant_config = None
|
||||||
|
if model_config.quantization is not None:
|
||||||
|
if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
|
||||||
|
raise ValueError(
|
||||||
|
f"Quantization is not supported for {model_class}.")
|
||||||
|
quant_config = get_quant_config(model_config.quantization,
|
||||||
|
model_config.model,
|
||||||
|
model_config.download_dir)
|
||||||
|
capability = torch.cuda.get_device_capability()
|
||||||
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
if capability < quant_config.get_min_capability():
|
||||||
|
raise ValueError(
|
||||||
|
f"The quantization method {model_config.quantization} is not "
|
||||||
|
"supported for the current GPU. "
|
||||||
|
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||||
|
f"Current capability: {capability}.")
|
||||||
|
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||||
|
if model_config.dtype not in supported_dtypes:
|
||||||
|
raise ValueError(
|
||||||
|
f"{model_config.dtype} is not supported for quantization "
|
||||||
|
f"method {model_config.quantization}. Supported dtypes: "
|
||||||
|
f"{supported_dtypes}")
|
||||||
|
|
||||||
with _set_default_torch_dtype(model_config.dtype):
|
with _set_default_torch_dtype(model_config.dtype):
|
||||||
# Create a model instance.
|
# Create a model instance.
|
||||||
# The weights will be initialized as empty tensors.
|
# The weights will be initialized as empty tensors.
|
||||||
model = model_class(model_config.hf_config)
|
if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
|
||||||
|
model = model_class(model_config.hf_config, quant_config)
|
||||||
|
else:
|
||||||
|
model = model_class(model_config.hf_config)
|
||||||
if model_config.load_format == "dummy":
|
if model_config.load_format == "dummy":
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
@@ -64,6 +99,6 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
|||||||
else:
|
else:
|
||||||
# Load the weights from the cached or downloaded files.
|
# Load the weights from the cached or downloaded files.
|
||||||
model.load_weights(model_config.model, model_config.download_dir,
|
model.load_weights(model_config.model, model_config.download_dir,
|
||||||
model_config.load_format)
|
model_config.load_format, model_config.revision)
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
|
|||||||
from vllm.model_executor.models.mpt import MPTForCausalLM
|
from vllm.model_executor.models.mpt import MPTForCausalLM
|
||||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||||
from vllm.model_executor.models.qwen import QWenLMHeadModel
|
from vllm.model_executor.models.qwen import QWenLMHeadModel
|
||||||
|
from vllm.model_executor.models.mistral import MistralForCausalLM
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AquilaForCausalLM",
|
"AquilaForCausalLM",
|
||||||
@@ -28,4 +29,5 @@ __all__ = [
|
|||||||
"MPTForCausalLM",
|
"MPTForCausalLM",
|
||||||
"OPTForCausalLM",
|
"OPTForCausalLM",
|
||||||
"QWenLMHeadModel",
|
"QWenLMHeadModel",
|
||||||
|
"MistralForCausalLM",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -105,6 +105,8 @@ class AquilaAttention(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -119,6 +121,8 @@ class AquilaAttention(nn.Module):
|
|||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
self.qkv_proj = ColumnParallelLinear(
|
self.qkv_proj = ColumnParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@@ -140,6 +144,8 @@ class AquilaAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
|
base=self.rope_theta,
|
||||||
|
max_position=self.max_position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -164,10 +170,15 @@ class AquilaDecoderLayer(nn.Module):
|
|||||||
def __init__(self, config: AquilaConfig):
|
def __init__(self, config: AquilaConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
self.self_attn = AquilaAttention(
|
self.self_attn = AquilaAttention(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_kv_heads=config.num_attention_heads,
|
num_kv_heads=config.num_attention_heads,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
)
|
)
|
||||||
self.mlp = AquilaMLP(
|
self.mlp = AquilaMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@@ -288,7 +299,8 @@ class AquilaForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||||
@@ -305,7 +317,7 @@ class AquilaForCausalLM(nn.Module):
|
|||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -111,6 +111,8 @@ class BaiChuanAttention(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
position_embedding: str,
|
position_embedding: str,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -122,6 +124,8 @@ class BaiChuanAttention(nn.Module):
|
|||||||
tensor_model_parallel_world_size)
|
tensor_model_parallel_world_size)
|
||||||
self.head_dim = hidden_size // self.total_num_heads
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
self.postion_embedding = position_embedding
|
self.postion_embedding = position_embedding
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
self.W_pack = ColumnParallelLinear(
|
self.W_pack = ColumnParallelLinear(
|
||||||
@@ -151,10 +155,13 @@ class BaiChuanAttention(nn.Module):
|
|||||||
scaling, alibi_slopes)
|
scaling, alibi_slopes)
|
||||||
else:
|
else:
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
self.attn = PagedAttentionWithRoPE(
|
||||||
self.head_dim,
|
self.num_heads,
|
||||||
self.scaling,
|
self.head_dim,
|
||||||
rotary_dim=self.head_dim)
|
self.scaling,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
base=self.rope_theta,
|
||||||
|
max_position=self.max_position_embeddings)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -183,10 +190,15 @@ class BaiChuanDecoderLayer(nn.Module):
|
|||||||
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
self.self_attn = BaiChuanAttention(
|
self.self_attn = BaiChuanAttention(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
position_embedding=position_embedding,
|
position_embedding=position_embedding,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
)
|
)
|
||||||
self.mlp = BaiChuanMLP(
|
self.mlp = BaiChuanMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@@ -303,13 +315,14 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -279,11 +279,12 @@ class BloomForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if name == "lm_head.weight":
|
if name == "lm_head.weight":
|
||||||
# Since hidden_states are parallelized, we need to
|
# Since hidden_states are parallelized, we need to
|
||||||
# load lm_head.weight in parallel.
|
# load lm_head.weight in parallel.
|
||||||
|
|||||||
@@ -161,12 +161,17 @@ class FalconAttention(nn.Module):
|
|||||||
"Rotary and alibi are mutually exclusive.")
|
"Rotary and alibi are mutually exclusive.")
|
||||||
|
|
||||||
if self.use_rotary:
|
if self.use_rotary:
|
||||||
# TODO(zhuohan): Pass in correct `max_position``
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
max_position_embeddings = getattr(config,
|
||||||
self.head_dim,
|
"max_position_embeddings", 8192)
|
||||||
self.inv_norm_factor,
|
self.attn = PagedAttentionWithRoPE(
|
||||||
rotary_dim=self.head_dim,
|
self.num_heads,
|
||||||
num_kv_heads=self.num_kv_heads)
|
self.head_dim,
|
||||||
|
self.inv_norm_factor,
|
||||||
|
base=rope_theta,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
num_kv_heads=self.num_kv_heads)
|
||||||
elif self.use_alibi:
|
elif self.use_alibi:
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
head_start = tp_rank * self.num_heads
|
head_start = tp_rank * self.num_heads
|
||||||
@@ -420,7 +425,8 @@ class FalconForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tp_size = (get_tensor_model_parallel_world_size())
|
tp_size = (get_tensor_model_parallel_world_size())
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
@@ -452,7 +458,7 @@ class FalconForCausalLM(nn.Module):
|
|||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "query_key_value" in name:
|
if "query_key_value" in name:
|
||||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
loaded_weight_size = loaded_weight.size()
|
loaded_weight_size = loaded_weight.size()
|
||||||
|
|||||||
@@ -231,14 +231,15 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tensor_model_parallel_world_size = (
|
tensor_model_parallel_world_size = (
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name:
|
||||||
# GPT-2 ties the weights of the embedding layer and the final
|
# GPT-2 ties the weights of the embedding layer and the final
|
||||||
# linear layer.
|
# linear layer.
|
||||||
|
|||||||
@@ -259,14 +259,15 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tensor_model_parallel_world_size = (
|
tensor_model_parallel_world_size = (
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name:
|
||||||
# GPT-2 ties the weights of the embedding layer and the final
|
# GPT-2 ties the weights of the embedding layer and the final
|
||||||
# linear layer.
|
# linear layer.
|
||||||
|
|||||||
@@ -67,11 +67,17 @@ class GPTJAttention(nn.Module):
|
|||||||
scaling = self.head_size**-0.5
|
scaling = self.head_size**-0.5
|
||||||
assert getattr(config, "rotary", True)
|
assert getattr(config, "rotary", True)
|
||||||
assert config.rotary_dim % 2 == 0
|
assert config.rotary_dim % 2 == 0
|
||||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
self.head_size,
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
scaling,
|
8192)
|
||||||
config.rotary_dim,
|
self.attn = PagedAttentionWithRoPE(
|
||||||
is_neox_style=False)
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
scaling,
|
||||||
|
config.rotary_dim,
|
||||||
|
base=rope_theta,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
is_neox_style=False)
|
||||||
self.warmup = False
|
self.warmup = False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -222,11 +228,12 @@ class GPTJForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "attn.bias" in name or "attn.masked_bias" in name:
|
if "attn.bias" in name or "attn.masked_bias" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -68,8 +68,16 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
scaling = self.head_size**-0.5
|
scaling = self.head_size**-0.5
|
||||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||||
assert rotary_dim % 2 == 0
|
assert rotary_dim % 2 == 0
|
||||||
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
scaling, rotary_dim)
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
|
self.attn = PagedAttentionWithRoPE(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
scaling,
|
||||||
|
rotary_dim,
|
||||||
|
base=rope_theta,
|
||||||
|
max_position=max_position_embeddings)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -231,11 +239,12 @@ class GPTNeoXForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||||
or "rotary_emb.inv_freq" in name):
|
or "rotary_emb.inv_freq" in name):
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -59,6 +59,8 @@ class InternLMAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -70,6 +72,8 @@ class InternLMAttention(nn.Module):
|
|||||||
tensor_model_parallel_world_size)
|
tensor_model_parallel_world_size)
|
||||||
self.head_dim = hidden_size // self.total_num_heads
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
self.qkv_proj = ColumnParallelLinear(
|
self.qkv_proj = ColumnParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@@ -85,10 +89,13 @@ class InternLMAttention(nn.Module):
|
|||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
perform_initialization=False,
|
perform_initialization=False,
|
||||||
)
|
)
|
||||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
self.attn = PagedAttentionWithRoPE(
|
||||||
self.head_dim,
|
self.num_heads,
|
||||||
self.scaling,
|
self.head_dim,
|
||||||
rotary_dim=self.head_dim)
|
self.scaling,
|
||||||
|
base=self.rope_theta,
|
||||||
|
max_position=self.max_position_embeddings,
|
||||||
|
rotary_dim=self.head_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -112,9 +119,14 @@ class InternLMDecoderLayer(nn.Module):
|
|||||||
def __init__(self, config: LlamaConfig):
|
def __init__(self, config: LlamaConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
self.self_attn = InternLMAttention(
|
self.self_attn = InternLMAttention(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
)
|
)
|
||||||
self.mlp = InternLMMLP(
|
self.mlp = InternLMMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@@ -233,12 +245,13 @@ class InternLMForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -36,13 +36,15 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
|||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (
|
from vllm.model_executor.layers.quantized_linear import ParallelLinear
|
||||||
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
|
|
||||||
hf_model_weights_iterator)
|
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.quantization_utils import QuantizationConfig
|
||||||
|
from vllm.model_executor.weight_utils import (
|
||||||
|
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||||
|
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
@@ -55,18 +57,21 @@ class LlamaMLP(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
self.gate_up_proj = ParallelLinear.column(hidden_size,
|
||||||
2 * intermediate_size,
|
2 * intermediate_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
gather_output=False,
|
gather_output=False,
|
||||||
perform_initialization=False)
|
perform_initialization=False,
|
||||||
self.down_proj = RowParallelLinear(intermediate_size,
|
quant_config=quant_config)
|
||||||
hidden_size,
|
self.down_proj = ParallelLinear.row(intermediate_size,
|
||||||
bias=False,
|
hidden_size,
|
||||||
input_is_parallel=True,
|
bias=False,
|
||||||
perform_initialization=False)
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
"Only silu is supported for now.")
|
"Only silu is supported for now.")
|
||||||
@@ -87,7 +92,10 @@ class LlamaAttention(nn.Module):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
):
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@@ -102,28 +110,34 @@ class LlamaAttention(nn.Module):
|
|||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
self.qkv_proj = ColumnParallelLinear(
|
self.qkv_proj = ParallelLinear.column(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
gather_output=False,
|
gather_output=False,
|
||||||
perform_initialization=False,
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = ParallelLinear.row(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
perform_initialization=False,
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
self.attn = PagedAttentionWithRoPE(
|
||||||
self.head_dim,
|
self.num_heads,
|
||||||
self.scaling,
|
self.head_dim,
|
||||||
base=self.rope_theta,
|
self.scaling,
|
||||||
rotary_dim=self.head_dim,
|
base=self.rope_theta,
|
||||||
num_kv_heads=self.num_kv_heads)
|
max_position=self.max_position_embeddings,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
rope_scaling=rope_scaling)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -144,21 +158,32 @@ class LlamaAttention(nn.Module):
|
|||||||
|
|
||||||
class LlamaDecoderLayer(nn.Module):
|
class LlamaDecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
# Requires transformers > 4.32.0
|
# Requires transformers > 4.32.0
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
self.self_attn = LlamaAttention(
|
self.self_attn = LlamaAttention(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(
|
self.mlp = LlamaMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
@@ -195,7 +220,11 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
class LlamaModel(nn.Module):
|
class LlamaModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
@@ -205,7 +234,8 @@ class LlamaModel(nn.Module):
|
|||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
vocab_size, config.hidden_size, perform_initialization=False)
|
vocab_size, config.hidden_size, perform_initialization=False)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
LlamaDecoderLayer(config, quant_config)
|
||||||
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
@@ -237,16 +267,23 @@ class LlamaModel(nn.Module):
|
|||||||
|
|
||||||
class LlamaForCausalLM(nn.Module):
|
class LlamaForCausalLM(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = LlamaModel(config)
|
self.quant_config = quant_config
|
||||||
|
self.model = LlamaModel(config, quant_config)
|
||||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
# NOTE: The LM head is not quantized.
|
||||||
vocab_size,
|
self.lm_head = ParallelLinear.column(config.hidden_size,
|
||||||
bias=False,
|
vocab_size,
|
||||||
gather_output=False,
|
bias=False,
|
||||||
perform_initialization=False)
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=None)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -263,15 +300,28 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
_column_parallel_weights = [
|
_column_parallel_layers = []
|
||||||
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
_row_parallel_layers = ["o_proj", "down_proj"]
|
||||||
]
|
|
||||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
|
if self.quant_config is None:
|
||||||
|
weight_suffixes = ["weight"]
|
||||||
|
else:
|
||||||
|
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
||||||
|
|
||||||
|
column_parallel_weights: List[str] = []
|
||||||
|
for layer in self._column_parallel_layers:
|
||||||
|
for suffix in weight_suffixes:
|
||||||
|
column_parallel_weights.append(f"{layer}.{suffix}")
|
||||||
|
row_parallel_weights: List[str] = []
|
||||||
|
for layer in self._row_parallel_layers:
|
||||||
|
for suffix in weight_suffixes:
|
||||||
|
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||||
@@ -288,15 +338,30 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
is_packed = False
|
||||||
|
is_transposed = False
|
||||||
|
if self.quant_config is not None:
|
||||||
|
is_packed = self.quant_config.is_packed(name)
|
||||||
|
is_transposed = self.quant_config.is_transposed(name)
|
||||||
|
if is_transposed:
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
|
loaded_weight = loaded_weight.T
|
||||||
|
|
||||||
is_attention_weight = False
|
is_attention_weight = False
|
||||||
for weight_name, shard_size, offset in attention_weight_specs:
|
for weight_name, shard_size, offset in attention_weight_specs:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||||
|
if is_transposed:
|
||||||
|
param = param.T
|
||||||
|
|
||||||
|
if is_packed:
|
||||||
|
shard_size //= self.quant_config.pack_factor
|
||||||
|
offset //= self.quant_config.pack_factor
|
||||||
|
|
||||||
loaded_weight = loaded_weight[
|
loaded_weight = loaded_weight[
|
||||||
shard_size * tensor_model_parallel_rank:shard_size *
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
@@ -315,6 +380,9 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||||
|
if is_transposed:
|
||||||
|
param = param.T
|
||||||
|
|
||||||
shard_size = param.shape[0] // 2
|
shard_size = param.shape[0] // 2
|
||||||
loaded_weight = loaded_weight[
|
loaded_weight = loaded_weight[
|
||||||
shard_size * tensor_model_parallel_rank:shard_size *
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
@@ -329,6 +397,8 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
param = state_dict[name]
|
param = state_dict[name]
|
||||||
|
if is_transposed:
|
||||||
|
param = param.T
|
||||||
|
|
||||||
if "embed_tokens" in name or "lm_head" in name:
|
if "embed_tokens" in name or "lm_head" in name:
|
||||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
@@ -336,6 +406,6 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
self._column_parallel_weights,
|
column_parallel_weights,
|
||||||
self._row_parallel_weights,
|
row_parallel_weights,
|
||||||
tensor_model_parallel_rank)
|
tensor_model_parallel_rank)
|
||||||
|
|||||||
404
vllm/model_executor/models/mistral.py
Normal file
404
vllm/model_executor/models/mistral.py
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only LLaMA model compatible with HuggingFace weights.
|
||||||
|
|
||||||
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
|
"""
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.quantized_linear import ParallelLinear
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
|
VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.quantization_utils import QuantizationConfig
|
||||||
|
from vllm.model_executor.weight_utils import (
|
||||||
|
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||||
|
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
from vllm.transformers_utils.configs.mistral import MistralConfig
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class MistralMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = ParallelLinear.column(hidden_size,
|
||||||
|
2 * intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.down_proj = ParallelLinear.row(intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MistralAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
max_position: int = 4096 * 32,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
sliding_window: Optional[int] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||||
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
self.qkv_proj = ParallelLinear.column(
|
||||||
|
hidden_size,
|
||||||
|
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||||
|
self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.o_proj = ParallelLinear.row(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
base=self.rope_theta,
|
||||||
|
max_position=max_position,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
sliding_window=self.sliding_window)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||||
|
input_metadata, cache_event)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class MistralDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MistralConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
# Requires transformers > 4.32.0
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
self.self_attn = MistralAttention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
max_position=config.max_position_embeddings,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
quant_config=quant_config,
|
||||||
|
sliding_window=config.sliding_window)
|
||||||
|
self.mlp = MistralMLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Self Attention
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_event=cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MistralModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MistralConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
vocab_size, config.hidden_size, perform_initialization=False)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
MistralDecoderLayer(config, quant_config)
|
||||||
|
for _ in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
if cache_events is None:
|
||||||
|
cache_event = None
|
||||||
|
else:
|
||||||
|
cache_event = cache_events[i]
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
input_metadata,
|
||||||
|
cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MistralForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MistralConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = MistralModel(config, quant_config)
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
# NOTE: The LM head is not quantized.
|
||||||
|
self.lm_head = ParallelLinear.column(config.hidden_size,
|
||||||
|
vocab_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=None)
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
|
input_metadata, cache_events)
|
||||||
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
|
input_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
_column_parallel_layers = []
|
||||||
|
_row_parallel_layers = ["o_proj", "down_proj"]
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
|
if self.quant_config is None:
|
||||||
|
weight_suffixes = ["weight"]
|
||||||
|
else:
|
||||||
|
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
||||||
|
|
||||||
|
column_parallel_weights: List[str] = []
|
||||||
|
for layer in self._column_parallel_layers:
|
||||||
|
for suffix in weight_suffixes:
|
||||||
|
column_parallel_weights.append(f"{layer}.{suffix}")
|
||||||
|
row_parallel_weights: List[str] = []
|
||||||
|
for layer in self._row_parallel_layers:
|
||||||
|
for suffix in weight_suffixes:
|
||||||
|
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||||
|
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||||
|
kv_proj_shard_size = (self.config.hidden_size //
|
||||||
|
self.config.num_attention_heads *
|
||||||
|
self.config.num_key_value_heads // tp_size)
|
||||||
|
attention_weight_specs = [
|
||||||
|
# (weight_name, shard_size, offset)
|
||||||
|
("q_proj", q_proj_shard_size, 0),
|
||||||
|
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
||||||
|
("v_proj", kv_proj_shard_size,
|
||||||
|
q_proj_shard_size + kv_proj_shard_size),
|
||||||
|
]
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_packed = False
|
||||||
|
is_transposed = False
|
||||||
|
if self.quant_config is not None:
|
||||||
|
is_packed = self.quant_config.is_packed(name)
|
||||||
|
is_transposed = self.quant_config.is_transposed(name)
|
||||||
|
if is_transposed:
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
|
loaded_weight = loaded_weight.T
|
||||||
|
|
||||||
|
is_attention_weight = False
|
||||||
|
for weight_name, shard_size, offset in attention_weight_specs:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||||
|
if is_transposed:
|
||||||
|
param = param.T
|
||||||
|
|
||||||
|
if is_packed:
|
||||||
|
shard_size //= self.quant_config.pack_factor
|
||||||
|
offset //= self.quant_config.pack_factor
|
||||||
|
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
|
(tensor_model_parallel_rank + 1)]
|
||||||
|
param_slice = param.data[offset:offset + shard_size]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_attention_weight = True
|
||||||
|
break
|
||||||
|
if is_attention_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_gate_up_weight = False
|
||||||
|
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||||
|
if is_transposed:
|
||||||
|
param = param.T
|
||||||
|
|
||||||
|
shard_size = param.shape[0] // 2
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
|
(tensor_model_parallel_rank + 1)]
|
||||||
|
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||||
|
(stride_id + 1)]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_gate_up_weight = True
|
||||||
|
break
|
||||||
|
if is_gate_up_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = state_dict[name]
|
||||||
|
if is_transposed:
|
||||||
|
param = param.T
|
||||||
|
|
||||||
|
if "embed_tokens" in name or "lm_head" in name:
|
||||||
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
|
continue
|
||||||
|
|
||||||
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
|
column_parallel_weights,
|
||||||
|
row_parallel_weights,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
@@ -244,12 +244,13 @@ class MPTForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "Wqkv" in name:
|
if "Wqkv" in name:
|
||||||
# NOTE(woosuk): MPT's fused QKV has the shape of
|
# NOTE(woosuk): MPT's fused QKV has the shape of
|
||||||
# [3 * num_heads * head_size, hidden_size].
|
# [3 * num_heads * head_size, hidden_size].
|
||||||
|
|||||||
@@ -297,12 +297,13 @@ class OPTForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -76,8 +76,12 @@ class QWenMLP(nn.Module):
|
|||||||
|
|
||||||
class QWenAttention(nn.Module):
|
class QWenAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, num_heads: int,
|
def __init__(self,
|
||||||
max_position_embeddings: int):
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
||||||
@@ -109,8 +113,9 @@ class QWenAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
|
base=rope_theta,
|
||||||
max_position=max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
)
|
rope_scaling=rope_scaling)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -135,14 +140,19 @@ class QWenBlock(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, config: QWenConfig):
|
def __init__(self, config: QWenConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
self.attn = QWenAttention(config.n_embd, config.num_attention_heads,
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
config.max_position_embeddings)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
self.attn = QWenAttention(config.hidden_size,
|
||||||
|
config.num_attention_heads,
|
||||||
|
config.max_position_embeddings,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
rope_scaling=rope_scaling)
|
||||||
|
|
||||||
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
self.mlp = QWenMLP(config.n_embd, config.ffn_hidden_size // 2)
|
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -181,11 +191,11 @@ class QWenModel(nn.Module):
|
|||||||
|
|
||||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
self.wte = VocabParallelEmbedding(vocab_size,
|
self.wte = VocabParallelEmbedding(vocab_size,
|
||||||
config.n_embd,
|
config.hidden_size,
|
||||||
perform_initialization=False)
|
perform_initialization=False)
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
|
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
|
||||||
self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -221,7 +231,7 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
self.transformer = QWenModel(config)
|
self.transformer = QWenModel(config)
|
||||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
self.lm_head = ColumnParallelLinear(
|
self.lm_head = ColumnParallelLinear(
|
||||||
config.n_embd,
|
config.hidden_size,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
gather_output=False,
|
gather_output=False,
|
||||||
@@ -251,13 +261,14 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto",
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
):
|
):
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
# Parts of the code here are adapted from PyTorch
|
# Parts of the code here are adapted from PyTorch
|
||||||
# repo: https://github.com/pytorch/pytorch
|
# repo: https://github.com/pytorch/pytorch
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -16,13 +16,11 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from .mappings import (
|
from .mappings import (
|
||||||
copy_to_tensor_model_parallel_region,
|
|
||||||
gather_from_tensor_model_parallel_region,
|
gather_from_tensor_model_parallel_region,
|
||||||
reduce_from_tensor_model_parallel_region,
|
reduce_from_tensor_model_parallel_region,
|
||||||
scatter_to_tensor_model_parallel_region,
|
scatter_to_tensor_model_parallel_region,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .random import get_cuda_rng_tracker
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
divide,
|
divide,
|
||||||
VocabUtility,
|
VocabUtility,
|
||||||
@@ -65,59 +63,6 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
|
|||||||
maybe_copy(attribute)
|
maybe_copy(attribute)
|
||||||
|
|
||||||
|
|
||||||
def _initialize_affine_weight_gpu(weight, init_method,
|
|
||||||
partition_dim, stride=1):
|
|
||||||
"""Initialize affine weight for model parallel on GPU."""
|
|
||||||
|
|
||||||
set_tensor_model_parallel_attributes(tensor=weight,
|
|
||||||
is_parallel=True,
|
|
||||||
dim=partition_dim,
|
|
||||||
stride=stride)
|
|
||||||
|
|
||||||
with get_cuda_rng_tracker().fork():
|
|
||||||
init_method(weight)
|
|
||||||
|
|
||||||
|
|
||||||
def _initialize_affine_weight_cpu(weight, output_size, input_size,
|
|
||||||
per_partition_size, partition_dim,
|
|
||||||
init_method, stride=1,
|
|
||||||
return_master_weight=False,
|
|
||||||
*, params_dtype=None):
|
|
||||||
"""Initialize affine weight for model parallel.
|
|
||||||
|
|
||||||
Build the master weight on all processes and scatter
|
|
||||||
the relevant chunk."""
|
|
||||||
|
|
||||||
set_tensor_model_parallel_attributes(tensor=weight,
|
|
||||||
is_parallel=True,
|
|
||||||
dim=partition_dim,
|
|
||||||
stride=stride)
|
|
||||||
|
|
||||||
if params_dtype is None:
|
|
||||||
params_dtype = torch.get_default_dtype()
|
|
||||||
|
|
||||||
# Initialize master weight
|
|
||||||
master_weight = torch.empty(output_size, input_size,
|
|
||||||
dtype=torch.float,
|
|
||||||
requires_grad=False)
|
|
||||||
init_method(master_weight)
|
|
||||||
master_weight = master_weight.to(dtype=params_dtype)
|
|
||||||
|
|
||||||
# Split and copy
|
|
||||||
per_partition_per_stride_size = divide(per_partition_size, stride)
|
|
||||||
weight_list = torch.split(master_weight, per_partition_per_stride_size,
|
|
||||||
dim=partition_dim)
|
|
||||||
rank = get_tensor_model_parallel_rank()
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
my_weight_list = weight_list[rank::world_size]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
torch.cat(my_weight_list, dim=partition_dim, out=weight)
|
|
||||||
if return_master_weight:
|
|
||||||
return master_weight
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class VocabParallelEmbedding(torch.nn.Module):
|
class VocabParallelEmbedding(torch.nn.Module):
|
||||||
"""Embedding parallelized in the vocabulary dimension.
|
"""Embedding parallelized in the vocabulary dimension.
|
||||||
|
|
||||||
@@ -138,8 +83,11 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
init_method=init.xavier_normal_,
|
init_method=init.xavier_normal_,
|
||||||
params_dtype: torch.dtype=None,
|
params_dtype: torch.dtype=None,
|
||||||
use_cpu_initialization: bool=False,
|
use_cpu_initialization: bool=False,
|
||||||
perform_initialization: bool=True):
|
perform_initialization: bool=False):
|
||||||
super(VocabParallelEmbedding, self).__init__()
|
super(VocabParallelEmbedding, self).__init__()
|
||||||
|
assert not perform_initialization
|
||||||
|
assert not use_cpu_initialization
|
||||||
|
|
||||||
# Keep the input dimensions.
|
# Keep the input dimensions.
|
||||||
self.num_embeddings = num_embeddings
|
self.num_embeddings = num_embeddings
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
@@ -162,23 +110,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
self.num_embeddings_per_partition = self.vocab_end_index - \
|
self.num_embeddings_per_partition = self.vocab_end_index - \
|
||||||
self.vocab_start_index
|
self.vocab_start_index
|
||||||
|
|
||||||
# Allocate weights and initialize.
|
self.weight = Parameter(torch.empty(
|
||||||
if use_cpu_initialization:
|
self.num_embeddings_per_partition, self.embedding_dim,
|
||||||
self.weight = Parameter(torch.empty(
|
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||||
self.num_embeddings_per_partition, self.embedding_dim,
|
|
||||||
dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_cpu(
|
|
||||||
self.weight, self.num_embeddings, self.embedding_dim,
|
|
||||||
self.num_embeddings_per_partition, 0, init_method,
|
|
||||||
params_dtype=params_dtype)
|
|
||||||
else:
|
|
||||||
self.weight = Parameter(torch.empty(
|
|
||||||
self.num_embeddings_per_partition, self.embedding_dim,
|
|
||||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
|
||||||
partition_dim=0, stride=1)
|
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
if self.tensor_model_parallel_size > 1:
|
if self.tensor_model_parallel_size > 1:
|
||||||
@@ -238,9 +172,12 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
skip_bias_add=False,
|
skip_bias_add=False,
|
||||||
params_dtype=None,
|
params_dtype=None,
|
||||||
use_cpu_initialization=False,
|
use_cpu_initialization=False,
|
||||||
perform_initialization=True,
|
perform_initialization=False,
|
||||||
|
quant_config=None,
|
||||||
):
|
):
|
||||||
super(ColumnParallelLinear, self).__init__()
|
super(ColumnParallelLinear, self).__init__()
|
||||||
|
assert not perform_initialization
|
||||||
|
assert not use_cpu_initialization
|
||||||
|
|
||||||
# Keep input parameters
|
# Keep input parameters
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
@@ -250,6 +187,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
self.world_size = get_tensor_model_parallel_world_size()
|
self.world_size = get_tensor_model_parallel_world_size()
|
||||||
self.output_size_per_partition = divide(output_size, self.world_size)
|
self.output_size_per_partition = divide(output_size, self.world_size)
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
@@ -257,33 +195,13 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
# Parameters.
|
# Parameters.
|
||||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||||
# we allocate the transpose.
|
# we allocate the transpose.
|
||||||
# Initialize weight.
|
self.create_weights(params_dtype)
|
||||||
if use_cpu_initialization:
|
|
||||||
self.weight = Parameter(torch.empty(self.output_size_per_partition,
|
|
||||||
self.input_size,
|
|
||||||
dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
self.master_weight = _initialize_affine_weight_cpu(
|
|
||||||
self.weight, self.output_size, self.input_size,
|
|
||||||
self.output_size_per_partition, 0, init_method,
|
|
||||||
stride=stride, return_master_weight=keep_master_weight_for_test)
|
|
||||||
else:
|
|
||||||
self.weight = Parameter(torch.empty(
|
|
||||||
self.output_size_per_partition, self.input_size,
|
|
||||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
|
||||||
partition_dim=0, stride=stride)
|
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
if use_cpu_initialization:
|
self.bias = Parameter(torch.empty(
|
||||||
self.bias = Parameter(torch.empty(
|
self.output_size_per_partition,
|
||||||
self.output_size_per_partition, dtype=params_dtype))
|
device=torch.cuda.current_device(),
|
||||||
else:
|
dtype=params_dtype))
|
||||||
self.bias = Parameter(torch.empty(
|
|
||||||
self.output_size_per_partition,
|
|
||||||
device=torch.cuda.current_device(),
|
|
||||||
dtype=params_dtype))
|
|
||||||
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
|
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
|
||||||
# Always initialize bias to zero.
|
# Always initialize bias to zero.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -291,6 +209,17 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter('bias', None)
|
self.register_parameter('bias', None)
|
||||||
|
|
||||||
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
|
self.weight = Parameter(torch.empty(
|
||||||
|
self.output_size_per_partition, self.input_size,
|
||||||
|
device=torch.cuda.current_device(), dtype=dtype))
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return F.linear(x, self.weight, bias)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
"""Forward of ColumnParallelLinear
|
"""Forward of ColumnParallelLinear
|
||||||
@@ -306,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
|
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
output_parallel = F.linear(input_parallel, self.weight, bias)
|
output_parallel = self.apply_weights(input_parallel, bias)
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
output = gather_from_tensor_model_parallel_region(output_parallel)
|
output = gather_from_tensor_model_parallel_region(output_parallel)
|
||||||
@@ -359,10 +288,13 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
skip_bias_add=False,
|
skip_bias_add=False,
|
||||||
params_dtype=None,
|
params_dtype=None,
|
||||||
use_cpu_initialization=False,
|
use_cpu_initialization=False,
|
||||||
perform_initialization=True,
|
perform_initialization=False,
|
||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
|
quant_config=None,
|
||||||
):
|
):
|
||||||
super(RowParallelLinear, self).__init__()
|
super(RowParallelLinear, self).__init__()
|
||||||
|
assert not perform_initialization
|
||||||
|
assert not use_cpu_initialization
|
||||||
|
|
||||||
# Keep input parameters
|
# Keep input parameters
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
@@ -376,47 +308,32 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
self.world_size = get_tensor_model_parallel_world_size()
|
self.world_size = get_tensor_model_parallel_world_size()
|
||||||
self.input_size_per_partition = divide(input_size, self.world_size)
|
self.input_size_per_partition = divide(input_size, self.world_size)
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
self.create_weights(params_dtype)
|
||||||
|
|
||||||
if not reduce_results and (bias and not skip_bias_add):
|
if not reduce_results and (bias and not skip_bias_add):
|
||||||
raise ValueError("When not reduce the results, adding bias to the "
|
raise ValueError("When not reduce the results, adding bias to the "
|
||||||
"results can lead to incorrect results")
|
"results can lead to incorrect results")
|
||||||
|
|
||||||
# Parameters.
|
|
||||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
|
||||||
# we allocate the transpose.
|
|
||||||
# Initialize weight.
|
|
||||||
if use_cpu_initialization:
|
|
||||||
self.weight = Parameter(torch.empty(self.output_size,
|
|
||||||
self.input_size_per_partition,
|
|
||||||
dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
self.master_weight = _initialize_affine_weight_cpu(
|
|
||||||
self.weight, self.output_size, self.input_size,
|
|
||||||
self.input_size_per_partition, 1, init_method,
|
|
||||||
stride=stride, return_master_weight=keep_master_weight_for_test,
|
|
||||||
params_dtype=params_dtype)
|
|
||||||
else:
|
|
||||||
self.weight = Parameter(torch.empty(
|
|
||||||
self.output_size, self.input_size_per_partition,
|
|
||||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
|
||||||
partition_dim=1, stride=stride)
|
|
||||||
if bias:
|
if bias:
|
||||||
if use_cpu_initialization:
|
self.bias = Parameter(torch.empty(
|
||||||
self.bias = Parameter(torch.empty(self.output_size,
|
self.output_size, device=torch.cuda.current_device(),
|
||||||
dtype=params_dtype))
|
dtype=params_dtype))
|
||||||
else:
|
|
||||||
self.bias = Parameter(torch.empty(
|
|
||||||
self.output_size, device=torch.cuda.current_device(),
|
|
||||||
dtype=params_dtype))
|
|
||||||
|
|
||||||
# Always initialize bias to zero.
|
# Always initialize bias to zero.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.bias.zero_()
|
self.bias.zero_()
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias', None)
|
self.register_parameter('bias', None)
|
||||||
self.weight_t = self.weight.t()
|
|
||||||
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
|
self.weight = Parameter(torch.empty(
|
||||||
|
self.output_size, self.input_size_per_partition,
|
||||||
|
device=torch.cuda.current_device(), dtype=dtype))
|
||||||
|
|
||||||
|
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.linear(x, self.weight)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
"""Forward of RowParallelLinear
|
"""Forward of RowParallelLinear
|
||||||
@@ -434,7 +351,7 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
output_parallel = F.linear(input_parallel, self.weight)
|
output_parallel = self.apply_weights(input_parallel)
|
||||||
if self.reduce_results and self.world_size > 1:
|
if self.reduce_results and self.world_size > 1:
|
||||||
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||||
else:
|
else:
|
||||||
|
|||||||
20
vllm/model_executor/quantization_utils/__init__.py
Normal file
20
vllm/model_executor/quantization_utils/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from vllm.model_executor.quantization_utils.awq import AWQConfig
|
||||||
|
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||||
|
|
||||||
|
_QUANTIZATION_REGISTRY = {
|
||||||
|
"awq": AWQConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_quant_class(quantization: str) -> Type[QuantizationConfig]:
|
||||||
|
if quantization not in _QUANTIZATION_REGISTRY:
|
||||||
|
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||||
|
return _QUANTIZATION_REGISTRY[quantization]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"QuantizationConfig",
|
||||||
|
"get_quant_class",
|
||||||
|
]
|
||||||
72
vllm/model_executor/quantization_utils/awq.py
Normal file
72
vllm/model_executor/quantization_utils/awq.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AWQConfig(QuantizationConfig):
|
||||||
|
"""Config class for AWQ.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2306.00978
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight_bits: int,
|
||||||
|
group_size: int,
|
||||||
|
zero_point: bool,
|
||||||
|
) -> None:
|
||||||
|
self.weight_bits = weight_bits
|
||||||
|
self.group_size = group_size
|
||||||
|
self.zero_point = zero_point
|
||||||
|
|
||||||
|
if self.weight_bits != 4:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently, only 4-bit weight quantization is supported for "
|
||||||
|
f"AWQ, but got {self.weight_bits} bits.")
|
||||||
|
self.pack_factor = 32 // self.weight_bits
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"AWQConfig(weight_bits={self.weight_bits}, "
|
||||||
|
f"group_size={self.group_size}, "
|
||||||
|
f"zero_point={self.zero_point})")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
return "awq"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.half]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
# The AWQ kernel only supports Ampere or newer GPUs.
|
||||||
|
return 80
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return [
|
||||||
|
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
||||||
|
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
||||||
|
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||||
|
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||||
|
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||||
|
return cls(weight_bits, group_size, zero_point)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_packed_tensor_names(cls) -> List[str]:
|
||||||
|
return ["qweight", "qzeros"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_transposed_tensor_names(cls) -> List[str]:
|
||||||
|
return ["qweight", "qzeros", "scales"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tp_tensor_names(cls) -> List[str]:
|
||||||
|
return ["qweight", "qzeros", "scales"]
|
||||||
75
vllm/model_executor/quantization_utils/base.py
Normal file
75
vllm/model_executor/quantization_utils/base.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationConfig:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
"""Name of the quantization method."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
"""List of supported activation dtypes."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
"""Minimum GPU capability to support the quantization method.
|
||||||
|
|
||||||
|
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||||
|
This requirement is due to the custom CUDA kernels used by the
|
||||||
|
quantization method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
"""List of filenames to search for in the model directory."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
||||||
|
"""Create a config class from the model's quantization config."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||||
|
"""Get a value from the model's quantization config."""
|
||||||
|
for key in keys:
|
||||||
|
if key in config:
|
||||||
|
return config[key]
|
||||||
|
raise ValueError(f"Cannot find any of {keys} in the model's "
|
||||||
|
"quantization config.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_packed_tensor_names(cls) -> List[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_packed(cls, tensor_name: str) -> bool:
|
||||||
|
"""Returns True if a tensor is packed.
|
||||||
|
|
||||||
|
A tensor is considered packed if each element in the tensor is a
|
||||||
|
packed representation of multiple elements in the original tensor.
|
||||||
|
For example, an INT32 element in the tensor may represent 8 INT4
|
||||||
|
elements in the original tensor.
|
||||||
|
"""
|
||||||
|
return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_transposed_tensor_names(cls) -> List[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_transposed(cls, tensor_name: str) -> bool:
|
||||||
|
"""Returns True if a tensor is transposed relative to nn.Linear.weight.
|
||||||
|
"""
|
||||||
|
return any(tag in tensor_name
|
||||||
|
for tag in cls.get_transposed_tensor_names())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tp_tensor_names(cls) -> List[str]:
|
||||||
|
raise NotImplementedError
|
||||||
@@ -4,7 +4,7 @@ import glob
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Iterator, List, Optional, Tuple, Any
|
from typing import Any, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from safetensors.torch import load_file, save_file, safe_open
|
from safetensors.torch import load_file, save_file, safe_open
|
||||||
@@ -13,6 +13,8 @@ import torch
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.quantization_utils import get_quant_class
|
||||||
|
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -44,7 +46,7 @@ def _shared_pointers(tensors):
|
|||||||
def convert_bin_to_safetensor_file(
|
def convert_bin_to_safetensor_file(
|
||||||
pt_filename: str,
|
pt_filename: str,
|
||||||
sf_filename: str,
|
sf_filename: str,
|
||||||
):
|
) -> None:
|
||||||
loaded = torch.load(pt_filename, map_location="cpu")
|
loaded = torch.load(pt_filename, map_location="cpu")
|
||||||
if "state_dict" in loaded:
|
if "state_dict" in loaded:
|
||||||
loaded = loaded["state_dict"]
|
loaded = loaded["state_dict"]
|
||||||
@@ -78,15 +80,55 @@ def convert_bin_to_safetensor_file(
|
|||||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(woosuk): Move this to other place.
|
||||||
|
def get_quant_config(
|
||||||
|
quantization: str,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
) -> QuantizationConfig:
|
||||||
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
|
if not is_local:
|
||||||
|
# Download the config files.
|
||||||
|
with get_lock(model_name_or_path, cache_dir):
|
||||||
|
hf_folder = snapshot_download(model_name_or_path,
|
||||||
|
allow_patterns="*.json",
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
tqdm_class=Disabledtqdm)
|
||||||
|
else:
|
||||||
|
hf_folder = model_name_or_path
|
||||||
|
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||||
|
|
||||||
|
quant_cls = get_quant_class(quantization)
|
||||||
|
quant_config_files = [
|
||||||
|
f for f in config_files if any(
|
||||||
|
f.endswith(x) for x in quant_cls.get_config_filenames())
|
||||||
|
]
|
||||||
|
if len(quant_config_files) == 0:
|
||||||
|
raise ValueError(f"Cannot find the config file for {quantization}")
|
||||||
|
if len(quant_config_files) > 1:
|
||||||
|
raise ValueError(f"Found multiple config files for {quantization}: "
|
||||||
|
f"{quant_config_files}")
|
||||||
|
|
||||||
|
quant_config_file = quant_config_files[0]
|
||||||
|
with open(quant_config_file, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return quant_cls.from_config(config)
|
||||||
|
|
||||||
|
|
||||||
def prepare_hf_model_weights(
|
def prepare_hf_model_weights(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
fall_back_to_pt: bool = True,
|
fall_back_to_pt: bool = True,
|
||||||
):
|
revision: Optional[str] = None,
|
||||||
|
) -> Tuple[str, List[str], bool]:
|
||||||
# Download model weights from huggingface.
|
# Download model weights from huggingface.
|
||||||
is_local = os.path.isdir(model_name_or_path)
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
|
if use_safetensors:
|
||||||
|
allow_patterns = ["*.safetensors"]
|
||||||
|
else:
|
||||||
|
# Some quantized models use .pt files for storing the weights.
|
||||||
|
allow_patterns = ["*.bin", "*.pt"]
|
||||||
if not is_local:
|
if not is_local:
|
||||||
# Use file lock to prevent multiple processes from
|
# Use file lock to prevent multiple processes from
|
||||||
# downloading the same model weights at the same time.
|
# downloading the same model weights at the same time.
|
||||||
@@ -94,10 +136,13 @@ def prepare_hf_model_weights(
|
|||||||
hf_folder = snapshot_download(model_name_or_path,
|
hf_folder = snapshot_download(model_name_or_path,
|
||||||
allow_patterns=allow_patterns,
|
allow_patterns=allow_patterns,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
tqdm_class=Disabledtqdm)
|
tqdm_class=Disabledtqdm,
|
||||||
|
revision=revision)
|
||||||
else:
|
else:
|
||||||
hf_folder = model_name_or_path
|
hf_folder = model_name_or_path
|
||||||
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
|
hf_weights_files: List[str] = []
|
||||||
|
for pattern in allow_patterns:
|
||||||
|
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||||
if not use_safetensors:
|
if not use_safetensors:
|
||||||
hf_weights_files = [
|
hf_weights_files = [
|
||||||
x for x in hf_weights_files if not x.endswith("training_args.bin")
|
x for x in hf_weights_files if not x.endswith("training_args.bin")
|
||||||
@@ -107,7 +152,8 @@ def prepare_hf_model_weights(
|
|||||||
return prepare_hf_model_weights(model_name_or_path,
|
return prepare_hf_model_weights(model_name_or_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
use_safetensors=False,
|
use_safetensors=False,
|
||||||
fall_back_to_pt=False)
|
fall_back_to_pt=False,
|
||||||
|
revision=revision)
|
||||||
|
|
||||||
if len(hf_weights_files) == 0:
|
if len(hf_weights_files) == 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -120,6 +166,7 @@ def hf_model_weights_iterator(
|
|||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto",
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||||
use_safetensors = False
|
use_safetensors = False
|
||||||
use_np_cache = False
|
use_np_cache = False
|
||||||
@@ -140,7 +187,8 @@ def hf_model_weights_iterator(
|
|||||||
model_name_or_path,
|
model_name_or_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
use_safetensors=use_safetensors,
|
use_safetensors=use_safetensors,
|
||||||
fall_back_to_pt=fall_back_to_pt)
|
fall_back_to_pt=fall_back_to_pt,
|
||||||
|
revision=revision)
|
||||||
|
|
||||||
if use_np_cache:
|
if use_np_cache:
|
||||||
# Currently np_cache only support *.bin checkpoints
|
# Currently np_cache only support *.bin checkpoints
|
||||||
|
|||||||
@@ -1,9 +1,17 @@
|
|||||||
"""Sampling parameters for text generation."""
|
"""Sampling parameters for text generation."""
|
||||||
|
from enum import IntEnum
|
||||||
|
from functools import cached_property
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
|
|
||||||
|
|
||||||
|
class SamplingType(IntEnum):
|
||||||
|
GREEDY = 0
|
||||||
|
RANDOM = 1
|
||||||
|
BEAM = 2
|
||||||
|
|
||||||
|
|
||||||
class SamplingParams:
|
class SamplingParams:
|
||||||
"""Sampling parameters for text generation.
|
"""Sampling parameters for text generation.
|
||||||
|
|
||||||
@@ -45,10 +53,15 @@ class SamplingParams:
|
|||||||
(canonical beam search algorithm).
|
(canonical beam search algorithm).
|
||||||
stop: List of strings that stop the generation when they are generated.
|
stop: List of strings that stop the generation when they are generated.
|
||||||
The returned output will not contain the stop strings.
|
The returned output will not contain the stop strings.
|
||||||
|
stop_token_ids: List of tokens that stop the generation when they are
|
||||||
|
generated. The returned output will contain the stop tokens unless
|
||||||
|
the stop tokens are sepcial tokens.
|
||||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||||
tokens after the EOS token is generated.
|
tokens after the EOS token is generated.
|
||||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||||
logprobs: Number of log probabilities to return per output token.
|
logprobs: Number of log probabilities to return per output token.
|
||||||
|
skip_special_tokens: Whether to skip special tokens in the output.
|
||||||
|
Defaults to true.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -64,9 +77,11 @@ class SamplingParams:
|
|||||||
length_penalty: float = 1.0,
|
length_penalty: float = 1.0,
|
||||||
early_stopping: Union[bool, str] = False,
|
early_stopping: Union[bool, str] = False,
|
||||||
stop: Union[None, str, List[str]] = None,
|
stop: Union[None, str, List[str]] = None,
|
||||||
|
stop_token_ids: List[int] = None,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
max_tokens: int = 16,
|
max_tokens: int = 16,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
|
skip_special_tokens: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.n = n
|
self.n = n
|
||||||
self.best_of = best_of if best_of is not None else n
|
self.best_of = best_of if best_of is not None else n
|
||||||
@@ -84,9 +99,14 @@ class SamplingParams:
|
|||||||
self.stop = [stop]
|
self.stop = [stop]
|
||||||
else:
|
else:
|
||||||
self.stop = list(stop)
|
self.stop = list(stop)
|
||||||
|
if stop_token_ids is None:
|
||||||
|
self.stop_token_ids = []
|
||||||
|
else:
|
||||||
|
self.stop_token_ids = list(stop_token_ids)
|
||||||
self.ignore_eos = ignore_eos
|
self.ignore_eos = ignore_eos
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.logprobs = logprobs
|
self.logprobs = logprobs
|
||||||
|
self.skip_special_tokens = skip_special_tokens
|
||||||
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
if self.use_beam_search:
|
if self.use_beam_search:
|
||||||
@@ -158,6 +178,14 @@ class SamplingParams:
|
|||||||
if self.top_k != -1:
|
if self.top_k != -1:
|
||||||
raise ValueError("top_k must be -1 when using greedy sampling.")
|
raise ValueError("top_k must be -1 when using greedy sampling.")
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def sampling_type(self) -> SamplingType:
|
||||||
|
if self.use_beam_search:
|
||||||
|
return SamplingType.BEAM
|
||||||
|
if self.temperature < _SAMPLING_EPS:
|
||||||
|
return SamplingType.GREEDY
|
||||||
|
return SamplingType.RANDOM
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"SamplingParams(n={self.n}, "
|
return (f"SamplingParams(n={self.n}, "
|
||||||
f"best_of={self.best_of}, "
|
f"best_of={self.best_of}, "
|
||||||
@@ -172,4 +200,5 @@ class SamplingParams:
|
|||||||
f"stop={self.stop}, "
|
f"stop={self.stop}, "
|
||||||
f"ignore_eos={self.ignore_eos}, "
|
f"ignore_eos={self.ignore_eos}, "
|
||||||
f"max_tokens={self.max_tokens}, "
|
f"max_tokens={self.max_tokens}, "
|
||||||
f"logprobs={self.logprobs})")
|
f"logprobs={self.logprobs}, "
|
||||||
|
f"skip_special_tokens={self.skip_special_tokens})")
|
||||||
|
|||||||
@@ -114,7 +114,6 @@ class Sequence:
|
|||||||
|
|
||||||
self.data = SequenceData(prompt_token_ids)
|
self.data = SequenceData(prompt_token_ids)
|
||||||
self.output_logprobs: List[Dict[int, float]] = []
|
self.output_logprobs: List[Dict[int, float]] = []
|
||||||
self.output_tokens: List[str] = []
|
|
||||||
self.output_text = ""
|
self.output_text = ""
|
||||||
|
|
||||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||||
@@ -122,6 +121,12 @@ class Sequence:
|
|||||||
self._append_tokens_to_blocks(prompt_token_ids)
|
self._append_tokens_to_blocks(prompt_token_ids)
|
||||||
self.status = SequenceStatus.WAITING
|
self.status = SequenceStatus.WAITING
|
||||||
|
|
||||||
|
# Used for incremental detokenization
|
||||||
|
self.prefix_offset = 0
|
||||||
|
self.read_offset = 0
|
||||||
|
# Input + output tokens
|
||||||
|
self.tokens: Optional[List[str]] = None
|
||||||
|
|
||||||
def _append_logical_block(self) -> None:
|
def _append_logical_block(self) -> None:
|
||||||
block = LogicalTokenBlock(
|
block = LogicalTokenBlock(
|
||||||
block_number=len(self.logical_token_blocks),
|
block_number=len(self.logical_token_blocks),
|
||||||
@@ -245,8 +250,8 @@ class SequenceGroup:
|
|||||||
# generation stage, we will have `best_of` sequences running.
|
# generation stage, we will have `best_of` sequences running.
|
||||||
return self.sampling_params.best_of
|
return self.sampling_params.best_of
|
||||||
# At sampling stages, return the number of actual sequences
|
# At sampling stages, return the number of actual sequences
|
||||||
# running.
|
# that are not finished yet.
|
||||||
return self.num_seqs(status=SequenceStatus.RUNNING)
|
return self.num_unfinished_seqs()
|
||||||
|
|
||||||
def get_seqs(
|
def get_seqs(
|
||||||
self,
|
self,
|
||||||
@@ -259,12 +264,23 @@ class SequenceGroup:
|
|||||||
seq for seq in self.seqs_dict.values() if seq.status == status
|
seq for seq in self.seqs_dict.values() if seq.status == status
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def get_unfinished_seqs(self) -> List[Sequence]:
|
||||||
|
return [
|
||||||
|
seq for seq in self.seqs_dict.values() if not seq.is_finished()
|
||||||
|
]
|
||||||
|
|
||||||
def get_finished_seqs(self) -> List[Sequence]:
|
def get_finished_seqs(self) -> List[Sequence]:
|
||||||
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
||||||
|
|
||||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||||
return len(self.get_seqs(status))
|
return len(self.get_seqs(status))
|
||||||
|
|
||||||
|
def num_unfinished_seqs(self) -> int:
|
||||||
|
return len(self.get_unfinished_seqs())
|
||||||
|
|
||||||
|
def num_finished_seqs(self) -> int:
|
||||||
|
return len(self.get_finished_seqs())
|
||||||
|
|
||||||
def find(self, seq_id: int) -> Sequence:
|
def find(self, seq_id: int) -> Sequence:
|
||||||
if seq_id not in self.seqs_dict:
|
if seq_id not in self.seqs_dict:
|
||||||
raise ValueError(f"Sequence {seq_id} not found.")
|
raise ValueError(f"Sequence {seq_id} not found.")
|
||||||
@@ -345,7 +361,7 @@ class SequenceOutputs:
|
|||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if not isinstance(other, SequenceOutputs):
|
if not isinstance(other, SequenceOutputs):
|
||||||
return NotImplementedError()
|
raise NotImplementedError()
|
||||||
return (self.parent_seq_id == other.parent_seq_id
|
return (self.parent_seq_id == other.parent_seq_id
|
||||||
and self.output_token == other.output_token
|
and self.output_token == other.output_token
|
||||||
and self.logprobs == other.logprobs)
|
and self.logprobs == other.logprobs)
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import AutoConfig, PretrainedConfig
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
|
|
||||||
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
|
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
|
||||||
@@ -12,10 +14,21 @@ _CONFIG_REGISTRY = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
|
def get_config(model: str,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
revision: Optional[str] = None) -> PretrainedConfig:
|
||||||
|
# NOTE: Because the Mistral model in HF hub does not have
|
||||||
|
# `configuration_mistral.py`, we cannot use `AutoConfig` to load the
|
||||||
|
# config. Instead, we use `MistralConfig` directly.
|
||||||
|
# NOTE: This is a hack. This does not work for local models.
|
||||||
|
# FIXME: Remove this once the Mistral model is available in the stable
|
||||||
|
# version of HF transformers.
|
||||||
|
if "mistral" in model.lower():
|
||||||
|
return MistralConfig.from_pretrained(model, revision=revision)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model, trust_remote_code=trust_remote_code)
|
model, trust_remote_code=trust_remote_code, revision=revision)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if (not trust_remote_code and
|
if (not trust_remote_code and
|
||||||
"requires you to execute the configuration file" in str(e)):
|
"requires you to execute the configuration file" in str(e)):
|
||||||
@@ -29,5 +42,5 @@ def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
|
|||||||
raise e
|
raise e
|
||||||
if config.model_type in _CONFIG_REGISTRY:
|
if config.model_type in _CONFIG_REGISTRY:
|
||||||
config_class = _CONFIG_REGISTRY[config.model_type]
|
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||||
config = config_class.from_pretrained(model)
|
config = config_class.from_pretrained(model, revision=revision)
|
||||||
return config
|
return config
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from vllm.transformers_utils.configs.qwen import QWenConfig
|
|||||||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||||
from vllm.transformers_utils.configs.falcon import RWConfig
|
from vllm.transformers_utils.configs.falcon import RWConfig
|
||||||
|
from vllm.transformers_utils.configs.mistral import MistralConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MPTConfig",
|
"MPTConfig",
|
||||||
@@ -13,4 +14,5 @@ __all__ = [
|
|||||||
"AquilaConfig",
|
"AquilaConfig",
|
||||||
"QWenConfig",
|
"QWenConfig",
|
||||||
"RWConfig",
|
"RWConfig",
|
||||||
|
"MistralConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
66
vllm/transformers_utils/configs/mistral.py
Normal file
66
vllm/transformers_utils/configs/mistral.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Mistral-7B-v0.1 configuration"""
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MistralConfig(PretrainedConfig):
|
||||||
|
model_type = "mistral"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=14336,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=4096 * 32,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
sliding_window=4096,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
@@ -7,65 +7,54 @@ from transformers import PretrainedConfig
|
|||||||
class QWenConfig(PretrainedConfig):
|
class QWenConfig(PretrainedConfig):
|
||||||
model_type = "qwen"
|
model_type = "qwen"
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
attribute_map = {
|
|
||||||
"hidden_size": "n_embd",
|
|
||||||
"num_attention_heads": "n_head",
|
|
||||||
"max_position_embeddings": "n_positions",
|
|
||||||
"num_hidden_layers": "n_layer",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size=151851,
|
vocab_size=151936,
|
||||||
n_embd=4096,
|
hidden_size=4096,
|
||||||
n_layer=32,
|
num_hidden_layers=32,
|
||||||
n_head=32,
|
num_attention_heads=32,
|
||||||
n_inner=None,
|
emb_dropout_prob=0.0,
|
||||||
embd_pdrop=0.0,
|
attn_dropout_prob=0.0,
|
||||||
attn_pdrop=0.0,
|
layer_norm_epsilon=1e-6,
|
||||||
layer_norm_epsilon=1e-5,
|
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
|
max_position_embeddings=8192,
|
||||||
scale_attn_weights=True,
|
scale_attn_weights=True,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
eos_token_id=151643,
|
bf16=False,
|
||||||
apply_residual_connection_post_layernorm=False,
|
fp16=False,
|
||||||
bf16=True,
|
fp32=False,
|
||||||
kv_channels=128,
|
kv_channels=128,
|
||||||
rotary_pct=1.0,
|
rotary_pct=1.0,
|
||||||
rotary_emb_base=10000,
|
rotary_emb_base=10000,
|
||||||
use_dynamic_ntk=False,
|
use_dynamic_ntk=True,
|
||||||
use_logn_attn=False,
|
use_logn_attn=True,
|
||||||
use_flash_attn=True,
|
use_flash_attn="auto",
|
||||||
ffn_hidden_size=22016,
|
intermediate_size=22016,
|
||||||
no_bias=True,
|
no_bias=True,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.eos_token_id = eos_token_id
|
|
||||||
super().__init__(eos_token_id=eos_token_id,
|
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.n_embd = n_embd
|
self.hidden_size = hidden_size
|
||||||
self.n_layer = n_layer
|
self.intermediate_size = intermediate_size
|
||||||
self.n_head = n_head
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.n_inner = n_inner
|
self.num_attention_heads = num_attention_heads
|
||||||
self.embd_pdrop = embd_pdrop
|
self.emb_dropout_prob = emb_dropout_prob
|
||||||
self.attn_pdrop = attn_pdrop
|
self.attn_dropout_prob = attn_dropout_prob
|
||||||
self.layer_norm_epsilon = layer_norm_epsilon
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.scale_attn_weights = scale_attn_weights
|
self.scale_attn_weights = scale_attn_weights
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.apply_residual_connection_post_layernorm = (
|
self.max_position_embeddings = max_position_embeddings
|
||||||
apply_residual_connection_post_layernorm)
|
|
||||||
self.bf16 = bf16
|
self.bf16 = bf16
|
||||||
|
self.fp16 = fp16
|
||||||
|
self.fp32 = fp32
|
||||||
self.kv_channels = kv_channels
|
self.kv_channels = kv_channels
|
||||||
self.rotary_pct = rotary_pct
|
self.rotary_pct = rotary_pct
|
||||||
self.rotary_emb_base = rotary_emb_base
|
self.rotary_emb_base = rotary_emb_base
|
||||||
self.use_dynamic_ntk = use_dynamic_ntk
|
self.use_dynamic_ntk = use_dynamic_ntk
|
||||||
self.use_logn_attn = use_logn_attn
|
self.use_logn_attn = use_logn_attn
|
||||||
self.use_flash_attn = use_flash_attn
|
self.use_flash_attn = use_flash_attn
|
||||||
self.ffn_hidden_size = ffn_hidden_size
|
|
||||||
self.no_bias = no_bias
|
self.no_bias = no_bias
|
||||||
self.tie_word_embeddings = tie_word_embeddings
|
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||||
PreTrainedTokenizerFast)
|
PreTrainedTokenizerFast)
|
||||||
@@ -28,8 +28,8 @@ def get_tokenizer(
|
|||||||
if ("llama" in tokenizer_name.lower() and kwargs.get("use_fast", True)
|
if ("llama" in tokenizer_name.lower() and kwargs.get("use_fast", True)
|
||||||
and tokenizer_name != _FAST_LLAMA_TOKENIZER):
|
and tokenizer_name != _FAST_LLAMA_TOKENIZER):
|
||||||
logger.info(
|
logger.info(
|
||||||
"For some LLaMA-based models, initializing the fast tokenizer may "
|
"For some LLaMA V1 models, initializing the fast tokenizer may "
|
||||||
"take a long time. To eliminate the initialization time, consider "
|
"take a long time. To reduce the initialization time, consider "
|
||||||
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
||||||
"tokenizer.")
|
"tokenizer.")
|
||||||
try:
|
try:
|
||||||
@@ -41,9 +41,9 @@ def get_tokenizer(
|
|||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
# The LLaMA tokenizer causes a protobuf error in some environments.
|
# The LLaMA tokenizer causes a protobuf error in some environments.
|
||||||
err_msg = (
|
err_msg = (
|
||||||
"Failed to load the tokenizer. If you are using a LLaMA-based "
|
"Failed to load the tokenizer. If you are using a LLaMA V1 model "
|
||||||
f"model, use '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
|
||||||
"tokenizer.")
|
"original tokenizer.")
|
||||||
raise RuntimeError(err_msg) from e
|
raise RuntimeError(err_msg) from e
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# If the error pertains to the tokenizer class not existing or not
|
# If the error pertains to the tokenizer class not existing or not
|
||||||
@@ -67,33 +67,11 @@ def get_tokenizer(
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
def detokenize_incrementally(
|
def _convert_tokens_to_string_with_added_encoders(
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
prev_output_tokens: List[str],
|
output_tokens: List[str],
|
||||||
new_token_id: int,
|
|
||||||
skip_special_tokens: bool,
|
skip_special_tokens: bool,
|
||||||
) -> Tuple[str, str]:
|
) -> str:
|
||||||
"""Detokenizes the new token in conjunction with the previous output tokens.
|
|
||||||
|
|
||||||
NOTE: This function does not update prev_output_tokens.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
new_token: The new token as a string.
|
|
||||||
output_text: The new output text as a string.
|
|
||||||
"""
|
|
||||||
if skip_special_tokens and (new_token_id in tokenizer.all_special_ids):
|
|
||||||
return None, prev_output_tokens
|
|
||||||
new_token = tokenizer.convert_ids_to_tokens(
|
|
||||||
new_token_id, skip_special_tokens=skip_special_tokens)
|
|
||||||
output_tokens = prev_output_tokens + [new_token]
|
|
||||||
|
|
||||||
# Convert the tokens to a string.
|
|
||||||
# Optimization: If the tokenizer does not have `added_tokens_encoder`,
|
|
||||||
# then we can directly use `convert_tokens_to_string`.
|
|
||||||
if not getattr(tokenizer, "added_tokens_encoder", {}):
|
|
||||||
output_text = tokenizer.convert_tokens_to_string(output_tokens)
|
|
||||||
return new_token, output_text
|
|
||||||
|
|
||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
|
||||||
# NOTE(woosuk): The following code is slow because it runs a for loop over
|
# NOTE(woosuk): The following code is slow because it runs a for loop over
|
||||||
@@ -115,5 +93,61 @@ def detokenize_incrementally(
|
|||||||
if current_sub_text:
|
if current_sub_text:
|
||||||
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
||||||
sub_texts.append(sub_text)
|
sub_texts.append(sub_text)
|
||||||
output_text = " ".join(sub_texts)
|
return " ".join(sub_texts)
|
||||||
return new_token, output_text
|
|
||||||
|
|
||||||
|
# Based on
|
||||||
|
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
|
||||||
|
# under Apache 2.0 license
|
||||||
|
def detokenize_incrementally(
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
|
all_input_ids: List[int],
|
||||||
|
prev_tokens: Optional[List[str]],
|
||||||
|
prefix_offset: int = 0,
|
||||||
|
read_offset: int = 0,
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
) -> Tuple[List[str], str, int, int]:
|
||||||
|
new_token_id = all_input_ids[-1]
|
||||||
|
# This is the first iteration for this sequence
|
||||||
|
if prev_tokens is None:
|
||||||
|
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||||
|
all_input_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
output_tokens = new_tokens
|
||||||
|
# 5 is an arbitrary value that should work for all
|
||||||
|
# tokenizers (bigger = more conservative).
|
||||||
|
# Subtract 1 extra to account for the generated token.
|
||||||
|
prefix_offset = max(len(output_tokens) - 6, 0)
|
||||||
|
read_offset = max(len(output_tokens) - 1, 0)
|
||||||
|
else:
|
||||||
|
# Put new_token_id in a list so skip_special_tokens is respected
|
||||||
|
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||||
|
[new_token_id], skip_special_tokens=skip_special_tokens)
|
||||||
|
output_tokens = prev_tokens + new_tokens
|
||||||
|
|
||||||
|
# The prefix text is necessary only to defeat cleanup algorithms in
|
||||||
|
# the decode which decide to add a space or not depending on the
|
||||||
|
# surrounding ids.
|
||||||
|
if not getattr(tokenizer, "added_tokens_encoder", {}):
|
||||||
|
prefix_text = tokenizer.convert_tokens_to_string(
|
||||||
|
output_tokens[prefix_offset:read_offset])
|
||||||
|
new_text = tokenizer.convert_tokens_to_string(
|
||||||
|
output_tokens[prefix_offset:])
|
||||||
|
else:
|
||||||
|
prefix_text = _convert_tokens_to_string_with_added_encoders(
|
||||||
|
tokenizer,
|
||||||
|
output_tokens[prefix_offset:read_offset],
|
||||||
|
skip_special_tokens=skip_special_tokens)
|
||||||
|
new_text = _convert_tokens_to_string_with_added_encoders(
|
||||||
|
tokenizer,
|
||||||
|
output_tokens[prefix_offset:],
|
||||||
|
skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
|
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||||||
|
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||||
|
# from byte fallback tokenization.
|
||||||
|
# If it's in the middle, it's probably a real invalid id generated
|
||||||
|
# by the model
|
||||||
|
new_text = new_text[len(prefix_text):]
|
||||||
|
return new_tokens, new_text, read_offset, len(output_tokens)
|
||||||
|
else:
|
||||||
|
return new_tokens, "", prefix_offset, read_offset
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import enum
|
import enum
|
||||||
from platform import uname
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from platform import uname
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm import cuda_utils
|
||||||
|
|
||||||
|
|
||||||
class Device(enum.Enum):
|
class Device(enum.Enum):
|
||||||
GPU = enum.auto()
|
GPU = enum.auto()
|
||||||
@@ -25,6 +27,15 @@ class Counter:
|
|||||||
self.counter = 0
|
self.counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||||
|
"""Returns the maximum shared memory per thread block in bytes."""
|
||||||
|
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||||
|
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 # pylint: disable=invalid-name
|
||||||
|
max_shared_mem = cuda_utils.get_device_attribute(
|
||||||
|
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
|
||||||
|
return int(max_shared_mem)
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_memory(gpu: int = 0) -> int:
|
def get_gpu_memory(gpu: int = 0) -> int:
|
||||||
"""Returns the total memory of the GPU in bytes."""
|
"""Returns the total memory of the GPU in bytes."""
|
||||||
return torch.cuda.get_device_properties(gpu).total_memory
|
return torch.cuda.get_device_properties(gpu).total_memory
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class CacheEngine:
|
|||||||
|
|
||||||
self.head_size = model_config.get_head_size()
|
self.head_size = model_config.get_head_size()
|
||||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||||
self.num_heads = model_config.get_num_heads(parallel_config)
|
self.num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
self.dtype = model_config.dtype
|
self.dtype = model_config.dtype
|
||||||
|
|
||||||
self.block_size = cache_config.block_size
|
self.block_size = cache_config.block_size
|
||||||
@@ -146,7 +146,7 @@ class CacheEngine:
|
|||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
) -> int:
|
) -> int:
|
||||||
head_size = model_config.get_head_size()
|
head_size = model_config.get_head_size()
|
||||||
num_heads = model_config.get_num_heads(parallel_config)
|
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
num_layers = model_config.get_num_layers(parallel_config)
|
num_layers = model_config.get_num_layers(parallel_config)
|
||||||
|
|
||||||
key_cache_block = block_size * num_heads * head_size
|
key_cache_block = block_size * num_heads * head_size
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
from vllm.utils import get_gpu_memory
|
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
@@ -42,6 +42,7 @@ class Worker:
|
|||||||
# self.init_cache_engine().
|
# self.init_cache_engine().
|
||||||
self.cache_config = None
|
self.cache_config = None
|
||||||
self.block_size = None
|
self.block_size = None
|
||||||
|
self.sliding_window = None
|
||||||
self.cache_engine = None
|
self.cache_engine = None
|
||||||
self.cache_events = None
|
self.cache_events = None
|
||||||
self.gpu_cache = None
|
self.gpu_cache = None
|
||||||
@@ -136,6 +137,15 @@ class Worker:
|
|||||||
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
self.block_size = cache_config.block_size
|
self.block_size = cache_config.block_size
|
||||||
|
self.sliding_window = cache_config.sliding_window
|
||||||
|
|
||||||
|
if self.sliding_window is None:
|
||||||
|
max_seq_len = self.scheduler_config.max_model_len
|
||||||
|
else:
|
||||||
|
max_seq_len = min(self.scheduler_config.max_model_len,
|
||||||
|
self.sliding_window)
|
||||||
|
_check_if_can_support_max_seq_len(max_seq_len, self.block_size)
|
||||||
|
|
||||||
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
||||||
self.parallel_config)
|
self.parallel_config)
|
||||||
self.cache_events = self.cache_engine.events
|
self.cache_events = self.cache_engine.events
|
||||||
@@ -207,10 +217,11 @@ class Worker:
|
|||||||
|
|
||||||
context_len = seq_data.get_len()
|
context_len = seq_data.get_len()
|
||||||
position = context_len - 1
|
position = context_len - 1
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
context_len = min(context_len, self.sliding_window)
|
||||||
input_positions.append(position)
|
input_positions.append(position)
|
||||||
|
|
||||||
block_table = seq_group_metadata.block_tables[seq_id]
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
generation_block_tables.append(block_table)
|
|
||||||
|
|
||||||
max_context_len = max(max_context_len, context_len)
|
max_context_len = max(max_context_len, context_len)
|
||||||
max_num_blocks_per_seq = max(max_num_blocks_per_seq,
|
max_num_blocks_per_seq = max(max_num_blocks_per_seq,
|
||||||
@@ -222,21 +233,37 @@ class Worker:
|
|||||||
slot = block_number * self.block_size + block_offset
|
slot = block_number * self.block_size + block_offset
|
||||||
slot_mapping.append(slot)
|
slot_mapping.append(slot)
|
||||||
|
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
sliding_window_blocks = (self.sliding_window //
|
||||||
|
self.block_size)
|
||||||
|
block_table = block_table[-sliding_window_blocks:]
|
||||||
|
generation_block_tables.append(block_table)
|
||||||
|
|
||||||
# Optimization: Pad the input length to be a multiple of 8.
|
# Optimization: Pad the input length to be a multiple of 8.
|
||||||
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
|
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
|
||||||
input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
|
input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
|
||||||
input_positions = _pad_to_alignment(input_positions, multiple_of=8)
|
input_positions = _pad_to_alignment(input_positions, multiple_of=8)
|
||||||
|
|
||||||
# Convert to tensors.
|
# Convert to tensors.
|
||||||
tokens_tensor = torch.cuda.LongTensor(input_tokens)
|
tokens_tensor = torch.tensor(input_tokens,
|
||||||
positions_tensor = torch.cuda.LongTensor(input_positions)
|
dtype=torch.long,
|
||||||
slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping)
|
device="cuda")
|
||||||
context_lens_tensor = torch.cuda.IntTensor(context_lens)
|
positions_tensor = torch.tensor(input_positions,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cuda")
|
||||||
|
slot_mapping_tensor = torch.tensor(slot_mapping,
|
||||||
|
dtype=torch.int,
|
||||||
|
device="cuda")
|
||||||
|
context_lens_tensor = torch.tensor(context_lens,
|
||||||
|
dtype=torch.int,
|
||||||
|
device="cuda")
|
||||||
padded_block_tables = [
|
padded_block_tables = [
|
||||||
_pad_to_max(block_table, max_num_blocks_per_seq)
|
_pad_to_max(block_table, max_num_blocks_per_seq)
|
||||||
for block_table in generation_block_tables
|
for block_table in generation_block_tables
|
||||||
]
|
]
|
||||||
block_tables_tensor = torch.cuda.IntTensor(padded_block_tables)
|
block_tables_tensor = torch.tensor(padded_block_tables,
|
||||||
|
dtype=torch.int,
|
||||||
|
device="cuda")
|
||||||
|
|
||||||
seq_data: Dict[int, SequenceData] = {}
|
seq_data: Dict[int, SequenceData] = {}
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
@@ -250,6 +277,7 @@ class Worker:
|
|||||||
context_lens=context_lens_tensor,
|
context_lens=context_lens_tensor,
|
||||||
max_context_len=max_context_len,
|
max_context_len=max_context_len,
|
||||||
block_tables=block_tables_tensor,
|
block_tables=block_tables_tensor,
|
||||||
|
sliding_window=self.sliding_window,
|
||||||
)
|
)
|
||||||
return tokens_tensor, positions_tensor, input_metadata
|
return tokens_tensor, positions_tensor, input_metadata
|
||||||
|
|
||||||
@@ -337,3 +365,23 @@ def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
|
|||||||
|
|
||||||
def _pad_to_max(x: List[int], max_len: int) -> List[int]:
|
def _pad_to_max(x: List[int], max_len: int) -> List[int]:
|
||||||
return x + [0] * (max_len - len(x))
|
return x + [0] * (max_len - len(x))
|
||||||
|
|
||||||
|
|
||||||
|
def _check_if_can_support_max_seq_len(max_seq_len: int,
|
||||||
|
block_size: int) -> None:
|
||||||
|
# Follows the logic in
|
||||||
|
# attention_kernels.cu::single_query_cached_kv_attention_launcher
|
||||||
|
max_shared_mem = get_max_shared_memory_bytes()
|
||||||
|
float32_bytes = torch.finfo(torch.float).bits // 8
|
||||||
|
padded_max_seq_len = (
|
||||||
|
(max_seq_len + block_size - 1) / block_size) * block_size
|
||||||
|
# padded_max_seq_len + extra buffer
|
||||||
|
required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
|
||||||
|
if padded_max_seq_len * float32_bytes > max_shared_mem:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"vLLM cannot currently support max_model_len={max_seq_len} "
|
||||||
|
f"with block_size={block_size} on GPU with compute "
|
||||||
|
f"capability {torch.cuda.get_device_capability()} "
|
||||||
|
f"(required shared memory {required_shared_mem} > "
|
||||||
|
f"available shared memory {max_shared_mem}). "
|
||||||
|
"This will be fixed in a future release.")
|
||||||
|
|||||||
Reference in New Issue
Block a user