Compare commits

..

143 Commits

Author SHA1 Message Date
Woosuk Kwon
90eb3f43ca Bump up the version to v0.1.7 (#1013)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.8) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9) (push) Has been cancelled
2023-09-11 00:54:30 -07:00
Woosuk Kwon
e67b4f2c2a Use FP32 in RoPE initialization (#1004)
Co-authored-by: One <imone@tuta.io>
2023-09-11 00:26:35 -07:00
Woosuk Kwon
d6770d1f23 Update setup.py (#1006) 2023-09-10 23:42:45 -07:00
Woosuk Kwon
b9cecc2635 [Docs] Update installation page (#1005) 2023-09-10 14:23:31 -07:00
Kyujin Cho
898285c9bf fix: CUDA error when inferencing with Falcon-40B base model (#992) 2023-09-10 01:39:02 -07:00
Antoni Baum
a62de9ecfd Fix wrong dtype in PagedAttentionWithALiBi bias (#996)
---------

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
2023-09-09 14:58:35 -07:00
Jingru
4042d192f5 fix "tansformers_module" ModuleNotFoundError when load model with trust_remote_code=True (#871) 2023-09-08 17:21:30 -07:00
Zhuohan Li
1117aa1411 Bump up the version to v0.1.6 (#989)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.8) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9) (push) Has been cancelled
2023-09-08 00:07:46 -07:00
Antoni Baum
080438477f Start background task in AsyncLLMEngine.generate (#988)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-09-08 00:03:39 -07:00
Robert Irvine
4b5bcf8906 faster startup of vLLM (#982)
* update

---------

Co-authored-by: Robert Irvine <robert@seamlessml.com>
2023-09-08 14:48:54 +09:00
Woosuk Kwon
852ef5b4f5 Bump up the version to v0.1.5 (#944)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.8) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9) (push) Has been cancelled
2023-09-07 16:15:31 -07:00
Zhuohan Li
db09d4ad83 [FIX] Fix Alibi implementation in PagedAttention kernel (#945)
* [FIX] Fix Alibi implementation in PagedAttention kernel

* Fix test_attention

* Fix

---------

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Oliver-ss <yuansongwx@outlook.com>
2023-09-07 15:53:14 -07:00
Zhuohan Li
c957c741d9 Enable safetensors loading for all models (#974) 2023-09-07 15:49:52 -07:00
Antoni Baum
c07ece5ca4 Make AsyncLLMEngine more robust & fix batched abort (#969)
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Co-authored-by: Avnish Narayan <38871737+avnishn@users.noreply.github.com>
2023-09-07 13:43:45 -07:00
Woosuk Kwon
7a9c20c715 Bum up transformers version (#976) 2023-09-07 13:15:53 -07:00
Antoni Baum
005ba458b5 Set torch default dtype in a context manager (#971)
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
2023-09-07 15:39:37 +09:00
Woosuk Kwon
320a622ec4 [BugFix] Implement RoPE for GPT-J (#941) 2023-09-06 11:54:33 +09:00
Antoni Baum
c9927c1a6a Use queue for finished requests (#957) 2023-09-05 19:27:23 -07:00
Woosuk Kwon
fbd80ad409 Clean up kernel unit tests (#938) 2023-09-05 16:57:38 -07:00
Wen Sun
22379d5513 fix: typo (#948) 2023-09-04 23:22:30 -07:00
Antoni Baum
1696725879 Initialize AsyncLLMEngine bg loop correctly (#943) 2023-09-04 17:41:22 -07:00
Zhuohan Li
002800f081 Align vLLM's beam search implementation with HF generate (#857) 2023-09-04 17:29:42 -07:00
Nelson Liu
e15932bb60 Only emit warning about internal tokenizer if it isn't being used (#939) 2023-09-05 00:50:55 +09:00
Antoni Baum
ce741ba3e4 Refactor AsyncLLMEngine (#880) 2023-09-03 21:43:43 -07:00
Woosuk Kwon
bf87484efa [BugFix] Fix NaN errors in paged attention kernel (#936) 2023-09-04 09:20:06 +09:00
Woosuk Kwon
8ce9c50d40 Avoid compiling kernels for double data type (#933) 2023-09-02 14:59:47 +09:00
Woosuk Kwon
32b6816e55 Add tests for models (#922) 2023-09-01 11:19:43 +09:00
Zhuohan Li
c128d69856 Fix README.md Link (#927) 2023-08-31 17:18:34 -07:00
Woosuk Kwon
55b28b1eee [Docs] Minor fixes in supported models (#920)
* Minor fix in supported models

* Add another small fix for Aquila model

---------

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-08-31 16:28:39 -07:00
Dong-Yong Lee
e11222333f fix: bug fix when penalties are negative (#913)
Co-authored-by: dongyong-lee <dongyong.lee@navercorp.com>
2023-09-01 00:37:17 +09:00
Aman Gupta Karmani
28873a2799 Improve _prune_hidden_states micro-benchmark (#707) 2023-08-31 13:28:43 +09:00
Zhuohan Li
0080d8329d Add acknowledgement to a16z grant 2023-08-30 02:26:47 -07:00
JFDuan
0d93f15694 Accelerate LLaMA model loading (#234) 2023-08-30 01:00:13 -07:00
lplcor
becd7a56f1 Enable request body OpenAPI spec for OpenAI endpoints (#865) 2023-08-29 21:54:08 -07:00
Aman Gupta Karmani
75471386de use flash-attn via xformers (#877) 2023-08-29 21:52:13 -07:00
Zhuohan Li
d2b2eed67c [Fix] Fix a condition for ignored sequences (#867) 2023-08-27 23:00:56 -07:00
Antoni Baum
4b6f069b6f Add support for CodeLlama (#854) 2023-08-25 12:44:07 -07:00
Woosuk Kwon
791d79de32 Bump up the version to v0.1.4 (#846)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.8) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9) (push) Has been cancelled
2023-08-25 12:28:00 +09:00
Woosuk Kwon
94d2f59895 Set replacement=True in torch.multinomial (#858) 2023-08-25 12:22:01 +09:00
wenjun93
75c0ca9d43 Clean up code (#844) 2023-08-23 16:44:15 -07:00
Woosuk Kwon
2a4ec90854 Fix for breaking changes in xformers 0.0.21 (#834) 2023-08-23 17:44:21 +09:00
ldwang
85ebcda94d Fix typo of Aquila in README.md (#836) 2023-08-22 20:48:36 -07:00
Woosuk Kwon
d64bf1646c Implement approximate GELU kernels (#828) 2023-08-23 07:43:21 +09:00
Woosuk Kwon
a41c20435e Add compute capability 8.9 to default targets (#829) 2023-08-23 07:28:38 +09:00
Wen Sun
eedac9dba0 fix: revert code to avoid no attribute problem (#827) 2023-08-22 11:55:16 -07:00
Zhuohan Li
14f9c72bfd Update Supported Model List (#825) 2023-08-22 11:51:44 -07:00
shunxing1234
ad5f2fe34c Add support for aquila (#663)
* add aquila

Signed-off-by: ftgreat <ftgreat@163.com>

* fix some bug

Signed-off-by: shunxing1234 <xw747777271@gmail.com>

* delete pdb

Signed-off-by: shunxing1234 <xw747777271@gmail.com>

* fix bugs

Signed-off-by: shunxing1234 <xw747777271@gmail.com>

* fix bugs

Signed-off-by: shunxing1234 <xw747777271@gmail.com>

* delete whitespace

Signed-off-by: shunxing1234 <xw747777271@gmail.com>

* format

* fix order

---------

Signed-off-by: ftgreat <ftgreat@163.com>
Signed-off-by: shunxing1234 <xw747777271@gmail.com>
Co-authored-by: ftgreat <ftgreat@163.com>
2023-08-22 00:13:36 -07:00
zhaoyang-star
4f8584756d Fix mqa is false case in gpt_bigcode (#806) 2023-08-21 22:22:06 -07:00
Xudong Zhang
65fc1c3127 set default coompute capability according to cuda version (#773) 2023-08-21 16:05:44 -07:00
Daniel
c393af6cd7 [Feature | CI] Added a github action to build wheels (#746) 2023-08-21 16:59:15 +09:00
wangcx18
0c04ce3234 Fix typo in sampling_params.py (#788) 2023-08-18 10:12:46 +09:00
Xinyu Yang
73b3de79ea explicitly del state (#784) 2023-08-17 12:56:04 -07:00
Abraham-Xu
d1744376ae Align with huggingface Top K sampling (#753) 2023-08-15 16:44:33 -07:00
Ikko Eltociear Ashimine
805de738f6 Fix typo in tokenizer.py (#750)
conjuction -> conjunction
2023-08-14 22:26:36 -07:00
Uranus
1b151ed181 Fix baichuan doc style (#748) 2023-08-13 20:57:31 -07:00
WanMok
e06f504a76 Supports tokens and arrays of tokens as inputs to the OpenAI completion API (#715) 2023-08-11 12:14:34 -07:00
WRH
462ae5220a [Fix] unwantted bias in InternLM Model (#740) 2023-08-11 11:40:37 -07:00
Nicolas Basile
66c54aa9c3 Check the max prompt length for the OpenAI completions API (#472) 2023-08-08 17:43:49 -07:00
Jia Guoqing
735ecfff61 add internlm model (#528) 2023-08-08 16:35:06 -07:00
Qing
a57d13cc96 add QWen-7b (#685)
Co-authored-by: wq.chu <wq.chu@tianrang-inc.com>
2023-08-08 13:50:38 -07:00
Dean Leitersdorf
79af7e96a0 [OPTIMIZATION] Optimizes the single_query_cached_kv_attention kernel (#420) 2023-08-04 10:57:29 -07:00
Wen Sun
621980bdc0 fix: incorrect bigcode attention heads num (#676) 2023-08-04 10:35:22 -07:00
Zhuohan Li
aa84c92ef6 Bump up version to 0.1.3 (#657) 2023-08-02 16:46:53 -07:00
Zhuohan Li
f7389f4763 [Doc] Add Baichuan 13B to supported models (#656) 2023-08-02 16:45:12 -07:00
Woosuk Kwon
55fe8a81ec Refactor scheduler (#658) 2023-08-02 16:42:01 -07:00
YHPeter
e8ddc08ec8 [BUG FIX] upgrade fschat version to 0.2.23 (#650)
Co-authored-by: hao.yu <hao.yu@cn-c017.server.mila.quebec>
2023-08-02 14:05:59 -07:00
Zhuohan Li
1b0bd0fe8a Add Falcon support (new) (#592) 2023-08-02 14:04:39 -07:00
Lily Liu
20044cab7a Fix log message in scheduler (#652) 2023-08-02 13:35:10 -07:00
Song
64f23c2900 fix baichuan for different position embedding for 7b and 13b models (#643) 2023-08-01 22:22:51 -07:00
Qing
d4c7755ca8 fix biachuan-7b tp (#598)
Co-authored-by: wq.chu <wq.chu@tianrang-inc.com>
2023-08-01 15:41:36 -07:00
Chaofan Lin
aa39e42c5a fix doc (#622) 2023-07-31 13:11:57 -07:00
Fang li
953f28cf9a fix ModuleNotFoundError (#599)
Co-authored-by: fangli <fangli@tencent.com>
2023-07-29 20:52:41 -07:00
Xudong Zhang
c0d00f5be6 [Fix] fix import error of RayWorker (#604) (#605) 2023-07-27 23:37:40 -07:00
Zhuohan Li
58a072be15 [Fix] Add model sequence length into model config (#575) 2023-07-25 23:46:30 -07:00
Zhuohan Li
82ad323dee [Fix] Add chat completion Example and simplify dependencies (#576) 2023-07-25 23:45:48 -07:00
Zhuohan Li
df5dd3c68e Add Baichuan-7B to README (#494) 2023-07-25 15:25:12 -07:00
MoeedDar
2d867b55fa fixed tensor parallel is not defined (#564) 2023-07-25 14:16:51 -07:00
Tao Peng
d7a1c6d614 Fix paged attention testing. (#495)
Signed-off-by: Tao Peng <jiankeng.pt@alibaba-inc.com>
2023-07-24 21:01:56 -07:00
Zhuohan Li
7d5a155e4a [Fix] Fix GPTBigcoder for distributed execution (#503) 2023-07-24 18:36:33 -07:00
leegohi04517
1dde34e0f8 GPTJConfig has no attribute rotary. (#532) 2023-07-24 11:29:30 -07:00
Zhuohan Li
6fc2a38b11 Add support for LLaMA-2 (#505) 2023-07-20 11:38:27 -07:00
Antoni Baum
c487a221ee Fix bad assert in initialize_cluster if PG already exists (#526) 2023-07-19 23:17:12 -07:00
Antoni Baum
9925c17940 Ray placement group support (#397) 2023-07-19 22:49:31 -07:00
Ricardo Lu
8c4b2592fb fix: enable trust-remote-code in api server & benchmark. (#509) 2023-07-19 17:06:15 -07:00
WRH
cf21a9bd5c support trust_remote_code in benchmark (#518) 2023-07-19 17:02:40 -07:00
Massimiliano Pronesti
16c3e295a8 fix(ray_utils): ignore re-init error (#465) 2023-07-19 17:01:19 -07:00
Song
bda41c70dd hotfix attn alibi wo head mapping (#496)
Co-authored-by: oliveryuan <oliveryuan@basemind.com>
2023-07-18 11:31:48 -07:00
Lily Liu
453bafb96f Merge pull request #498 from MoeedDar/main
Fixed old name reference for max_seq_len
2023-07-18 09:22:56 -07:00
MoeedDar
328d231c17 Fixed old name reference for max_seq_len 2023-07-18 16:47:59 +01:00
Lily Liu
b4b195b360 fix max seq len (#489) 2023-07-17 23:20:20 -07:00
codethazine
20b0d88d16 Add support for baichuan (#365) 2023-07-17 13:50:55 -07:00
Zhuohan Li
2bdea7ac11 [Fix] Fix the condition of max_seq_len (#477) 2023-07-17 00:33:48 -04:00
Zhanghao Wu
58df2883cb [Doc] Add doc for running vLLM on the cloud (#426)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-07-16 13:37:14 -07:00
Zhangir Azerbayev
6d7d95a70a Offload port selection to OS (#467) 2023-07-15 23:11:02 -07:00
Zhuohan Li
96853af5a8 Optimize MQA Kernel (#452) 2023-07-14 20:06:40 -04:00
Wen Sun
dbed69058c Fix the KeyError when loading bloom-based models (#441) 2023-07-13 21:58:09 -07:00
panda
7b6ae94059 add vocab padding for LLama(Support WizardLM) (#411) 2023-07-13 23:56:22 -04:00
xcnick
c6dfc3cdbe Fix handling of special tokens in decoding. (#418) 2023-07-12 11:14:56 -04:00
Keming
51be365143 fix: freeze pydantic to v1 (#429) 2023-07-12 11:10:55 -04:00
Andre Slavescu
c894836108 [Model] Add support for GPT-J (#226)
Co-authored-by: woWoosuk Kwon <woosuk.kwon@berkeley.edu>
2023-07-08 17:55:16 -07:00
Fazlul Shahriar
75beba29b5 Don't try to load training_args.bin (#373) 2023-07-08 15:26:28 -07:00
Woosuk Kwon
ddfdf470ae Add trust_remote_code arg to get_config (#405) 2023-07-08 15:24:17 -07:00
Woosuk Kwon
b6fbb9a565 Sort the outputs before return (#402) 2023-07-08 14:48:18 -07:00
Lily Liu
2179e4f4c5 avoid python list copy in sequence initialization (#401) 2023-07-08 12:42:08 -07:00
codethazine
a945fcc2ae Add trust-remote-code flag to handle remote tokenizers (#364) 2023-07-07 11:04:58 -07:00
Nicolas Frenay
be54f8e5c4 [Fix] Change /generate response-type to json for non-streaming (#374) 2023-07-06 18:15:17 -07:00
Ricardo Lu
b396cb4998 fix: only response [DONE] once when streaming response. (#378) 2023-07-06 18:08:40 -07:00
Woosuk Kwon
1c395b4eaa Bump up the version (#300) 2023-07-04 21:41:53 -07:00
akxxsb
3d64cf019e [Server] use fastchat.model.model_adapter.get_conversation_template method to get model template (#357) 2023-07-04 21:39:59 -07:00
Zhuohan Li
98fe8cb542 [Server] Add option to specify chat template for chat endpoint (#345) 2023-07-03 23:01:56 -07:00
Woosuk Kwon
ffa6d2f9f9 [Docs] Fix typo (#346) 2023-07-03 16:51:47 -07:00
Woosuk Kwon
404422f42e [Model] Add support for MPT (#334) 2023-07-03 16:47:53 -07:00
coolcloudcol
7717d0838b Fix an endless loop issue when engine_step throws a RuntimeError (#339) 2023-07-03 15:22:28 -07:00
Zhuohan Li
42e0c1df78 [Quality] Add CI for formatting (#343) 2023-07-03 14:50:56 -07:00
Woosuk Kwon
e41f06702c Add support for BLOOM (#331) 2023-07-03 13:12:35 -07:00
Zhuohan Li
d6fa1be3a8 [Quality] Add code formatter and linter (#326) 2023-07-03 11:31:55 -07:00
Zhuohan Li
0ffded812a [Fix] Better error message for batched prompts (#342) 2023-07-03 09:27:31 -07:00
Michele Catalano
0bd2a573a5 Allow send list of str for the Prompt on openai demo endpoint /v1/completions (#323)
* allow str or List[str] for prompt

* Update vllm/entrypoints/openai/api_server.py

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>

---------

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-07-03 09:17:50 -07:00
Ricardo Lu
49b26e2cec feat: add ChatCompletion endpoint in OpenAI demo server. (#330) 2023-07-02 22:54:33 -07:00
Lily Liu
dafd924c1f Raise error for long prompt (#273) 2023-06-30 18:48:49 -07:00
Zhuohan Li
598dc4b79a [Fix] Weight loading for GPTBigCode (#313) 2023-06-29 22:14:17 -07:00
Zhuohan Li
85de093472 [Fix] Do not pin memory when in WSL (#312) 2023-06-29 15:00:21 -07:00
Zhanghao Wu
f72297562f Add news for the vllm+skypilot example (#314) 2023-06-29 12:32:37 -07:00
Bayang
9d27b09d12 Update README.md (#306) 2023-06-29 06:52:15 -07:00
Woosuk Kwon
998d9d1509 [Tokenizer] Add tokenizer mode (#298) 2023-06-28 14:19:22 -07:00
Lily Liu
425040d4c1 remove floats == 0 comparison (#285) 2023-06-28 14:11:51 -07:00
Woosuk Kwon
4338cc4750 [Tokenizer] Add an option to specify tokenizer (#284) 2023-06-28 09:46:58 -07:00
Jishnu Ray Chowdhury
bdd6b4c8bc Add LLM.set_tokenizer (#283) 2023-06-28 00:28:29 -07:00
Cody Yu
2b7d3aca2e Update setup.py (#282)
Co-authored-by: neubig <neubig@gmail.com>
2023-06-27 14:34:23 -07:00
twaka
4026a049d3 expand coverage of gpt2 model loading (#271) 2023-06-27 06:27:41 -07:00
Zhuohan Li
43710e8d09 [Fix] Fix default port number in benchmark scripts (#265) 2023-06-26 13:15:35 -07:00
Woosuk Kwon
526df28fb2 [BugFix] Fix a bug in counting running sequences (#266) 2023-06-26 13:09:02 -07:00
Zhuohan Li
2cf1a333b6 [Doc] Documentation for distributed inference (#261) 2023-06-26 11:34:23 -07:00
Zhuohan Li
0b7db411b5 [Bug] Fix the OOM condition for CPU cache (#260) 2023-06-26 11:16:13 -07:00
BasicCoder
471a7a4566 Compatible with Decapoda Research llama hf version (#251) 2023-06-26 09:23:57 -07:00
Lianmin Zheng
6214dd6ce9 Update README.md (#236) 2023-06-25 16:58:06 -07:00
metacryptom
0603379863 fix wrong using getattr to get dict value (#232) 2023-06-24 22:00:24 -07:00
Woosuk Kwon
665c48963b [Docs] Add GPTBigCode to supported models (#213) 2023-06-22 15:05:11 -07:00
Michael Feil
298695b766 GPTBigCode (StarCoder, SantaCoder Support) (#209) 2023-06-23 01:49:27 +08:00
Zhuohan Li
83658c8ace Bump up version to 0.1.1 (#204) 2023-06-22 15:33:32 +08:00
Zhuohan Li
1d24ccb96c [Fix] Better error message when there is OOM during cache initialization (#203) 2023-06-22 15:30:06 +08:00
Woosuk Kwon
14f0b39cda [Bugfix] Fix a bug in RequestOutput.finished (#202) 2023-06-22 00:17:24 -07:00
Zhuohan Li
2e0d314384 fix-ray (#193) 2023-06-22 00:21:41 +08:00
112 changed files with 8969 additions and 1883 deletions

101
.github/workflows/publish.yml vendored Normal file
View File

@@ -0,0 +1,101 @@
# This workflow will upload a Python Package to Release asset
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
name: Create Release
on:
push:
tags:
- v*
# Needed to create release and upload assets
permissions:
contents: write
jobs:
release:
# Retrieve tag and create release
name: Create Release
runs-on: ubuntu-latest
outputs:
upload_url: ${{ steps.create_release.outputs.upload_url }}
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Extract branch info
shell: bash
run: |
echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
- name: Create Release
id: create_release
uses: "actions/github-script@v6"
env:
RELEASE_TAG: ${{ env.release_tag }}
with:
github-token: "${{ secrets.GITHUB_TOKEN }}"
script: |
const script = require('.github/workflows/scripts/create_release.js')
await script(github, context, core)
wheel:
name: Build Wheel
runs-on: ${{ matrix.os }}
needs: release
strategy:
fail-fast: false
matrix:
os: ['ubuntu-20.04']
python-version: ['3.8', '3.9', '3.10', '3.11']
cuda-version: ['11.8'] # Github runner can't build anything older than 11.8
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Linux Env
if: ${{ runner.os == 'Linux' }}
run: |
bash -x .github/workflows/scripts/env.sh
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install CUDA ${{ matrix.cuda-version }}
run: |
bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
- name: Install PyTorch-cu${{ matrix.cuda-version }}
run: |
bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
- name: Build wheel
shell: bash
run: |
bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename)
asset_name=${wheel_name//"linux"/"manylinux1"}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
echo "asset_name=${asset_name}" >> $GITHUB_ENV
- name: Upload Release Asset
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ needs.release.outputs.upload_url }}
asset_path: ./dist/${{ env.wheel_name }}
asset_name: ${{ env.asset_name }}
asset_content_type: application/*
# (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
# - name: Publish package
# uses: pypa/gh-action-pypi-publish@release/v1.8
# with:
# repository-url: https://test.pypi.org/legacy/
# password: ${{ secrets.PYPI_API_TOKEN }}
# skip-existing: true

31
.github/workflows/pylint.yml vendored Normal file
View File

@@ -0,0 +1,31 @@
name: pylint
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
pylint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint==2.8.2
- name: Analysing the code with pylint
run: |
pylint vllm

15
.github/workflows/scripts/build.sh vendored Normal file
View File

@@ -0,0 +1,15 @@
#!/bin/bash
python_executable=python$1
cuda_home=/usr/local/cuda-$2
# Update paths
PATH=${cuda_home}/bin:$PATH
LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
# Install requirements
$python_executable -m pip install wheel packaging
$python_executable -m pip install -r requirements.txt
# Build
$python_executable setup.py bdist_wheel --dist-dir=dist

View File

@@ -0,0 +1,20 @@
// Uses Github's API to create the release and wait for result.
// We use a JS script since github CLI doesn't provide a way to wait for the release's creation and returns immediately.
module.exports = async (github, context, core) => {
try {
const response = await github.rest.repos.createRelease({
draft: false,
generate_release_notes: true,
name: process.env.RELEASE_TAG,
owner: context.repo.owner,
prerelease: false,
repo: context.repo.repo,
tag_name: process.env.RELEASE_TAG,
});
core.setOutput('upload_url', response.data.upload_url);
} catch (error) {
core.setFailed(error.message);
}
}

View File

@@ -0,0 +1,18 @@
#!/bin/bash
# Replace '.' with '-' ex: 11.8 -> 11-8
cuda_version=$(echo $1 | tr "." "-")
# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004
OS=$(echo $2 | tr -d ".\-")
# Installs CUDA
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
rm cuda-keyring_1.1-1_all.deb
sudo apt -qq update
sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version}
sudo apt clean
# Test nvcc
PATH=/usr/local/cuda-$1/bin:${PATH}
nvcc --version

56
.github/workflows/scripts/env.sh vendored Normal file
View File

@@ -0,0 +1,56 @@
#!/bin/bash
# This file installs common linux environment tools
export LANG C.UTF-8
# python_version=$1
sudo apt-get update && \
sudo apt-get install -y --no-install-recommends \
software-properties-common \
sudo apt-get install -y --no-install-recommends \
build-essential \
apt-utils \
ca-certificates \
wget \
git \
vim \
libssl-dev \
curl \
unzip \
unrar \
cmake \
net-tools \
sudo \
autotools-dev \
rsync \
jq \
openssh-server \
tmux \
screen \
htop \
pdsh \
openssh-client \
lshw \
dmidecode \
util-linux \
automake \
autoconf \
libtool \
net-tools \
pciutils \
libpci-dev \
libaio-dev \
libcap2 \
libtinfo5 \
fakeroot \
devscripts \
debhelper \
nfs-common
# Remove github bloat files to free up disk space
sudo rm -rf "/usr/local/share/boost"
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
sudo rm -rf "/usr/share/dotnet"

View File

@@ -0,0 +1,14 @@
#!/bin/bash
python_executable=python$1
cuda_version=$2
# Install torch
$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
$python_executable -m pip install torch -f https://download.pytorch.org/whl/cu${cuda_version//./}/torch_stable.html
# Print version information
$python_executable --version
$python_executable -c "import torch; print('PyTorch:', torch.__version__)"
$python_executable -c "import torch; print('CUDA:', torch.version.cuda)"
$python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"

31
.github/workflows/yapf.yml vendored Normal file
View File

@@ -0,0 +1,31 @@
name: yapf
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
yapf:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install yapf==0.32.0
pip install toml==0.10.2
- name: Running yapf
run: |
yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**'

3
.gitignore vendored
View File

@@ -170,3 +170,6 @@ cython_debug/
# Python pickle files # Python pickle files
*.pkl *.pkl
# Sphinx documentation
_build/

434
.pylintrc Normal file
View File

@@ -0,0 +1,434 @@
# This Pylint rcfile contains a best-effort configuration to uphold the
# best-practices and style described in the Google Python style guide:
# https://google.github.io/styleguide/pyguide.html
#
# Its canonical open-source location is:
# https://google.github.io/styleguide/pylintrc
[MASTER]
# Files or directories to be skipped. They should be base names, not paths.
ignore=docs,parallel_utils
# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.
ignore-patterns=
# Pickle collected data for later comparisons.
persistent=no
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=4
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
#enable=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=abstract-method,
apply-builtin,
arguments-differ,
attribute-defined-outside-init,
backtick,
bad-option-value,
basestring-builtin,
buffer-builtin,
c-extension-no-member,
consider-using-enumerate,
cmp-builtin,
cmp-method,
coerce-builtin,
coerce-method,
delslice-method,
div-method,
duplicate-code,
eq-without-hash,
execfile-builtin,
file-builtin,
filter-builtin-not-iterating,
fixme,
getslice-method,
global-statement,
hex-method,
idiv-method,
implicit-str-concat-in-sequence,
import-error,
import-self,
import-star-module-level,
inconsistent-return-statements,
input-builtin,
intern-builtin,
invalid-str-codec,
locally-disabled,
logging-fstring-interpolation, # added by vLLM
logging-not-lazy, # added by vLLM
long-builtin,
long-suffix,
map-builtin-not-iterating,
misplaced-comparison-constant,
missing-class-docstring, # TODO (vLLM): enable
missing-function-docstring,
missing-module-docstring, # TODO (vLLM): enable
metaclass-assignment,
next-method-called,
next-method-defined,
no-absolute-import,
no-else-break,
no-else-continue,
no-else-raise,
no-else-return,
no-init, # added
no-member,
no-name-in-module,
no-self-use,
nonzero-method,
oct-method,
old-division,
old-ne-operator,
old-octal-literal,
old-raise-syntax,
parameter-unpacking,
print-statement,
raising-string,
range-builtin-not-iterating,
raw_input-builtin,
rdiv-method,
reduce-builtin,
relative-import,
reload-builtin,
round-builtin,
setslice-method,
signature-differs,
standarderror-builtin,
suppressed-message,
sys-max-int,
too-few-public-methods,
too-many-ancestors,
too-many-arguments,
too-many-boolean-expressions,
too-many-branches,
too-many-instance-attributes,
too-many-locals,
too-many-nested-blocks,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
trailing-newlines,
unichr-builtin,
unicode-builtin,
unnecessary-pass,
unpacking-in-except,
unspecified-encoding,
useless-else-on-loop,
useless-object-inheritance,
useless-suppression,
using-cmp-argument,
wrong-import-order,
xrange-builtin,
zip-builtin-not-iterating,
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
# (visual studio) and html. You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages
reports=no
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
[BASIC]
# Good variable names which should always be accepted, separated by a comma
good-names=main,_
# Bad variable names which should always be refused, separated by a comma
bad-names=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
# Regular expression matching correct function names
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
# Regular expression matching correct variable names
variable-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct constant names
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct attribute names
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
# Regular expression matching correct argument names
argument-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class attribute names
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct inline iteration names
inlinevar-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class names
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
# Regular expression matching correct module names
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
# Regular expression matching correct method names
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=10
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=80
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
# lines made too long by directives to pytype.
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=(?x)(
^\s*(\#\ )?<?https?://\S+>?$|
^\s*(from\s+\S+\s+)?import\s+.+$)
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=yes
# Maximum number of lines in a module
max-module-lines=99999
# String used as indentation unit. The internal Google style guide mandates 2
# spaces. Google's externaly-published style guide says 4, consistent with
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
# projects (like TensorFlow).
indent-string=' '
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=TODO
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=yes
[VARIABLES]
# Tells whether we should check for unused import in __init__ files.
init-import=no
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,_cb
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging,absl.logging,tensorflow.io.logging
[SIMILARITIES]
# Minimum lines number of a similarity.
min-similarity-lines=4
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
[SPELLING]
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,
TERMIOS,
Bastion,
rexec,
sets
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant, absl
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls,
class_
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=StandardError,
Exception,
BaseException

View File

@@ -49,12 +49,15 @@ If not, please file a new issue, providing as much relevant information as possi
In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html). In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
We include a formatting script [`format.sh`](./format.sh) to format the code.
### Pull Requests ### Pull Requests
When submitting a pull request: When submitting a pull request:
1. Make sure your code has been rebased on top of the latest commit on the main branch. 1. Make sure your code has been rebased on top of the latest commit on the main branch.
2. Include a detailed description of the changes in the pull request. 2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
3. Include a detailed description of the changes in the pull request.
Explain why you made the changes you did. Explain why you made the changes you did.
If your pull request fixes an open issue, please include a reference to it in the description. If your pull request fixes an open issue, please include a reference to it in the description.

View File

@@ -17,8 +17,10 @@ Easy, fast, and cheap LLM serving for everyone
--- ---
*Latest News* 🔥 *Latest News* 🔥
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
- [2023/06] We officially released vLLM! vLLM has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid April. Check out our [blog post](https://vllm.ai). - [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
--- ---
@@ -28,7 +30,7 @@ vLLM is fast with:
- State-of-the-art serving throughput - State-of-the-art serving throughput
- Efficient management of attention key and value memory with **PagedAttention** - Efficient management of attention key and value memory with **PagedAttention**
- Dynamic batching of incoming requests - Continuous batching of incoming requests
- Optimized CUDA kernels - Optimized CUDA kernels
vLLM is flexible and easy to use with: vLLM is flexible and easy to use with:
@@ -41,10 +43,19 @@ vLLM is flexible and easy to use with:
vLLM seamlessly supports many Huggingface models, including the following architectures: vLLM seamlessly supports many Huggingface models, including the following architectures:
- Aquila (`BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
- GPT-2 (`gpt2`, `gpt2-xl`, etc.) - GPT-2 (`gpt2`, `gpt2-xl`, etc.)
- GPTNeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
- LLaMA (`lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):

View File

@@ -17,9 +17,11 @@ def main(args: argparse.Namespace):
# the engine will automatically process the request in multiple batches. # the engine will automatically process the request in multiple batches.
llm = LLM( llm = LLM(
model=args.model, model=args.model,
tokenizer=args.tokenizer,
tensor_parallel_size=args.tensor_parallel_size, tensor_parallel_size=args.tensor_parallel_size,
max_num_seqs=args.batch_size, max_num_seqs=args.batch_size,
max_num_batched_tokens=args.batch_size * args.input_len, max_num_batched_tokens=args.batch_size * args.input_len,
trust_remote_code=args.trust_remote_code,
) )
sampling_params = SamplingParams( sampling_params = SamplingParams(
@@ -63,6 +65,7 @@ if __name__ == '__main__':
description='Benchmark the latency of processing a single batch of ' description='Benchmark the latency of processing a single batch of '
'requests till completion.') 'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m') parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128) parser.add_argument('--output-len', type=int, default=128)
@@ -72,5 +75,7 @@ if __name__ == '__main__':
parser.add_argument('--use-beam-search', action='store_true') parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters', type=int, default=3, parser.add_argument('--num-iters', type=int, default=3,
help='Number of iterations to run.') help='Number of iterations to run.')
parser.add_argument('--trust-remote-code', action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@@ -24,20 +24,13 @@ from typing import AsyncGenerator, List, Tuple
import aiohttp import aiohttp
import numpy as np import numpy as np
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer import get_tokenizer
# (prompt len, output len, latency) # (prompt len, output len, latency)
REQUEST_LATENCY: List[Tuple[int, int, float]] = [] REQUEST_LATENCY: List[Tuple[int, int, float]] = []
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
config = AutoConfig.from_pretrained(model_name)
if config.model_type == "llama":
# A workaround for potential protobuf errors.
model_name = "hf-internal-testing/llama-tokenizer"
return AutoTokenizer.from_pretrained(model_name)
def sample_requests( def sample_requests(
dataset_path: str, dataset_path: str,
num_requests: int, num_requests: int,
@@ -184,7 +177,7 @@ def main(args: argparse.Namespace):
np.random.seed(args.seed) np.random.seed(args.seed)
api_url = f"http://{args.host}:{args.port}/generate" api_url = f"http://{args.host}:{args.port}/generate"
tokenizer = get_tokenizer(args.tokenizer) tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
benchmark_start_time = time.time() benchmark_start_time = time.time()
@@ -217,7 +210,7 @@ if __name__ == "__main__":
parser.add_argument("--backend", type=str, default="vllm", parser.add_argument("--backend", type=str, default="vllm",
choices=["vllm", "tgi"]) choices=["vllm", "tgi"])
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--dataset", type=str, required=True, parser.add_argument("--dataset", type=str, required=True,
help="Path to the dataset.") help="Path to the dataset.")
parser.add_argument("--tokenizer", type=str, required=True, parser.add_argument("--tokenizer", type=str, required=True,
@@ -234,5 +227,7 @@ if __name__ == "__main__":
"Otherwise, we use Poisson process to synthesize " "Otherwise, we use Poisson process to synthesize "
"the request arrival times.") "the request arrival times.")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code', action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@@ -6,23 +6,11 @@ import time
from typing import List, Tuple from typing import List, Tuple
import torch import torch
from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM, from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
PreTrainedTokenizerBase)
from tqdm import tqdm from tqdm import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
config = AutoConfig.from_pretrained(model_name)
if config.model_type == "llama":
# A workaround for potential protobuf errors.
model_name = "hf-internal-testing/llama-tokenizer"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
return AutoTokenizer.from_pretrained(model_name)
def sample_requests( def sample_requests(
@@ -74,15 +62,19 @@ def sample_requests(
def run_vllm( def run_vllm(
requests: List[Tuple[str, int, int]], requests: List[Tuple[str, int, int]],
model: str, model: str,
tokenizer: str,
tensor_parallel_size: int, tensor_parallel_size: int,
seed: int, seed: int,
n: int, n: int,
use_beam_search: bool, use_beam_search: bool,
trust_remote_code: bool,
) -> float: ) -> float:
llm = LLM( llm = LLM(
model=model, model=model,
tokenizer=tokenizer,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
seed=seed, seed=seed,
trust_remote_code=trust_remote_code,
) )
# Add the requests to the engine. # Add the requests to the engine.
@@ -116,11 +108,14 @@ def run_hf(
n: int, n: int,
use_beam_search: bool, use_beam_search: bool,
max_batch_size: int, max_batch_size: int,
trust_remote_code: bool,
) -> float: ) -> float:
assert not use_beam_search assert not use_beam_search
tokenizer = get_tokenizer(model) llm = AutoModelForCausalLM.from_pretrained(model,
llm = AutoModelForCausalLM.from_pretrained( torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
model, torch_dtype=torch.float16) if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
llm = llm.cuda() llm = llm.cuda()
pbar = tqdm(total=len(requests)) pbar = tqdm(total=len(requests))
@@ -170,17 +165,18 @@ def main(args: argparse.Namespace):
random.seed(args.seed) random.seed(args.seed)
# Sample the requests. # Sample the requests.
tokenizer = get_tokenizer(args.model) tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer) requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm( elapsed_time = run_vllm(
requests, args.model, args.tensor_parallel_size, args.seed, args.n, requests, args.model, args.tokenizer, args.tensor_parallel_size,
args.use_beam_search) args.seed, args.n, args.use_beam_search, args.trust_remote_code)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(
args.use_beam_search, args.hf_max_batch_size) requests, args.model, tokenizer, args.n, args.use_beam_search,
args.hf_max_batch_size, args.trust_remote_code)
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum( total_num_tokens = sum(
@@ -198,6 +194,7 @@ if __name__ == "__main__":
parser.add_argument("--dataset", type=str, required=True, parser.add_argument("--dataset", type=str, required=True,
help="Path to the dataset.") help="Path to the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m") parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", type=int, default=1, parser.add_argument("--n", type=int, default=1,
help="Number of generated sequences per prompt.") help="Number of generated sequences per prompt.")
@@ -207,12 +204,18 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size", type=int, default=None, parser.add_argument("--hf-max-batch-size", type=int, default=None,
help="Maximum batch size for HF backend.") help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args() args = parser.parse_args()
if args.backend == "vllm": if args.backend == "vllm":
if args.hf_max_batch_size is not None: if args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.") raise ValueError("HF max batch size is only for HF backend.")
elif args.backend == "hf": elif args.backend == "hf":
if args.hf_max_batch_size is None: if args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend.") raise ValueError("HF max batch size is required for HF backend.")
if args.tokenizer is None:
args.tokenizer = args.model
main(args) main(args)

View File

@@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
PORT=8001 PORT=8000
MODEL=$1 MODEL=$1
TOKENS=$2 TOKENS=$2

View File

@@ -4,9 +4,25 @@ void silu_and_mul(
torch::Tensor& out, torch::Tensor& out,
torch::Tensor& input); torch::Tensor& input);
void gelu_new(
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast(
torch::Tensor& out,
torch::Tensor& input);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"silu_and_mul", "silu_and_mul",
&silu_and_mul, &silu_and_mul,
"Activation function used in SwiGLU."); "Activation function used in SwiGLU.");
m.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
m.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
} }

View File

@@ -1,6 +1,8 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
namespace vllm { namespace vllm {
template<typename T> template<typename T>
@@ -34,9 +36,7 @@ void silu_and_mul(
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(d, 1024)); dim3 block(std::min(d, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2( VLLM_DISPATCH_FLOATING_TYPES(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(), input.scalar_type(),
"silu_and_mul_kernel", "silu_and_mul_kernel",
[&] { [&] {
@@ -46,3 +46,69 @@ void silu_and_mul(
d); d);
}); });
} }
namespace vllm {
// Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
scalar_t* __restrict__ out, // [num_tokens, d]
const scalar_t* __restrict__ input, // [num_tokens, d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
}
} // namespace vllm
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int num_tokens = input.size(0); \
int d = input.size(1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"activation_kernel", \
[&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
namespace vllm {
template<typename T>
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
const float x3 = (float) (x * x * x);
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
return ((T) 0.5) * x * (((T) 1.0) + t);
}
template<typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
const float f = (float) x;
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
return ((T) 0.5) * x * (((T) 1.0) + t);
}
} // namespace vllm
void gelu_new(
torch::Tensor& out, // [num_tokens, d]
torch::Tensor& input) // [num_tokens, d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}
void gelu_fast(
torch::Tensor& out, // [num_tokens, d]
torch::Tensor& input) // [num_tokens, d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}

View File

@@ -1,15 +1,18 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <c10/util/Optional.h>
void single_query_cached_kv_attention( void single_query_cached_kv_attention(
torch::Tensor& out, torch::Tensor& out,
torch::Tensor& query, torch::Tensor& query,
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& context_lens,
int block_size, int block_size,
int max_context_len); int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def( m.def(

View File

@@ -74,14 +74,20 @@ template<
__global__ void single_query_cached_kv_attention_kernel( __global__ void single_query_cached_kv_attention_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int* __restrict__ head_mapping, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const int q_stride) { const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x; const int thread_idx = threadIdx.x;
@@ -90,7 +96,9 @@ __global__ void single_query_cached_kv_attention_kernel(
const int head_idx = blockIdx.x; const int head_idx = blockIdx.x;
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int kv_head_idx = head_mapping[head_idx];
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query. // A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group // The vector size is configured in such a way that the threads in a thread group
@@ -114,12 +122,13 @@ __global__ void single_query_cached_kv_attention_kernel(
// th vectors of the query, and so on. // th vectors of the query, and so on.
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
Q_vec q_vecs[NUM_VECS_PER_THREAD]; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
} }
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
// Memory planning. // Memory planning.
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
@@ -156,8 +165,8 @@ __global__ void single_query_cached_kv_attention_kernel(
#pragma unroll #pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ head_idx * HEAD_SIZE * BLOCK_SIZE + kv_head_idx * kv_head_stride
+ physical_block_offset * x; + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset1 = (vec_idx * VEC_SIZE) / x;
@@ -167,12 +176,14 @@ __global__ void single_query_cached_kv_attention_kernel(
// Compute dot product. // Compute dot product.
// This includes a reduction across the threads in the same thread group. // This includes a reduction across the threads in the same thread group.
const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs); float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
const bool mask = token_idx >= context_len; // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// Store the partial reductions to shared memory. // Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits. // NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= context_len;
logits[token_idx] = mask ? 0.f : qk; logits[token_idx] = mask ? 0.f : qk;
// Update the max value. // Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_max = mask ? qk_max : fmaxf(qk_max, qk);
@@ -235,6 +246,8 @@ __global__ void single_query_cached_kv_attention_kernel(
accs[i] = 0.f; accs[i] = 0.f;
} }
scalar_t zero_value;
zero(zero_value);
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx]; const int physical_block_number = block_table[block_idx];
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
@@ -242,14 +255,24 @@ __global__ void single_query_cached_kv_attention_kernel(
L_vec logits_vec; L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx)); from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ head_idx * HEAD_SIZE * BLOCK_SIZE; + kv_head_idx * kv_head_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) { if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset; const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
if (block_idx == num_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs.
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j <= V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
}
}
accs[i] += dot(logits_vec, v_vec); accs[i] += dot(logits_vec, v_vec);
} }
} }
@@ -324,11 +347,15 @@ __global__ void single_query_cached_kv_attention_kernel(
query_ptr, \ query_ptr, \
key_cache_ptr, \ key_cache_ptr, \
value_cache_ptr, \ value_cache_ptr, \
head_mapping_ptr, \
scale, \ scale, \
block_tables_ptr, \ block_tables_ptr, \
context_lens_ptr, \ context_lens_ptr, \
max_num_blocks_per_seq, \ max_num_blocks_per_seq, \
query_stride); alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride);
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template< template<
@@ -340,23 +367,33 @@ void single_query_cached_kv_attention_launcher(
torch::Tensor& query, torch::Tensor& query,
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& context_lens,
int max_context_len) { int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1); int max_num_blocks_per_seq = block_tables.size(1);
int query_stride = query.stride(0); int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr()); T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>(); int* context_lens_ptr = context_lens.data_ptr<int>();
@@ -371,7 +408,7 @@ void single_query_cached_kv_attention_launcher(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) { switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes // NOTE(woosuk): To reduce the compilation time, we omitted head sizes
// 32, 160, 192, 256. // 32, 160, 192.
// case 32: // case 32:
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); // LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
// break; // break;
@@ -384,6 +421,9 @@ void single_query_cached_kv_attention_launcher(
case 96: case 96:
LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
break; break;
case 112:
LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS);
break;
case 128: case 128:
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
break; break;
@@ -393,9 +433,9 @@ void single_query_cached_kv_attention_launcher(
// case 192: // case 192:
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); // LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
// break; // break;
// case 256: case 256:
// LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
// break; break;
default: default:
TORCH_CHECK(false, "Unsupported head size: ", head_size); TORCH_CHECK(false, "Unsupported head size: ", head_size);
break; break;
@@ -408,10 +448,12 @@ void single_query_cached_kv_attention_launcher(
query, \ query, \
key_cache, \ key_cache, \
value_cache, \ value_cache, \
head_mapping, \
scale, \ scale, \
block_tables, \ block_tables, \
context_lens, \ context_lens, \
max_context_len); max_context_len, \
alibi_slopes);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
@@ -454,11 +496,13 @@ void single_query_cached_kv_attention(
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& head_mapping, // [num_heads]
float scale, float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs] torch::Tensor& context_lens, // [num_seqs]
int block_size, int block_size,
int max_context_len) { int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
if (query.dtype() == at::ScalarType::Float) { if (query.dtype() == at::ScalarType::Float) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
} else if (query.dtype() == at::ScalarType::Half) { } else if (query.dtype() == at::ScalarType::Half) {

View File

@@ -420,4 +420,14 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
#endif #endif
} }
// Zero-out a variable.
inline __device__ void zero(__nv_bfloat16& dst) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
#endif
}
} // namespace vllm } // namespace vllm

View File

@@ -390,11 +390,6 @@ inline __device__ float sum(uint4 v) {
return sum(c); return sum(c);
} }
// Zero-out a vector.
inline __device__ void zero(uint16_t& dst) {
dst = uint16_t(0);
}
// From float32 to float16. // From float32 to float16.
inline __device__ void from_float(uint16_t& dst, float src) { inline __device__ void from_float(uint16_t& dst, float src) {
dst = float_to_half(src); dst = float_to_half(src);
@@ -441,4 +436,9 @@ inline __device__ Float8_ to_float(uint4 u) {
return tmp; return tmp;
} }
// Zero-out a variable.
inline __device__ void zero(uint16_t& dst) {
dst = uint16_t(0);
}
} // namespace vllm } // namespace vllm

View File

@@ -265,4 +265,9 @@ inline __device__ Float8_ to_float(Float8_ u) {
return u; return u;
} }
// Zero-out a variable.
inline __device__ void zero(float& dst) {
dst = 0.f;
}
} // namespace vllm } // namespace vllm

View File

@@ -1,6 +1,8 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <map> #include <map>
@@ -125,9 +127,7 @@ void copy_blocks(
dim3 grid(num_layers, num_pairs); dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, numel_per_block)); dim3 block(std::min(1024, numel_per_block));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2( VLLM_DISPATCH_FLOATING_TYPES(
at::ScalarType::Half,
at::ScalarType::BFloat16,
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(), key_cache_ptrs_tensor.data_ptr<int64_t>(),
@@ -202,9 +202,7 @@ void reshape_and_cache(
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512)); dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2( VLLM_DISPATCH_FLOATING_TYPES(
at::ScalarType::Half,
at::ScalarType::BFloat16,
key.scalar_type(), key.scalar_type(),
"reshape_and_cache_kernel", "reshape_and_cache_kernel",
[&] { [&] {
@@ -364,9 +362,7 @@ void gather_cached_kv(
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512)); dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2( VLLM_DISPATCH_FLOATING_TYPES(
at::ScalarType::Half,
at::ScalarType::BFloat16,
key.scalar_type(), key.scalar_type(),
"gather_cached_kv_kernel_optimized", "gather_cached_kv_kernel_optimized",
[&] { [&] {

14
csrc/dispatch_utils.h Normal file
View File

@@ -0,0 +1,14 @@
/*
* Adapted from
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
*/
#include <torch/extension.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

View File

@@ -1,6 +1,7 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
#include "reduction_utils.cuh" #include "reduction_utils.cuh"
namespace vllm { namespace vllm {
@@ -46,9 +47,7 @@ void rms_norm(
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2( VLLM_DISPATCH_FLOATING_TYPES(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(), input.scalar_type(),
"rms_norm_kernel", "rms_norm_kernel",
[&] { [&] {

View File

@@ -1,15 +1,16 @@
#include <torch/extension.h> #include <torch/extension.h>
void rotary_embedding_neox( void rotary_embedding(
torch::Tensor& positions, torch::Tensor& positions,
torch::Tensor& query, torch::Tensor& query,
torch::Tensor& key, torch::Tensor& key,
int head_size, int head_size,
torch::Tensor& cos_sin_cache); torch::Tensor& cos_sin_cache,
bool is_neox);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"rotary_embedding_neox", "rotary_embedding",
&rotary_embedding_neox, &rotary_embedding,
"Apply GPT-NeoX style rotary embedding to query and key"); "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
} }

View File

@@ -1,17 +1,51 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
namespace vllm { namespace vllm {
template<typename scalar_t> template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_neox_kernel( inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ arr,
const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index;
scalar_t cos, sin;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = __ldg(cos_ptr + x_index);
sin = __ldg(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = __ldg(cos_ptr + x_index / 2);
sin = __ldg(sin_ptr + x_index / 2);
}
const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}
template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [num_tokens] const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim, const int rot_dim,
const int stride, const int query_stride,
const int key_stride,
const int num_heads, const int num_heads,
const int num_kv_heads,
const int head_size) { const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
@@ -19,65 +53,75 @@ __global__ void rotary_embedding_neox_kernel(
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
const int n = num_heads * embed_dim; const scalar_t* cos_ptr = cache_ptr;
for (int i = threadIdx.x; i < n; i += blockDim.x) { const scalar_t* sin_ptr = cache_ptr + embed_dim;
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int token_head = token_idx * stride + head_idx * head_size; const int token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
const int x_index = rot_offset; apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
const int y_index = embed_dim + rot_offset; sin_ptr, rot_offset, embed_dim);
}
const int out_x = token_idx * stride + head_idx * head_size + x_index; const int nk = num_kv_heads * embed_dim;
const int out_y = token_idx * stride + head_idx * head_size + y_index; for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const scalar_t cos = __ldg(cache_ptr + x_index); const int token_head = token_idx * key_stride + head_idx * head_size;
const scalar_t sin = __ldg(cache_ptr + y_index); const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
const scalar_t q_x = query[token_head + x_index]; sin_ptr, rot_offset, embed_dim);
const scalar_t q_y = query[token_head + y_index];
query[out_x] = q_x * cos - q_y * sin;
query[out_y] = q_y * cos + q_x * sin;
const scalar_t k_x = key[token_head + x_index];
const scalar_t k_y = key[token_head + y_index];
key[out_x] = k_x * cos - k_y * sin;
key[out_y] = k_y * cos + k_x * sin;
} }
} }
} // namespace vllm } // namespace vllm
void rotary_embedding_neox( void rotary_embedding(
torch::Tensor& positions, // [num_tokens] torch::Tensor& positions, // [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size] torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_heads * head_size] torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
int head_size, int head_size,
torch::Tensor& cos_sin_cache) // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
{ bool is_neox) {
int num_tokens = query.size(0); int num_tokens = query.size(0);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size; int num_heads = query.size(1) / head_size;
int stride = query.stride(0); int num_kv_heads = key.size(1) / head_size;
TORCH_CHECK(stride == key.stride(0)); int query_stride = query.stride(0);
int key_stride = key.stride(0);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2( VLLM_DISPATCH_FLOATING_TYPES(
at::ScalarType::Half,
at::ScalarType::BFloat16,
query.scalar_type(), query.scalar_type(),
"rotary_embedding_neox", "rotary_embedding",
[&] { [&] {
vllm::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>( if (is_neox) {
positions.data_ptr<int64_t>(), vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(),
key.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(),
rot_dim, cos_sin_cache.data_ptr<scalar_t>(),
stride, rot_dim,
num_heads, query_stride,
head_size); key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
}); });
} }

View File

@@ -4,14 +4,14 @@
```bash ```bash
# Install dependencies. # Install dependencies.
pip -r requirements-docs.txt pip install -r requirements-docs.txt
# Build the docs. # Build the docs.
make clean make clean
make html make html
``` ```
## Open the docs with your brower ## Open the docs with your browser
```bash ```bash
python -m http.server -d build/html/ python -m http.server -d build/html/

View File

@@ -3,31 +3,15 @@
Installation Installation
============ ============
vLLM is a Python library that also contains some C++ and CUDA code. vLLM is a Python library that also contains pre-compiled C++ and CUDA (11.8) binaries.
This additional code requires compilation on the user's machine.
Requirements Requirements
------------ ------------
* OS: Linux * OS: Linux
* Python: 3.8 or higher * Python: 3.8 -- 3.11
* CUDA: 11.0 -- 11.8
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.) * GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
.. note::
As of now, vLLM does not support CUDA 12.
If you are using Hopper or Lovelace GPUs, please use CUDA 11.8 instead of CUDA 12.
.. tip::
If you have trouble installing vLLM, we recommend using the NVIDIA PyTorch Docker image.
.. code-block:: console
$ # Pull the Docker image with CUDA 11.8.
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
Inside the Docker container, please execute :code:`pip uninstall torch` before installing vLLM.
Install with pip Install with pip
---------------- ----------------
@@ -40,7 +24,7 @@ You can install vLLM using pip:
$ conda activate myenv $ conda activate myenv
$ # Install vLLM. $ # Install vLLM.
$ pip install vllm # This may take 5-10 minutes. $ pip install vllm
.. _build_from_source: .. _build_from_source:
@@ -55,3 +39,11 @@ You can also build and install vLLM from source:
$ git clone https://github.com/vllm-project/vllm.git $ git clone https://github.com/vllm-project/vllm.git
$ cd vllm $ cd vllm
$ pip install -e . # This may take 5-10 minutes. $ pip install -e . # This may take 5-10 minutes.
.. tip::
If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image.
.. code-block:: console
$ # Pull the Docker image with CUDA 11.8.
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3

View File

@@ -29,7 +29,7 @@ vLLM is fast with:
* State-of-the-art serving throughput * State-of-the-art serving throughput
* Efficient management of attention key and value memory with **PagedAttention** * Efficient management of attention key and value memory with **PagedAttention**
* Dynamic batching of incoming requests * Continuous batching of incoming requests
* Optimized CUDA kernels * Optimized CUDA kernels
vLLM is flexible and easy to use with: vLLM is flexible and easy to use with:
@@ -40,7 +40,11 @@ vLLM is flexible and easy to use with:
* Streaming outputs * Streaming outputs
* OpenAI-compatible API server * OpenAI-compatible API server
For more information, please refer to our `blog post <https://vllm.ai>`_. For more information, check out the following:
* `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention)
* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al.
Documentation Documentation
@@ -53,6 +57,13 @@ Documentation
getting_started/installation getting_started/installation
getting_started/quickstart getting_started/quickstart
.. toctree::
:maxdepth: 1
:caption: Serving
serving/distributed_serving
serving/run_on_sky
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
:caption: Models :caption: Models

View File

@@ -59,7 +59,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
+ kv_caches: List[KVCache], + kv_caches: List[KVCache],
+ input_metadata: InputMetadata, + input_metadata: InputMetadata,
+ cache_events: Optional[List[torch.cuda.Event]], + cache_events: Optional[List[torch.cuda.Event]],
+) -> Dict[int, SequenceOutputs]: +) -> SamplerOutput:
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors. 3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture. 4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.

View File

@@ -14,18 +14,45 @@ Alongside each architecture, we include some popular models that use it.
* - Architecture * - Architecture
- Models - Models
- Example HuggingFace Models - Example HuggingFace Models
* - :code:`AquilaForCausalLM`
- Aquila
- :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc.
* - :code:`BaiChuanForCausalLM`
- Baichuan
- :code:`baichuan-inc/Baichuan-7B`, :code:`baichuan-inc/Baichuan-13B-Chat`, etc.
* - :code:`BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
* - :code:`FalconForCausalLM`
- Falcon
- :code:`tiiuae/falcon-7b``, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
* - :code:`GPT2LMHeadModel` * - :code:`GPT2LMHeadModel`
- GPT-2 - GPT-2
- :code:`gpt2`, :code:`gpt2-xl`, etc. - :code:`gpt2`, :code:`gpt2-xl`, etc.
* - :code:`GPTBigCodeForCausalLM`
- StarCoder, SantaCoder, WizardCoder
- :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc.
* - :code:`GPTJForCausalLM`
- GPT-J
- :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc.
* - :code:`GPTNeoXForCausalLM` * - :code:`GPTNeoXForCausalLM`
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM - GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc. - :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
* - :code:`InternLMForCausalLM`
- InternLM
- :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.
* - :code:`LlamaForCausalLM` * - :code:`LlamaForCausalLM`
- LLaMA, Vicuna, Alpaca, Koala, Guanaco - LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
- :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc. - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc.
* - :code:`MPTForCausalLM`
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
* - :code:`OPTForCausalLM` * - :code:`OPTForCausalLM`
- OPT, OPT-IML - OPT, OPT-IML
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc. - :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
* - :code:`QWenLMHeadModel`
- Qwen
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model. Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.

View File

@@ -0,0 +1,38 @@
.. _distributed_serving:
Distributed Inference and Serving
=================================
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with `Ray <https://github.com/ray-project/ray>`_. To run distributed inference, install Ray with:
.. code-block:: console
$ pip install ray
To run multi-GPU inference with the :code:`LLM` class, set the :code:`tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs:
.. code-block:: python
from vllm import LLM
llm = LLM("facebook/opt-13b", tensor_parallel_size=4)
output = llm.generate("San Franciso is a")
To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument when starting the server. For example, to run API server on 4 GPUs:
.. code-block:: console
$ python -m vllm.entrypoints.api_server \
$ --model facebook/opt-13b \
$ --tensor-parallel-size 4
To scale vLLM beyond a single machine, start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:
.. code-block:: console
$ # On head node
$ ray start --head
$ # On worker nodes
$ ray start --address=<ray-head-address>
After that, you can run inference and serving on multiple machines by launching the vLLM process on the head node by setting :code:`tensor_parallel_size` to the number of GPUs to be the total number of GPUs across all machines.

View File

@@ -0,0 +1,69 @@
.. _on_cloud:
Running on clouds with SkyPilot
===============================
.. raw:: html
<p align="center">
<img src="https://imgur.com/yxtzPEu.png" alt="vLLM"/>
</p>
vLLM can be run on the cloud to scale to multiple GPUs with `SkyPilot <https://github.com/skypilot-org/skypilot>`__, an open-source framework for running LLMs on any cloud.
To install SkyPilot and setup your cloud credentials, run:
.. code-block:: console
$ pip install skypilot
$ sky check
See the vLLM SkyPilot YAML for serving, `serving.yaml <https://github.com/skypilot-org/skypilot/blob/master/llm/vllm/serve.yaml>`__.
.. code-block:: yaml
resources:
accelerators: A100
envs:
MODEL_NAME: decapoda-research/llama-13b-hf
TOKENIZER: hf-internal-testing/llama-tokenizer
setup: |
conda create -n vllm python=3.9 -y
conda activate vllm
git clone https://github.com/vllm-project/vllm.git
cd vllm
pip install .
pip install gradio
run: |
conda activate vllm
echo 'Starting vllm api server...'
python -u -m vllm.entrypoints.api_server \
--model $MODEL_NAME \
--tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
--tokenizer $TOKENIZER 2>&1 | tee api_server.log &
echo 'Waiting for vllm api server to start...'
while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done
echo 'Starting gradio server...'
python vllm/examples/gradio_webserver.py
Start the serving the LLaMA-13B model on an A100 GPU:
.. code-block:: console
$ sky launch serving.yaml
Check the output of the command. There will be a sharable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
.. code-block:: console
(task, pid=7431) Running on public URL: https://<gradio-hash>.gradio.live
**Optional**: Serve the 65B model instead of the default 13B and use more GPU:
.. code-block:: console
sky launch -c vllm-serve-new -s serve.yaml --gpus A100:8 --env MODEL_NAME=decapoda-research/llama-65b-hf

View File

@@ -14,7 +14,9 @@ def clear_line(n: int = 1) -> None:
print(LINE_UP, end=LINE_CLEAR, flush=True) print(LINE_UP, end=LINE_CLEAR, flush=True)
def post_http_request(prompt: str, api_url: str, n: int = 1, def post_http_request(prompt: str,
api_url: str,
n: int = 1,
stream: bool = False) -> requests.Response: stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"} headers = {"User-Agent": "Test Client"}
pload = { pload = {
@@ -30,7 +32,8 @@ def post_http_request(prompt: str, api_url: str, n: int = 1,
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]: def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"): delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode("utf-8")) data = json.loads(chunk.decode("utf-8"))

View File

@@ -12,9 +12,14 @@ def http_bot(prompt):
"stream": True, "stream": True,
"max_tokens": 128, "max_tokens": 128,
} }
response = requests.post(args.model_url, headers=headers, json=pload, stream=True) response = requests.post(args.model_url,
headers=headers,
json=pload,
stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode("utf-8")) data = json.loads(chunk.decode("utf-8"))
output = data["text"][0] output = data["text"][0]
@@ -23,11 +28,11 @@ def http_bot(prompt):
def build_demo(): def build_demo():
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown( gr.Markdown("# vLLM text completion demo\n")
"# vLLM text completion demo\n" inputbox = gr.Textbox(label="Input",
) placeholder="Enter text and press ENTER")
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER") outputbox = gr.Textbox(label="Output",
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model") placeholder="Generated result from the model")
inputbox.submit(http_bot, [inputbox], [outputbox]) inputbox.submit(http_bot, [inputbox], [outputbox])
return demo return demo
@@ -36,7 +41,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--model-url", type=str, default="http://localhost:8000/generate") parser.add_argument("--model-url",
type=str,
default="http://localhost:8000/generate")
args = parser.parse_args() args = parser.parse_args()
demo = build_demo() demo = build_demo()

View File

@@ -10,19 +10,25 @@ def main(args: argparse.Namespace):
# Test the following prompts. # Test the following prompts.
test_prompts = [ test_prompts = [
("A robot may not injure a human being", SamplingParams()), ("A robot may not injure a human being",
SamplingParams(temperature=0.0)),
("To be or not to be,", ("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)), SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?", ("What is the meaning of life?",
SamplingParams(n=2, best_of=5, temperature=0.8, top_p=0.95, frequency_penalty=0.1)), SamplingParams(n=2,
best_of=5,
temperature=0.8,
top_p=0.95,
frequency_penalty=0.1)),
("It is only with the heart that one can see rightly", ("It is only with the heart that one can see rightly",
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)), SamplingParams(n=3, best_of=3, use_beam_search=True,
temperature=0.0)),
] ]
# Run the engine by calling `engine.step()` manually. # Run the engine by calling `engine.step()` manually.
request_id = 0 request_id = 0
while True: while True:
# To test iteration-level scheduling, we add one request at each step. # To test continuous batching, we add one request at each step.
if test_prompts: if test_prompts:
prompt, sampling_params = test_prompts.pop(0) prompt, sampling_params = test_prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params) engine.add_request(str(request_id), prompt, sampling_params)
@@ -30,7 +36,7 @@ def main(args: argparse.Namespace):
request_outputs = engine.step() request_outputs = engine.step()
for request_output in request_outputs: for request_output in request_outputs:
if request_output.finished(): if request_output.finished:
print(request_output) print(request_output)
if not (engine.has_unfinished_requests() or test_prompts): if not (engine.has_unfinished_requests() or test_prompts):

View File

@@ -1,6 +1,5 @@
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",

View File

@@ -0,0 +1,33 @@
import openai
# Modify OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
# List models API
models = openai.Model.list()
print("Models:", models)
model = models["data"][0]["id"]
# Chat completion API
chat_completion = openai.ChatCompletion.create(
model=model,
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}])
print("Chat completion results:")
print(chat_completion)

View File

@@ -3,21 +3,26 @@ import openai
# Modify OpenAI's API key and API base to use vLLM's API server. # Modify OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY" openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1" openai.api_base = "http://localhost:8000/v1"
model = "facebook/opt-125m"
# Test list models API # List models API
models = openai.Model.list() models = openai.Model.list()
print("Models:", models) print("Models:", models)
# Test completion API model = models["data"][0]["id"]
stream = True
completion = openai.Completion.create(
model=model, prompt="A robot may not injure a human being", echo=False, n=2,
best_of=3, stream=stream, logprobs=3)
# print the completion # Completion API
stream = False
completion = openai.Completion.create(
model=model,
prompt="A robot may not injure a human being",
echo=False,
n=2,
stream=stream,
logprobs=3)
print("Completion results:")
if stream: if stream:
for c in completion: for c in completion:
print(c) print(c)
else: else:
print("Completion result:", completion) print(completion)

108
format.sh Executable file
View File

@@ -0,0 +1,108 @@
#!/usr/bin/env bash
# YAPF formatter, adapted from ray and skypilot.
#
# Usage:
# # Do work and commit your work.
# # Format files that differ from origin/main.
# bash format.sh
# # Commit changed files with message 'Run yapf and pylint'
#
#
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
set -eo pipefail
# this stops git rev-parse from failing if we run this from the .git directory
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
ROOT="$(git rev-parse --show-toplevel)"
builtin cd "$ROOT" || exit 1
YAPF_VERSION=$(yapf --version | awk '{print $2}')
PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}')
MYPY_VERSION=$(mypy --version | awk '{print $2}')
# # params: tool name, tool version, required version
tool_version_check() {
if [[ $2 != $3 ]]; then
echo "Wrong $1 version installed: $3 is required, not $2."
exit 1
fi
}
tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "pylint" $PYLINT_VERSION "$(grep "pylint==" requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
YAPF_FLAGS=(
'--recursive'
'--parallel'
)
YAPF_EXCLUDES=(
'--exclude' 'build/**'
'--exclude' 'vllm/model_executor/parallel_utils/**'
)
# Format specified files
format() {
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
}
# Format files that differ from main branch. Ignores dirs that are not slated
# for autoformat yet.
format_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause yapf to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
fi
}
# Format all files
format_all() {
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm
}
## This flag formats individual files. --files *must* be the first command line
## arg to use this option.
if [[ "$1" == '--files' ]]; then
format "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is formatted.
elif [[ "$1" == '--all' ]]; then
format_all
else
# Format only the files that changed in last commit.
format_changed
fi
echo 'vLLM yapf: Done'
# Run mypy
# TODO(zhuohan): Enable mypy
# echo 'vLLM mypy:'
# mypy
# Run Pylint
echo 'vLLM Pylint:'
pylint vllm
if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:'
echo
git --no-pager diff --name-only
exit 1
fi

View File

@@ -1,2 +1,13 @@
mypy # formatting
yapf==0.32.0
pylint==2.8.2
# type checking
mypy==0.991
types-PyYAML
types-requests
types-setuptools
# testing
pytest pytest
pytest-forked

View File

@@ -1,11 +1,11 @@
ninja # For faster builds. ninja # For faster builds.
psutil psutil
ray ray >= 2.5.1
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy numpy
torch >= 2.0.0 torch >= 2.0.0
transformers >= 4.28.0 # Required for LLaMA. transformers >= 4.33.1 # Required for Code Llama.
xformers >= 0.0.19 xformers >= 0.0.21
fastapi fastapi
uvicorn uvicorn
pydantic # Required for OpenAI server. pydantic < 2 # Required for OpenAI server.

View File

@@ -20,10 +20,9 @@ ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
if not torch.cuda.is_available(): if CUDA_HOME is None:
raise RuntimeError( raise RuntimeError(
f"Cannot find CUDA at CUDA_HOME: {CUDA_HOME}. " "Cannot find CUDA_HOME. CUDA must be available to build the package.")
"CUDA must be available in order to build the package.")
def get_nvcc_cuda_version(cuda_dir: str) -> Version: def get_nvcc_cuda_version(cuda_dir: str) -> Version:
@@ -48,12 +47,6 @@ for i in range(device_count):
raise RuntimeError( raise RuntimeError(
"GPUs with compute capability less than 7.0 are not supported.") "GPUs with compute capability less than 7.0 are not supported.")
compute_capabilities.add(major * 10 + minor) compute_capabilities.add(major * 10 + minor)
# If no GPU is available, add all supported compute capabilities.
if not compute_capabilities:
compute_capabilities = {70, 75, 80, 86, 90}
# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
# Validate the NVCC CUDA version. # Validate the NVCC CUDA version.
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
@@ -61,10 +54,35 @@ if nvcc_cuda_version < Version("11.0"):
raise RuntimeError("CUDA 11.0 or higher is required to build the package.") raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"): if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
raise RuntimeError( raise RuntimeError(
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.") "CUDA 11.1 or higher is required for GPUs with compute capability 8.6."
)
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by
# the previous versions of CUDA 11 and targeting compute capability 8.0.
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
# instead of 8.9.
compute_capabilities.remove(89)
compute_capabilities.add(80)
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
raise RuntimeError( raise RuntimeError(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.") "CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
)
# If no GPU is available, add all supported compute capabilities.
if not compute_capabilities:
compute_capabilities = {70, 75, 80}
if nvcc_cuda_version >= Version("11.1"):
compute_capabilities.add(86)
if nvcc_cuda_version >= Version("11.8"):
compute_capabilities.add(89)
compute_capabilities.add(90)
# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
NVCC_FLAGS += [
"-gencode", f"arch=compute_{capability},code=sm_{capability}"
]
# Use NVCC threads to parallelize the build. # Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"): if nvcc_cuda_version >= Version("11.2"):
@@ -77,7 +95,10 @@ ext_modules = []
cache_extension = CUDAExtension( cache_extension = CUDAExtension(
name="vllm.cache_ops", name="vllm.cache_ops",
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"], sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
) )
ext_modules.append(cache_extension) ext_modules.append(cache_extension)
@@ -85,7 +106,10 @@ ext_modules.append(cache_extension)
attention_extension = CUDAExtension( attention_extension = CUDAExtension(
name="vllm.attention_ops", name="vllm.attention_ops",
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"], sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
) )
ext_modules.append(attention_extension) ext_modules.append(attention_extension)
@@ -93,7 +117,10 @@ ext_modules.append(attention_extension)
positional_encoding_extension = CUDAExtension( positional_encoding_extension = CUDAExtension(
name="vllm.pos_encoding_ops", name="vllm.pos_encoding_ops",
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"], sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
) )
ext_modules.append(positional_encoding_extension) ext_modules.append(positional_encoding_extension)
@@ -101,7 +128,10 @@ ext_modules.append(positional_encoding_extension)
layernorm_extension = CUDAExtension( layernorm_extension = CUDAExtension(
name="vllm.layernorm_ops", name="vllm.layernorm_ops",
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"], sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
) )
ext_modules.append(layernorm_extension) ext_modules.append(layernorm_extension)
@@ -109,7 +139,10 @@ ext_modules.append(layernorm_extension)
activation_extension = CUDAExtension( activation_extension = CUDAExtension(
name="vllm.activation_ops", name="vllm.activation_ops",
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"], sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
) )
ext_modules.append(activation_extension) ext_modules.append(activation_extension)
@@ -124,8 +157,8 @@ def find_version(filepath: str):
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
""" """
with open(filepath) as fp: with open(filepath) as fp:
version_match = re.search( version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M) fp.read(), re.M)
if version_match: if version_match:
return version_match.group(1) return version_match.group(1)
raise RuntimeError("Unable to find version string.") raise RuntimeError("Unable to find version string.")
@@ -148,7 +181,8 @@ setuptools.setup(
version=find_version(get_path("vllm", "__init__.py")), version=find_version(get_path("vllm", "__init__.py")),
author="vLLM Team", author="vLLM Team",
license="Apache 2.0", license="Apache 2.0",
description="A high-throughput and memory-efficient inference and serving engine for LLMs", description=("A high-throughput and memory-efficient inference and "
"serving engine for LLMs"),
long_description=read_readme(), long_description=read_readme(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://github.com/vllm-project/vllm", url="https://github.com/vllm-project/vllm",
@@ -160,11 +194,12 @@ setuptools.setup(
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: Apache Software License", "License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
], ],
packages=setuptools.find_packages( packages=setuptools.find_packages(exclude=("benchmarks", "csrc", "docs",
exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")), "examples", "tests")),
python_requires=">=3.8", python_requires=">=3.8",
install_requires=get_requirements(), install_requires=get_requirements(),
ext_modules=ext_modules, ext_modules=ext_modules,

View File

@@ -0,0 +1,50 @@
"""vllm.entrypoints.api_server with some extra logging for testing."""
import argparse
from typing import Any, Dict
import uvicorn
from fastapi.responses import JSONResponse, Response
import vllm.entrypoints.api_server
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
app = vllm.entrypoints.api_server.app
class AsyncLLMEngineWithStats(AsyncLLMEngine):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._num_aborts = 0
async def abort(self, request_id: str) -> None:
await super().abort(request_id)
self._num_aborts += 1
def testing_stats(self) -> Dict[str, Any]:
return {"num_aborted_requests": self._num_aborts}
@app.get("/stats")
def stats() -> Response:
"""Get the statistics of the engine."""
return JSONResponse(engine.testing_stats())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
vllm.entrypoints.api_server.engine = engine
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=vllm.entrypoints.api_server.TIMEOUT_KEEP_ALIVE)

View File

@@ -0,0 +1,86 @@
import subprocess
import sys
import time
from multiprocessing import Pool
from pathlib import Path
import pytest
import requests
def _query_server(prompt: str) -> dict:
response = requests.post("http://localhost:8000/generate",
json={
"prompt": prompt,
"max_tokens": 100,
"temperature": 0,
"ignore_eos": True
})
response.raise_for_status()
return response.json()
@pytest.fixture
def api_server():
script_path = Path(__file__).parent.joinpath(
"api_server_async_engine.py").absolute()
uvicorn_process = subprocess.Popen([
sys.executable, "-u",
str(script_path), "--model", "facebook/opt-125m"
])
yield
uvicorn_process.terminate()
def test_api_server(api_server):
"""
Run the API server and test it.
We run both the server and requests in separate processes.
We test that the server can handle incoming requests, including
multiple requests at the same time, and that it can handle requests
being cancelled without crashing.
"""
with Pool(32) as pool:
# Wait until the server is ready
prompts = ["Hello world"] * 1
result = None
while not result:
try:
for result in pool.map(_query_server, prompts):
break
except:
time.sleep(1)
# Actual tests start here
# Try with 1 prompt
for result in pool.map(_query_server, prompts):
assert result
num_aborted_requests = requests.get(
"http://localhost:8000/stats").json()["num_aborted_requests"]
assert num_aborted_requests == 0
# Try with 100 prompts
prompts = ["Hello world"] * 100
for result in pool.map(_query_server, prompts):
assert result
# Cancel requests
pool.map_async(_query_server, prompts)
time.sleep(0.01)
pool.terminate()
pool.join()
# check cancellation stats
num_aborted_requests = requests.get(
"http://localhost:8000/stats").json()["num_aborted_requests"]
assert num_aborted_requests > 0
# check that server still runs after cancellations
with Pool(32) as pool:
# Try with 100 prompts
prompts = ["Hello world"] * 100
for result in pool.map(_query_server, prompts):
assert result

View File

@@ -0,0 +1,54 @@
import pytest
from vllm.engine.async_llm_engine import RequestTracker
from vllm.outputs import RequestOutput
def test_request_tracker():
tracker = RequestTracker()
stream_1 = tracker.add_request("1")
new, finished = tracker.get_new_and_finished_requests()
assert len(new) == 1
assert new[0]["request_id"] == "1"
assert not finished
assert not stream_1.finished
stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3")
new, finished = tracker.get_new_and_finished_requests()
assert len(new) == 2
assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3"
assert not finished
assert not stream_2.finished
assert not stream_3.finished
# request_ids must be unique
with pytest.raises(KeyError):
tracker.add_request("1")
tracker.abort_request("1")
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "1" in finished
assert not new
assert stream_1.finished
stream_4 = tracker.add_request("4")
tracker.abort_request("4")
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "4" in finished
assert not new
assert stream_4.finished
stream_5 = tracker.add_request("5")
tracker.process_request_output(
RequestOutput("2", "output", [], [], finished=True))
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "2" in finished
assert len(new) == 1
assert new[0]["request_id"] == "5"
assert stream_2.finished
assert not stream_5.finished

178
tests/conftest.py Normal file
View File

@@ -0,0 +1,178 @@
from typing import List, Optional, Tuple
import pytest
import torch
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
_TEST_PROMPTS = [
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
"Describe the basic components of a neural network and how it can be trained.",
"Write a short story about a robot that dreams for the first time.",
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
]
@pytest.fixture
def example_prompts() -> List[str]:
return _TEST_PROMPTS
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
}
class HfRunner:
def __init__(
self,
model_name: str,
tokenizer_name: Optional[str] = None,
dtype: str = "half",
) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
if tokenizer_name is None:
tokenizer_name = model_name
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
def generate(
self,
prompts: List[str],
**kwargs,
) -> List[Tuple[List[int], str]]:
outputs: List[Tuple[List[int], str]] = []
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
output_ids = self.model.generate(
input_ids.cuda(),
use_cache=True,
**kwargs,
)
output_str = self.tokenizer.batch_decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
output_ids = output_ids.cpu().tolist()
outputs.append((output_ids, output_str))
return outputs
def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens)
for i in range(len(outputs)):
output_ids, output_str = outputs[i]
outputs[i] = (output_ids[0], output_str[0])
return outputs
def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens,
num_beams=beam_width,
num_return_sequences=beam_width)
for i in range(len(outputs)):
output_ids, output_str = outputs[i]
for j in range(len(output_ids)):
output_ids[j] = [
x for x in output_ids[j]
if x != self.tokenizer.pad_token_id
]
outputs[i] = (output_ids, output_str)
return outputs
@pytest.fixture
def hf_runner():
return HfRunner
class VllmRunner:
def __init__(
self,
model_name: str,
tokenizer_name: Optional[str] = None,
dtype: str = "half",
) -> None:
self.model = LLM(
model=model_name,
tokenizer=tokenizer_name,
trust_remote_code=True,
dtype=dtype,
swap_space=0,
)
def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
) -> List[Tuple[List[int], str]]:
req_outputs = self.model.generate(prompts,
sampling_params=sampling_params)
outputs = []
for req_output in req_outputs:
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
req_sample_output_ids = []
req_sample_output_strs = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs
def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]
def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[int], str]]:
beam_search_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,
max_tokens=max_tokens)
outputs = self.generate(prompts, beam_search_params)
return outputs
@pytest.fixture
def vllm_runner():
return VllmRunner

43
tests/kernels/conftest.py Normal file
View File

@@ -0,0 +1,43 @@
from typing import List, Tuple
import pytest
import torch
def create_kv_caches(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
seed: int,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = []
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device='cuda')
key_cache.uniform_(-scale, scale)
key_caches.append(key_cache)
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = []
for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device='cuda')
value_cache.uniform_(-scale, scale)
value_caches.append(value_cache)
return key_caches, value_caches
@pytest.fixture()
def kv_cache_factory():
return create_kv_caches

View File

@@ -1,20 +1,34 @@
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers.activations import get_activation
from vllm import activation_ops from vllm import activation_ops
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
SEEDS = [0]
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor: def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(chunks=2, dim=1) x1, x2 = x.chunk(chunks=2, dim=1)
return F.silu(x1) * x2 return F.silu(x1) * x2
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode() @torch.inference_mode()
def run_silu_and_mul( def test_silu_and_mul(
num_tokens: int, num_tokens: int,
d: int, d: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int,
) -> None: ) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda') x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
activation_ops.silu_and_mul(out, x) activation_ops.silu_and_mul(out, x)
@@ -22,9 +36,40 @@ def run_silu_and_mul(
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
def test_silu_and_mul() -> None: @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
for dtype in [torch.half, torch.bfloat16, torch.float]: @pytest.mark.parametrize("d", D)
for num_tokens in [7, 83, 2048]: @pytest.mark.parametrize("dtype", DTYPES)
for d in [512, 4096, 5120, 13824]: @pytest.mark.parametrize("seed", SEEDS)
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') @torch.inference_mode()
run_silu_and_mul(num_tokens, d, dtype) def test_gelu_new(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
activation_ops.gelu_new(out, x)
ref_out = get_activation("gelu_new")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
def test_gelu_fast(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
activation_ops.gelu_fast(out, x)
ref_out = get_activation("gelu_fast")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)

View File

@@ -1,14 +1,24 @@
import random import random
from typing import List, Optional from typing import List, Optional, Tuple
import pytest
import torch import torch
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm import attention_ops from vllm import attention_ops
MAX_SEQ_LEN = 4096 MAX_SEQ_LEN = 8192
TEST_SEED = 0 NUM_BLOCKS = 128 # Arbitrary values for testing
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [1, 3, 7] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
USE_ALIBI = [False, True]
SEEDS = [0]
def ref_masked_attention( def ref_masked_attention(
@@ -18,29 +28,34 @@ def ref_masked_attention(
scale: float, scale: float,
attn_mask: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
query = query * scale attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
attn = torch.einsum('qhd,khd->hqk', query, key)
if attn_mask is not None: if attn_mask is not None:
attn = attn + attn_mask attn_weights = attn_weights + attn_mask.float()
attn = torch.softmax(attn, dim=-1) attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum('hqk,khd->qhd', attn, value) out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out return out
def ref_single_query_cached_kv_attention( def ref_single_query_cached_kv_attention(
output: torch.Tensor, output: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
num_queries_per_kv: int,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
block_tables: torch.Tensor, block_tables: torch.Tensor,
context_lens: torch.Tensor, context_lens: torch.Tensor,
scale: float,
alibi_slopes: Optional[torch.Tensor],
) -> None: ) -> None:
num_heads = value_cache.shape[1] num_query_heads = query.shape[1]
num_kv_heads = value_cache.shape[1]
head_size = value_cache.shape[2] head_size = value_cache.shape[2]
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs = query.shape[0]
num_input_tokens = query.shape[0] block_tables = block_tables.cpu().tolist()
for i in range(num_input_tokens): context_lens = context_lens.cpu().tolist()
for i in range(num_seqs):
q = query[i].unsqueeze(0) q = query[i].unsqueeze(0)
block_table = block_tables[i] block_table = block_tables[i]
context_len = int(context_lens[i]) context_len = int(context_lens[i])
@@ -52,30 +67,138 @@ def ref_single_query_cached_kv_attention(
block_offset = j % block_size block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :] k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_heads, head_size) k = k.reshape(num_kv_heads, head_size)
keys.append(k) keys.append(k)
v = value_cache[block_number, :, :, block_offset] v = value_cache[block_number, :, :, block_offset]
values.append(v) values.append(v)
keys = torch.stack(keys, dim=0) keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0) values = torch.stack(values, dim=0)
if num_queries_per_kv > 1:
# Handle MQA and GQA
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
scale = 1.0 / (head_size ** 0.5) alibi_bias = None
out = ref_masked_attention(q, keys, values, scale) if alibi_slopes is not None:
out = out.view(num_heads, head_size) # Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(context_len, device="cuda").int()
alibi_bias = (position_ids - context_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1)
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
out = out.view(num_query_heads, head_size)
output[i].copy_(out, non_blocking=True) output[i].copy_(out, non_blocking=True)
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_single_query_cached_kv_attention(
kv_cache_factory,
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
use_alibi: bool,
block_size: int,
dtype: torch.dtype,
seed: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs,
num_query_heads,
head_size,
dtype=dtype,
device="cuda")
query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
head_mapping = torch.repeat_interleave(
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
num_queries_per_kv)
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
dtype=torch.float,
device="cuda")
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size, dtype,
seed)
key_cache, value_cache = key_caches[0], value_caches[0]
# Call the paged attention kernel.
output = torch.empty_like(query)
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
# Run the reference implementation.
ref_output = torch.empty_like(query)
ref_single_query_cached_kv_attention(
ref_output,
query,
num_queries_per_kv,
key_cache,
value_cache,
block_tables,
context_lens,
scale,
alibi_slopes,
)
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
def ref_multi_query_kv_attention( def ref_multi_query_kv_attention(
cu_seq_lens: List[int], cu_seq_lens: List[int],
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
scale: float,
dtype: torch.dtype, dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
head_size = query.shape[-1]
scale = 1.0 / (head_size ** 0.5)
num_seqs = len(cu_seq_lens) - 1 num_seqs = len(cu_seq_lens) - 1
ref_outputs = [] ref_outputs = []
for i in range(num_seqs): for i in range(num_seqs):
@@ -84,10 +207,10 @@ def ref_multi_query_kv_attention(
seq_len = end_idx - start_idx seq_len = end_idx - start_idx
# Create attention mask. # Create attention mask.
attn_mask = torch.triu( attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device='cuda') attn_mask = attn_mask.to(dtype=dtype, device="cuda")
ref_output = ref_masked_attention( ref_output = ref_masked_attention(
query[start_idx:end_idx], query[start_idx:end_idx],
@@ -101,147 +224,43 @@ def ref_multi_query_kv_attention(
return ref_output return ref_output
def ref_multi_query_cached_kv_attention( # TODO(woosuk): Add tests for USE_ALIBI=True.
cu_query_lens: List[int], @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
query: torch.Tensor, @pytest.mark.parametrize("num_heads", NUM_HEADS)
key_cache: torch.Tensor, @pytest.mark.parametrize("head_size", HEAD_SIZES)
value_cache: torch.Tensor, @pytest.mark.parametrize("dtype", DTYPES)
block_tables: torch.Tensor, @pytest.mark.parametrize("seed", SEEDS)
context_lens: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
num_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
block_size = value_cache.shape[3]
scale = 1.0 / (head_size ** 0.5)
num_queries = len(cu_query_lens) - 1
ref_outputs = []
for i in range(num_queries):
start_idx = cu_query_lens[i]
end_idx = cu_query_lens[i + 1]
query_len = end_idx - start_idx
context_len = int(context_lens[i])
block_table = block_tables[i]
# Create attention mask
attn_mask = torch.triu(
torch.ones(query_len, context_len), diagonal=context_len - query_len + 1) * -1e5
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
keys = []
values = []
for j in range(context_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_heads, head_size)
keys.append(k)
v = value_cache[block_number, :, :, block_offset]
values.append(v)
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)
ref_output = ref_masked_attention(
query[start_idx:end_idx],
keys,
values,
scale,
attn_mask=attn_mask,
)
ref_outputs.append(ref_output)
ref_output = torch.cat(ref_outputs, dim=0)
return ref_output
@torch.inference_mode() @torch.inference_mode()
def run_single_query_cached_kv_attention( def test_multi_query_kv_attention(
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
qkv = torch.empty(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
qkv.uniform_(-1e-3, 1e-3)
query, _, _ = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.empty(
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
key_cache.uniform_(-1e-3, 1e-3)
value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.empty(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
value_cache.uniform_(-1e-3, 1e-3)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_tokens):
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5))
output = torch.empty(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
)
ref_output = torch.empty_like(query)
ref_single_query_cached_kv_attention(
ref_output,
query,
key_cache,
value_cache,
block_tables,
context_lens,
)
# NOTE(woosuk): Due to the difference in the data types the two
# implementations use for attention softmax logits and accumulation,
# there is a small difference in the final outputs.
# We should use a relaxed tolerance for the test.
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
@torch.inference_mode()
def run_multi_query_kv_attention(
num_seqs: int, num_seqs: int,
num_heads: int, num_heads: Tuple[int, int],
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int,
) -> None: ) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
num_tokens = sum(seq_lens) num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size ** 0.5)) scale = float(1.0 / (head_size**0.5))
qkv = torch.empty( num_query_heads, num_kv_heads = num_heads
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') qkv = torch.empty(num_tokens,
qkv.uniform_(-1e-3, 1e-3) num_query_heads + 2 * num_kv_heads,
query, key, value = qkv.unbind(dim=1) head_size,
dtype=dtype,
device="cuda")
qkv.uniform_(-scale, scale)
query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
attn_op = xops.fmha.cutlass.FwOp() num_queries_per_kv = num_query_heads // num_kv_heads
if num_queries_per_kv > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward( output = xops.memory_efficient_attention_forward(
query.unsqueeze(0), query.unsqueeze(0),
@@ -250,7 +269,6 @@ def run_multi_query_kv_attention(
attn_bias=attn_bias, attn_bias=attn_bias,
p=0.0, p=0.0,
scale=scale, scale=scale,
op=attn_op,
) )
output = output.squeeze(0) output = output.squeeze(0)
@@ -262,40 +280,7 @@ def run_multi_query_kv_attention(
query, query,
key, key,
value, value,
scale,
dtype, dtype,
) )
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
def test_single_query_cached_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16, torch.float]:
for block_size in [8, 16, 32]:
for head_size in [64, 80, 96, 128]:
print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}')
run_single_query_cached_kv_attention(
num_tokens=37,
num_heads=3,
head_size=head_size,
block_size=block_size,
num_blocks=1024,
dtype=dtype,
)
def test_multi_query_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16, torch.float]:
for head_size in [64, 80, 96, 128]:
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
f'head_size={head_size}')
run_multi_query_kv_attention(
num_seqs=5,
num_heads=3,
head_size=head_size,
dtype=dtype,
)

View File

@@ -1,12 +1,32 @@
import random import random
import pytest
import torch import torch
from vllm import cache_ops from vllm import cache_ops
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
NUM_LAYERS = [5] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS = [1024] # Arbitrary values for testing
NUM_MAPPINGS = [32, 256] # Arbitrary values for testing
SEEDS = [0]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode() @torch.inference_mode()
def run_copy_blocks( def test_copy_blocks(
kv_cache_factory,
num_mappings: int, num_mappings: int,
num_layers: int, num_layers: int,
num_heads: int, num_heads: int,
@@ -14,151 +34,113 @@ def run_copy_blocks(
block_size: int, block_size: int,
num_blocks: int, num_blocks: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int,
) -> None: ) -> None:
# Generate random block mappings. random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Generate random block mappings where each source block is mapped to two
# destination blocks.
assert 2 * num_mappings <= num_blocks
src_blocks = random.sample(range(num_blocks), num_mappings) src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, num_mappings) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)} block_mapping = {}
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
dst2 = dst_blocks[2 * i + 1]
block_mapping[src] = [dst1, dst2]
# Create the KV cache. # Create the KV caches.
x = 16 // torch.tensor([], dtype=dtype).element_size() key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) num_layers, num_heads,
key_caches = [] head_size, dtype, seed)
for _ in range(num_layers):
key_cache = torch.randn(
size=key_cache_shape, dtype=dtype, device='cuda')
key_caches.append(key_cache)
cloned_key_caches = []
for key_cache in key_caches:
cloned_key_caches.append(key_cache.clone())
value_cache_shape = (num_blocks, num_heads, head_size, block_size) # Clone the KV caches.
value_caches = [] cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
for _ in range(num_layers): cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
value_cache = torch.randn(
size=value_cache_shape, dtype=dtype, device='cuda')
value_caches.append(value_cache)
cloned_value_caches = []
for value_cache in value_caches:
cloned_value_caches.append(value_cache.clone())
# Call the copy blocks kernel. # Call the copy blocks kernel.
cache_ops.copy_blocks(key_caches, value_caches, block_mapping) cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
# Reference implementation. # Run the reference implementation.
for src, dsts in block_mapping.items(): for src, dsts in block_mapping.items():
for dst in dsts: for dst in dsts:
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst] = cloned_key_cache[src] cloned_key_cache[dst] = cloned_key_cache[src]
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst] = cloned_value_cache[src] cloned_value_cache[dst] = cloned_value_cache[src]
# Compare the results. # Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
assert torch.allclose(key_cache, cloned_key_cache) assert torch.allclose(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
assert torch.allclose(value_cache, cloned_value_cache) assert torch.allclose(value_cache, cloned_value_cache)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode() @torch.inference_mode()
def run_reshape_and_cache( def test_reshape_and_cache(
kv_cache_factory,
num_tokens: int, num_tokens: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
block_size: int, block_size: int,
num_blocks: int, num_blocks: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int,
) -> None: ) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Create a random slot mapping.
num_slots = block_size * num_blocks num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
qkv = torch.randn( qkv = torch.randn(num_tokens,
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') 3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
_, key, value = qkv.unbind(dim=1) _, key, value = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size() # Create the KV caches.
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') num_heads, head_size, dtype,
cloned_key_cache = key_cache.clone() seed)
key_cache, value_cache = key_caches[0], value_caches[0]
value_cache_shape = (num_blocks, num_heads, head_size, block_size) # Clone the KV caches.
value_cache = torch.randn( cloned_key_cache = key_cache.clone()
size=value_cache_shape, dtype=dtype, device='cuda')
cloned_value_cache = value_cache.clone() cloned_value_cache = value_cache.clone()
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) # Call the reshape_and_cache kernel.
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping)
# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
block_indicies = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist()
for i in range(num_tokens): for i in range(num_tokens):
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x) block_idx = block_indicies[i]
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor') block_offset = block_offsets[i]
block_offset = slot_mapping[i] % block_size
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i] cloned_value_cache[block_idx, :, :, block_offset] = value[i]
assert torch.allclose(key_cache, cloned_key_cache) assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache) assert torch.allclose(value_cache, cloned_value_cache)
@torch.inference_mode()
def run_gather_cached_kv(
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
qkv = torch.randn(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
_, key, value = qkv.unbind(dim=1)
qkv_clone = qkv.clone()
_, cloned_key, cloned_value = qkv_clone.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_cache = torch.randn(
size=value_cache_shape, dtype=dtype, device='cuda')
cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping)
# Reference implementation.
for i in range(num_tokens):
reshaped_key = cloned_key.reshape(num_tokens, num_heads, head_size // x, x)
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
block_offset = slot_mapping[i] % block_size
reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
cloned_value[i] = value_cache[block_idx, :, :, block_offset]
assert torch.allclose(key, cloned_key)
assert torch.allclose(value, cloned_value)
def test_copy_blocks() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
run_copy_blocks(
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
block_size=8, num_blocks=1024, dtype=dtype)
def test_reshape_and_cache() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
run_reshape_and_cache(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=dtype)
def test_gather_cached_kv() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
run_gather_cached_kv(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=dtype)

View File

@@ -1,33 +1,50 @@
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm import layernorm_ops from vllm import layernorm_ops
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
SEEDS = [0]
class RefRMSNorm(nn.Module): class RefRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
super().__init__() super().__init__()
weight = torch.empty(hidden_size) weight = torch.empty(hidden_size)
weight.uniform_(-1e-3, 1e-3) weight.normal_(mean=1.0, std=0.1)
self.weight = nn.Parameter(weight) self.weight = nn.Parameter(weight)
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states): def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) input_dtype = hidden_states.dtype
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states.to(torch.float32)
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]: variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states * torch.rsqrt(variance +
return self.weight * hidden_states self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode() @torch.inference_mode()
def run_rms_norm( def test_rms_norm(
num_tokens: int, num_tokens: int,
hidden_size: int, hidden_size: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int,
) -> None: ) -> None:
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda') torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
scale = float(hidden_size**-0.5)
x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
x.uniform_(-scale, scale)
ref = RefRMSNorm(hidden_size).to(dtype).cuda() ref = RefRMSNorm(hidden_size).to(dtype).cuda()
out = torch.empty_like(x) out = torch.empty_like(x)
@@ -38,17 +55,4 @@ def run_rms_norm(
ref.variance_epsilon, ref.variance_epsilon,
) )
ref_out = ref(x) ref_out = ref(x)
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5) assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5)
def test_rms_norm() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 128, 2048]:
for hidden_size in [13, 64, 1024, 5120]:
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
f'{num_tokens}, hidden_size={hidden_size}')
run_rms_norm(
num_tokens=num_tokens,
hidden_size=hidden_size,
dtype=dtype,
)

View File

@@ -1,47 +1,70 @@
from typing import Tuple from typing import Optional, Tuple
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm import pos_encoding_ops from vllm import pos_encoding_ops
IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing
NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing
SEEDS = [0]
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2] def rotate_neox(x: torch.Tensor) -> torch.Tensor:
x2 = x[..., x.shape[-1] // 2 :] x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb( def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
def apply_rope(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
cos: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor, sin: torch.Tensor,
is_neox_style: bool,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
q_embed = (q * cos) + (rotate_half(q) * sin) rotate_fn = rotate_neox if is_neox_style else rotate_gptj
k_embed = (k * cos) + (rotate_half(k) * sin) q_embed = (q * cos) + (rotate_fn(q) * sin)
k_embed = (k * cos) + (rotate_fn(k) * sin)
return q_embed, k_embed return q_embed, k_embed
class RefRotaryEmbeddingNeox(nn.Module): class RefRotaryEmbedding(nn.Module):
"""Reference implementation of the GPT-NeoX style rotary embedding.""" """Reference implementation of rotary embedding."""
def __init__( def __init__(
self, self,
dim: int, dim: int,
max_position_embeddings: int = 2048, is_neox_style: bool,
max_position_embeddings: int = 8192,
base: int = 10000, base: int = 10000,
) -> None: ) -> None:
super().__init__() super().__init__()
self.rotary_dim = dim self.rotary_dim = dim
self.is_neox_style = is_neox_style
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
# Create cos and sin embeddings. # Create cos and sin embeddings.
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
t = torch.arange(max_position_embeddings).float() t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, inv_freq.float()) freqs = torch.einsum("i,j->ij", t, inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1) if is_neox_style:
emb = torch.cat((freqs, freqs), dim=-1)
else:
emb = torch.repeat_interleave(freqs, 2, -1)
cos = emb.cos().to(dtype=inv_freq.dtype) cos = emb.cos().to(dtype=inv_freq.dtype)
sin = emb.sin().to(dtype=inv_freq.dtype) sin = emb.sin().to(dtype=inv_freq.dtype)
self.register_buffer("cos_cached", cos, persistent=False) self.register_buffer("cos_cached", cos, persistent=False)
@@ -49,22 +72,22 @@ class RefRotaryEmbeddingNeox(nn.Module):
def forward( def forward(
self, self,
positions: torch.Tensor, # [num_tokens] positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads, head_size] query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads, head_size] key: torch.Tensor, # [num_tokens, num_heads, head_size]
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
query_rot = query[..., :self.rotary_dim]
query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim:]
query_pass = query[..., self.rotary_dim :] key_rot = key[..., :self.rotary_dim]
key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim :]
query_rot = query_rot.transpose(0, 1) query_rot = query_rot.transpose(0, 1)
key_rot = key_rot.transpose(0, 1) key_rot = key_rot.transpose(0, 1)
cos = F.embedding(positions, self.cos_cached) cos = F.embedding(positions, self.cos_cached)
sin = F.embedding(positions, self.sin_cached) sin = F.embedding(positions, self.sin_cached)
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
self.is_neox_style)
query_rot = query_rot.transpose(0, 1).contiguous() query_rot = query_rot.transpose(0, 1).contiguous()
key_rot = key_rot.transpose(0, 1).contiguous() key_rot = key_rot.transpose(0, 1).contiguous()
@@ -75,24 +98,45 @@ class RefRotaryEmbeddingNeox(nn.Module):
return query, key return query, key
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode() @torch.inference_mode()
def run_rotary_embedding_neox( def test_rotary_embedding(
is_neox_style: bool,
num_tokens: int, num_tokens: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
max_position: int, rotary_dim: Optional[int],
rotary_dim: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int,
max_position: int = 8192,
base: int = 10000, base: int = 10000,
) -> None: ) -> None:
positions = torch.randint(0, max_position, (num_tokens,), device='cuda') if rotary_dim is None:
query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda') rotary_dim = head_size
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda') torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
positions = torch.randint(0, max_position, (num_tokens, ), device="cuda")
query = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device="cuda")
key = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device="cuda")
# Create the rotary embedding. # Create the rotary embedding.
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim)) inv_freq = 1.0 / (base**(
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
t = torch.arange(max_position).float() t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1) cos_sin_cache = torch.cat((cos, sin), dim=-1)
@@ -101,20 +145,22 @@ def run_rotary_embedding_neox(
# Run the kernel. The kernel is in-place, so we need to clone the inputs. # Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query = query.clone() out_query = query.clone()
out_key = key.clone() out_key = key.clone()
pos_encoding_ops.rotary_embedding_neox( pos_encoding_ops.rotary_embedding(
positions, positions,
out_query, out_query,
out_key, out_key,
head_size, head_size,
cos_sin_cache, cos_sin_cache,
is_neox_style,
) )
# Run the reference implementation. # Run the reference implementation.
ref_rotary_embedding = RefRotaryEmbeddingNeox( ref_rotary_embedding = RefRotaryEmbedding(
dim=rotary_dim, dim=rotary_dim,
is_neox_style=is_neox_style,
max_position_embeddings=max_position, max_position_embeddings=max_position,
base=base, base=base,
).to(dtype=dtype, device='cuda') ).to(dtype=dtype, device="cuda")
ref_query, ref_key = ref_rotary_embedding( ref_query, ref_key = ref_rotary_embedding(
positions, positions,
query.view(num_tokens, num_heads, head_size), query.view(num_tokens, num_heads, head_size),
@@ -124,19 +170,5 @@ def run_rotary_embedding_neox(
ref_key = ref_key.view(num_tokens, num_heads * head_size) ref_key = ref_key.view(num_tokens, num_heads * head_size)
# Compare the results. # Compare the results.
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5) assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
def test_rotary_embedding_neox() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Running tests for head_size={head_size} and dtype={dtype}')
run_rotary_embedding_neox(
num_tokens=2145,
num_heads=5,
head_size=head_size,
max_position=8192,
rotary_dim=head_size,
dtype=dtype,
)

View File

@@ -0,0 +1,45 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.
Run `pytest tests/models/test_models.py --forked`.
"""
import pytest
MODELS = [
"facebook/opt-125m",
"gpt2",
"bigcode/tiny_starcoder_py",
"EleutherAI/gpt-j-6b",
"EleutherAI/pythia-70m",
"bigscience/bloom-560m",
"mosaicml/mpt-7b",
"tiiuae/falcon-7b",
"meta-llama/Llama-2-7b-hf",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

View File

@@ -0,0 +1,46 @@
"""Compare the outputs of HF and vLLM when using beam search.
Run `pytest tests/samplers/test_beam_search.py --forked`.
"""
import pytest
# FIXME(zhuohan): The test can not pass if we:
# 1. Increase max_tokens to 256.
# 2. Increase beam_width to 8.
# 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS = [128]
BEAM_WIDTHS = [4]
MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
def test_beam_search_single_input(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
del vllm_model
for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
vllm_output_ids, _ = vllm_outputs[i]
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")

View File

@@ -1,3 +1,5 @@
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
@@ -6,7 +8,7 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
__version__ = "0.1.0" __version__ = "0.1.7"
__all__ = [ __all__ = [
"LLM", "LLM",

View File

@@ -35,7 +35,8 @@ class LogicalTokenBlock:
def append_tokens(self, token_ids: List[int]) -> None: def append_tokens(self, token_ids: List[int]) -> None:
assert len(token_ids) <= self.get_num_empty_slots() assert len(token_ids) <= self.get_num_empty_slots()
self.token_ids[self.num_tokens:self.num_tokens + len(token_ids)] = token_ids curr_idx = self.num_tokens
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
self.num_tokens += len(token_ids) self.num_tokens += len(token_ids)
def get_token_ids(self) -> List[int]: def get_token_ids(self) -> List[int]:

View File

@@ -1,14 +1,15 @@
from typing import Optional from typing import Optional
import torch import torch
from transformers import AutoConfig, PretrainedConfig from transformers import PretrainedConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory from vllm.utils import get_cpu_memory
logger = init_logger(__name__) logger = init_logger(__name__)
_GiB = 1 << 30 _GB = 1 << 30
class ModelConfig: class ModelConfig:
@@ -16,11 +17,23 @@ class ModelConfig:
Args: Args:
model: Name or path of the huggingface model to use. model: Name or path of the huggingface model to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
download_dir: Directory to download and load the weights, default to the download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface. default cache directory of huggingface.
use_np_weights: Save a numpy copy of model weights for faster loading. load_format: The format of the model weights to load:
This can increase the disk usage by up to 2x. "auto" will try to load the weights in the safetensors format and
use_dummy_weights: Use dummy values for model weights (for profiling). fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
dtype: Data type for model weights and activations. The "auto" option dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models. for BF16 models.
@@ -30,20 +43,44 @@ class ModelConfig:
def __init__( def __init__(
self, self,
model: str, model: str,
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
download_dir: Optional[str], download_dir: Optional[str],
use_np_weights: bool, load_format: str,
use_dummy_weights: bool,
dtype: str, dtype: str,
seed: int, seed: int,
) -> None: ) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.download_dir = download_dir self.download_dir = download_dir
self.use_np_weights = use_np_weights self.load_format = load_format
self.use_dummy_weights = use_dummy_weights
self.seed = seed self.seed = seed
self.hf_config: PretrainedConfig = AutoConfig.from_pretrained(model) self.hf_config = get_config(model, trust_remote_code)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_load_format()
self._verify_tokenizer_mode()
def _verify_load_format(self) -> None:
load_format = self.load_format.lower()
if load_format not in [
"auto", "pt", "safetensors", "npcache", "dummy"
]:
raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
self.load_format = load_format
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow"]:
raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto' or 'slow'.")
self.tokenizer_mode = tokenizer_mode
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
@@ -73,9 +110,49 @@ class ModelConfig:
return self.hf_config.hidden_size // self.hf_config.num_attention_heads return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_num_heads(self, parallel_config: "ParallelConfig") -> int: def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
# For GPTBigCode & Falcon:
# Note: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = (
self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False))
if not new_decoder_arch_falcon and getattr(self.hf_config,
"multi_query", False):
# Multi-query attention, only one KV head.
return 1
# For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None:
return (self.hf_config.n_head_kv //
parallel_config.tensor_parallel_size)
# For LLaMA-2:
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
return (self.hf_config.num_key_value_heads //
parallel_config.tensor_parallel_size)
total_num_attention_heads = self.hf_config.num_attention_heads total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size return total_num_attention_heads // parallel_config.tensor_parallel_size
def get_max_model_len(self) -> int:
max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(self.hf_config, key, None)
if max_len_key is not None:
max_model_len = min(max_model_len, max_len_key)
return max_model_len
def get_num_layers(self, parallel_config: "ParallelConfig") -> int: def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers total_num_hidden_layers = self.hf_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size return total_num_hidden_layers // parallel_config.pipeline_parallel_size
@@ -90,6 +167,7 @@ class CacheConfig:
vLLM execution. vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB). swap_space: Size of the CPU swap space per GPU (in GiB).
""" """
def __init__( def __init__(
self, self,
block_size: int, block_size: int,
@@ -98,7 +176,7 @@ class CacheConfig:
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GiB self.swap_space_bytes = swap_space * _GB
self._verify_args() self._verify_args()
# Will be set after profiling. # Will be set after profiling.
@@ -121,14 +199,13 @@ class CacheConfig:
num_gpus_per_node = parallel_config.tensor_parallel_size num_gpus_per_node = parallel_config.tensor_parallel_size
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
msg = ( msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
f"{cpu_memory_usage / _GiB:.2f} GiB out of " f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
f"the {total_cpu_memory / _GiB:.2f} GiB total CPU memory is " "allocated for the swap space.")
"allocated for the swap space.")
if cpu_memory_usage > 0.7 * total_cpu_memory: if cpu_memory_usage > 0.7 * total_cpu_memory:
raise ValueError("Too large swap space. " + msg) raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory: elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warn("Possibly too large swap space. " + msg) logger.warning("Possibly too large swap space. " + msg)
class ParallelConfig: class ParallelConfig:
@@ -141,6 +218,7 @@ class ParallelConfig:
True if either pipeline_parallel_size or tensor_parallel_size is True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1. greater than 1.
""" """
def __init__( def __init__(
self, self,
pipeline_parallel_size: int, pipeline_parallel_size: int,
@@ -170,14 +248,15 @@ class SchedulerConfig:
a single iteration. a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single max_num_seqs: Maximum number of sequences to be processed in a single
iteration. iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
""" """
def __init__(
self, def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
max_num_batched_tokens: int, max_model_len: int) -> None:
max_num_seqs: int,
) -> None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
@@ -221,7 +300,7 @@ def _get_and_verify_dtype(
pass pass
else: else:
# Casting between float16 and bfloat16 is allowed with a warning. # Casting between float16 and bfloat16 is allowed with a warning.
logger.warn(f"Casting {config_dtype} to {torch_dtype}.") logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: if torch_dtype == torch.bfloat16:

View File

@@ -27,8 +27,9 @@ class BlockAllocator:
# Initialize the free blocks. # Initialize the free blocks.
self.free_blocks: List[PhysicalTokenBlock] = [] self.free_blocks: List[PhysicalTokenBlock] = []
for i in range(num_blocks): for i in range(num_blocks):
block = PhysicalTokenBlock( block = PhysicalTokenBlock(device=device,
device=device, block_number=i, block_size=block_size) block_number=i,
block_size=block_size)
self.free_blocks.append(block) self.free_blocks.append(block)
def allocate(self) -> PhysicalTokenBlock: def allocate(self) -> PhysicalTokenBlock:
@@ -84,10 +85,12 @@ class BlockSpaceManager:
num_required_blocks = len(seq.logical_token_blocks) num_required_blocks = len(seq.logical_token_blocks)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction. # Use watermark to avoid frequent cache eviction.
return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks return (num_free_gpu_blocks - num_required_blocks >=
self.watermark_blocks)
def allocate(self, seq_group: SequenceGroup) -> None: def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same prompt. # NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = seq_group.get_seqs()[0] seq = seq_group.get_seqs()[0]
# Allocate new physical token blocks that will store the prompt tokens. # Allocate new physical token blocks that will store the prompt tokens.
@@ -143,7 +146,8 @@ class BlockSpaceManager:
for block in src_block_table: for block in src_block_table:
block.ref_count += 1 block.ref_count += 1
def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: def _get_physical_blocks(
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
# NOTE: Here, we assume that the physical blocks are only shared by # NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group. # the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set() blocks: Set[PhysicalTokenBlock] = set()
@@ -168,9 +172,7 @@ class BlockSpaceManager:
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
# CPU block -> GPU block. # CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
if seq.is_finished():
continue
new_block_table: BlockTable = [] new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
@@ -199,9 +201,7 @@ class BlockSpaceManager:
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
# GPU block -> CPU block. # GPU block -> CPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
if seq.is_finished():
continue
new_block_table: BlockTable = [] new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]

View File

@@ -1,19 +1,16 @@
import enum import enum
import time import time
from typing import Dict, List, Optional, Tuple from typing import Dict, Iterable, List, Optional, Tuple, Union
from vllm.config import CacheConfig, SchedulerConfig from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.block_manager import BlockSpaceManager from vllm.core.block_manager import BlockSpaceManager
from vllm.core.policy import PolicyFactory from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs, SequenceGroupMetadata, SequenceStatus)
SequenceStatus)
logger = init_logger(__name__) logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 5
class PreemptionMode(enum.Enum): class PreemptionMode(enum.Enum):
"""Preemption modes. """Preemption modes.
@@ -32,20 +29,28 @@ class SchedulerOutputs:
def __init__( def __init__(
self, self,
scheduled_seq_groups: List[SequenceGroup],
prompt_run: bool,
num_batched_tokens: int,
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: Dict[int, List[int]],
ignored_seq_groups: List[SequenceGroup],
) -> None: ) -> None:
self.scheduled_seq_groups = scheduled_seq_groups
self.prompt_run = prompt_run
self.num_batched_tokens = num_batched_tokens
self.blocks_to_swap_in = blocks_to_swap_in self.blocks_to_swap_in = blocks_to_swap_in
self.blocks_to_swap_out = blocks_to_swap_out self.blocks_to_swap_out = blocks_to_swap_out
self.blocks_to_copy = blocks_to_copy self.blocks_to_copy = blocks_to_copy
# Swap in and swap out should never happen at the same time. # Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out) assert not (blocks_to_swap_in and blocks_to_swap_out)
self.ignored_seq_groups = ignored_seq_groups
def is_empty(self) -> bool: def is_empty(self) -> bool:
return (not self.blocks_to_swap_in # NOTE: We do not consider the ignored sequence groups.
and not self.blocks_to_swap_out return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_copy) and not self.blocks_to_swap_out and not self.blocks_to_copy)
class Scheduler: class Scheduler:
@@ -54,14 +59,15 @@ class Scheduler:
self, self,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
log_stats: bool,
) -> None: ) -> None:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
self.log_stats = log_stats
self.prompt_limit = min(self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
# Instantiate the scheduling policy. # Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name='fcfs') self.policy = PolicyFactory.get_policy(policy_name="fcfs")
# Create the block space manager. # Create the block space manager.
self.block_manager = BlockSpaceManager( self.block_manager = BlockSpaceManager(
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
@@ -69,6 +75,7 @@ class Scheduler:
num_cpu_blocks=self.cache_config.num_cpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks,
) )
# TODO(zhuohan): Use deque instead of list for better performance.
# Sequence groups in the WAITING state. # Sequence groups in the WAITING state.
self.waiting: List[SequenceGroup] = [] self.waiting: List[SequenceGroup] = []
# Sequence groups in the RUNNING state. # Sequence groups in the RUNNING state.
@@ -76,25 +83,30 @@ class Scheduler:
# Sequence groups in the SWAPPED state. # Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = [] self.swapped: List[SequenceGroup] = []
self.last_logging_time: float = 0.0
# List[timestamp, num_tokens]
self.num_input_tokens: List[Tuple[float, int]] = []
def add_seq_group(self, seq_group: SequenceGroup) -> None: def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue. # Add sequence groups to the waiting queue.
self.waiting.append(seq_group) self.waiting.append(seq_group)
def abort_seq_group(self, request_id: str) -> None: def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
if isinstance(request_id, str):
request_id = (request_id, )
request_ids = set(request_id)
for state_queue in [self.waiting, self.running, self.swapped]: for state_queue in [self.waiting, self.running, self.swapped]:
for seq_group in state_queue: # We need to reverse the list as we are removing elements
if seq_group.request_id == request_id: # from it as we iterate over it. If we don't do it,
# indices will get messed up and we will skip over elements.
for seq_group in reversed(state_queue):
if seq_group.request_id in request_ids:
# Remove the sequence group from the state queue. # Remove the sequence group from the state queue.
state_queue.remove(seq_group) state_queue.remove(seq_group)
for seq in seq_group.seqs: for seq in seq_group.get_seqs():
if seq.is_finished(): if seq.is_finished():
continue continue
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED) seq.status = SequenceStatus.FINISHED_ABORTED
return self.free_seq(seq)
request_ids.remove(seq_group.request_id)
if not request_ids:
return
def has_unfinished_seqs(self) -> bool: def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running or self.swapped return self.waiting or self.running or self.swapped
@@ -102,7 +114,7 @@ class Scheduler:
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]: def _schedule(self) -> SchedulerOutputs:
# Blocks that need to be swaped or copied before model execution. # Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {}
@@ -111,10 +123,72 @@ class Scheduler:
# Fix the current time. # Fix the current time.
now = time.time() now = time.time()
# NOTE(woosuk): We prioritize the sequence groups in the RUNNING state # Join waiting sequences if possible.
# in order to minimize the preemption overheads. if not self.swapped:
# Preemption happens only when there is no available slot to keep all ignored_seq_groups: List[SequenceGroup] = []
# the sequence groups in the RUNNING state. scheduled: List[SequenceGroup] = []
# The total number of sequences on the fly, including the
# requests in the generation phase.
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running)
num_batched_tokens = 0
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
while self.waiting:
seq_group = self.waiting[0]
assert seq_group.num_seqs() == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if num_prompt_tokens > self.prompt_limit:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds limit of {self.prompt_limit}")
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
continue
# If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group):
break
# If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens >
self.scheduler_config.max_num_batched_tokens):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.get_max_num_running_seqs()
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.waiting.pop(0)
self._allocate(seq_group)
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
num_curr_seqs += num_new_seqs
scheduled.append(seq_group)
if scheduled:
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled,
prompt_run=True,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
ignored_seq_groups=ignored_seq_groups,
)
return scheduler_outputs
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence # In this case, the policy is responsible for deciding which sequence
# groups to preempt. # groups to preempt.
self.running = self.policy.sort_by_priority(now, self.running) self.running = self.policy.sort_by_priority(now, self.running)
@@ -144,129 +218,56 @@ class Scheduler:
# Swap in the sequence groups in the SWAPPED state if possible. # Swap in the sequence groups in the SWAPPED state if possible.
self.swapped = self.policy.sort_by_priority(now, self.swapped) self.swapped = self.policy.sort_by_priority(now, self.swapped)
while self.swapped and not blocks_to_swap_out: if not preempted:
seq_group = self.swapped[0] num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
# If the sequence group has been preempted in this step, stop. for seq_group in self.running)
if seq_group in preempted:
break
# If the sequence group cannot be swapped in, stop.
if not self.block_manager.can_swap_in(seq_group):
break
# The total number of sequences in the RUNNING state should not while self.swapped:
# exceed the maximum number of sequences. seq_group = self.swapped[0]
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) # If the sequence group cannot be swapped in, stop.
num_curr_seqs = len(self.running) if not self.block_manager.can_swap_in(seq_group):
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
break
seq_group = self.swapped.pop(0)
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slot(seq_group, blocks_to_copy)
self.running.append(seq_group)
num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running
)
# Join waiting sequences if possible.
prompt_group_ids: List[str] = []
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
# prioritized over the sequence groups in the WAITING state.
# This is because we want to bound the amount of CPU memory taken by
# the swapped sequence groups.
if not self.swapped:
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
while self.waiting:
seq_group = self.waiting[0]
# If the sequence group has been preempted in this step, stop.
if seq_group in preempted:
break
# If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group):
break
# If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if (num_batched_tokens + num_prompt_tokens
> self.scheduler_config.max_num_batched_tokens):
break break
# The total number of sequences in the RUNNING state should not # The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences. # exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING) num_new_seqs = seq_group.get_max_num_running_seqs()
num_curr_seqs = len(self.running) if (num_curr_seqs + num_new_seqs >
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: self.scheduler_config.max_num_seqs):
break break
seq_group = self.waiting.pop(0) seq_group = self.swapped.pop(0)
self._allocate(seq_group) self._swap_in(seq_group, blocks_to_swap_in)
self._append_slot(seq_group, blocks_to_copy)
num_curr_seqs += num_new_seqs
self.running.append(seq_group) self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
prompt_group_ids.append(seq_group.request_id) # Each sequence in the generation phase only takes one token slot.
# Therefore, the number of batched tokens is equal to the number of
# sequences in the RUNNING state.
num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
scheduler_outputs = SchedulerOutputs( scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=self.running,
prompt_run=False,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
ignored_seq_groups=[],
) )
if not self.log_stats: return scheduler_outputs
return scheduler_outputs, prompt_group_ids
# TODO(woosuk): Move the below code to the engine.
now = time.time()
if num_batched_tokens > 0:
self.num_input_tokens.append((now, num_batched_tokens))
elapsed_time = now - self.last_logging_time
if elapsed_time > _LOGGING_INTERVAL_SEC:
self.last_logging_time = now
self.num_input_tokens = [
(t, n) for t, n in self.num_input_tokens
if now - t < _LOGGING_INTERVAL_SEC
]
if len(self.num_input_tokens) > 1:
total_num_tokens = sum(n for _, n in self.num_input_tokens[:-1])
window = now - self.num_input_tokens[0][0]
avg_throughput = total_num_tokens / window
else:
avg_throughput = 0.0
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0:
num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else:
cpu_cache_usage = 0.0
logger.info(
f"Throughput: {avg_throughput:.1f} tokens/s, "
f"Running: {len(self.running)} reqs, "
f"Swapped: {len(self.swapped)} reqs, "
f"Pending: {len(self.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
return scheduler_outputs, prompt_group_ids
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# Schedule sequence groups. # Schedule sequence groups.
# This function call changes the internal states of the scheduler # This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting. # such as self.running, self.swapped, and self.waiting.
scheduler_outputs, prompt_group_ids = self._schedule() scheduler_outputs = self._schedule()
# Create input data structures. # Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
for seq_group in self.running: for seq_group in scheduler_outputs.scheduled_seq_groups:
is_prompt = seq_group.request_id in prompt_group_ids
seq_data: Dict[int, List[SequenceData]] = {} seq_data: Dict[int, List[SequenceData]] = {}
block_tables: Dict[int, List[int]] = {} block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
@@ -276,7 +277,7 @@ class Scheduler:
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id, request_id=seq_group.request_id,
is_prompt=is_prompt, is_prompt=scheduler_outputs.prompt_run,
seq_data=seq_data, seq_data=seq_data,
sampling_params=seq_group.sampling_params, sampling_params=seq_group.sampling_params,
block_tables=block_tables, block_tables=block_tables,
@@ -284,35 +285,10 @@ class Scheduler:
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs return seq_group_metadata_list, scheduler_outputs
def update( def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
self, self.block_manager.fork(parent_seq, child_seq)
seq_outputs: Dict[int, SequenceOutputs],
) -> List[SequenceGroup]:
# Update the running sequences and free blocks.
for seq_group in self.running:
# Process beam search results before processing the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam search).
# Free the current sequence.
self.block_manager.free(seq)
# Fork the parent sequence.
parent_seq = seq_group.find(output.parent_seq_id)
parent_seq.fork(seq)
self.block_manager.fork(parent_seq, seq)
# Process the new tokens. def free_seq(self, seq: Sequence) -> None:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append_token_id(output.output_token, output.logprobs)
# Return a shallow copy of the running queue to prevent the queue
# from being modified by the caller.
return self.running.copy()
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
seq.status = finish_status
self.block_manager.free(seq) self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None: def free_finished_seq_groups(self) -> None:
@@ -349,8 +325,8 @@ class Scheduler:
# If preemption mode is not specified, we determine the mode as follows: # If preemption mode is not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than # We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences # swapping. However, when the sequence group has multiple sequences
# (e.g., beam search), recomputation is not supported. In such a case, # (e.g., beam search), recomputation is not currently supported. In
# we use swapping instead. # such a case, we use swapping instead.
# FIXME(woosuk): This makes our scheduling policy a bit bizarre. # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
# As swapped sequences are prioritized over waiting sequences, # As swapped sequences are prioritized over waiting sequences,
# sequence groups with multiple sequences are implicitly prioritized # sequence groups with multiple sequences are implicitly prioritized
@@ -358,8 +334,7 @@ class Scheduler:
# TODO(woosuk): Support recomputation for sequence groups with multiple # TODO(woosuk): Support recomputation for sequence groups with multiple
# sequences. This may require a more sophisticated CUDA kernel. # sequences. This may require a more sophisticated CUDA kernel.
if preemption_mode is None: if preemption_mode is None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) if seq_group.get_max_num_running_seqs() == 1:
if len(seqs) == 1:
preemption_mode = PreemptionMode.RECOMPUTE preemption_mode = PreemptionMode.RECOMPUTE
else: else:
preemption_mode = PreemptionMode.SWAP preemption_mode = PreemptionMode.SWAP
@@ -368,7 +343,7 @@ class Scheduler:
elif preemption_mode == PreemptionMode.SWAP: elif preemption_mode == PreemptionMode.SWAP:
self._preempt_by_swap(seq_group, blocks_to_swap_out) self._preempt_by_swap(seq_group, blocks_to_swap_out)
else: else:
assert False, 'Invalid preemption mode.' assert False, "Invalid preemption mode."
def _preempt_by_recompute( def _preempt_by_recompute(
self, self,
@@ -388,9 +363,6 @@ class Scheduler:
seq_group: SequenceGroup, seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
) -> None: ) -> None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
for seq in seqs:
seq.status = SequenceStatus.SWAPPED
self._swap_out(seq_group, blocks_to_swap_out) self._swap_out(seq_group, blocks_to_swap_out)
self.swapped.append(seq_group) self.swapped.append(seq_group)

View File

@@ -11,10 +11,12 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
class EngineArgs: class EngineArgs:
"""Arguments for vLLM engine.""" """Arguments for vLLM engine."""
model: str model: str
tokenizer: Optional[str] = None
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
download_dir: Optional[str] = None download_dir: Optional[str] = None
use_np_weights: bool = False load_format: str = 'auto'
use_dummy_weights: bool = False dtype: str = 'auto'
dtype: str = "auto"
seed: int = 0 seed: int = 0
worker_use_ray: bool = False worker_use_ray: bool = False
pipeline_parallel_size: int = 1 pipeline_parallel_size: int = 1
@@ -27,72 +29,117 @@ class EngineArgs:
disable_log_stats: bool = False disable_log_stats: bool = False
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens) self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
parser: argparse.ArgumentParser, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine.""" """Shared CLI arguments for vLLM engine."""
# Model arguments # Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', parser.add_argument(
help='name or path of the huggingface model to use') '--model',
parser.add_argument('--download-dir', type=str, type=str,
default='facebook/opt-125m',
help='name or path of the huggingface model to use')
parser.add_argument(
'--tokenizer',
type=str,
default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use')
parser.add_argument('--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'],
help='tokenizer mode. "auto" will use the fast '
'tokenizer if available, and "slow" will '
'always use the slow tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument('--download-dir',
type=str,
default=EngineArgs.download_dir, default=EngineArgs.download_dir,
help='directory to download and load the weights, ' help='directory to download and load the weights, '
'default to the default cache dir of ' 'default to the default cache dir of '
'huggingface') 'huggingface')
parser.add_argument('--use-np-weights', action='store_true', parser.add_argument(
help='save a numpy copy of model weights for ' '--load-format',
'faster loading. This can increase the disk ' type=str,
'usage by up to 2x.') default=EngineArgs.load_format,
parser.add_argument('--use-dummy-weights', action='store_true', choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
help='use dummy values for model weights') help='The format of the model weights to load. '
'"auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
'is not available. '
'"pt" will load the weights in the pytorch bin format. '
'"safetensors" will load the weights in the safetensors format. '
'"npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading. '
'"dummy" will initialize the weights with random values, '
'which is mainly for profiling.')
# TODO(woosuk): Support FP32. # TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default=EngineArgs.dtype, parser.add_argument(
choices=['auto', 'half', 'bfloat16', 'float'], '--dtype',
help='data type for model weights and activations. ' type=str,
'The "auto" option will use FP16 precision ' default=EngineArgs.dtype,
'for FP32 and FP16 models, and BF16 precision ' choices=['auto', 'half', 'bfloat16', 'float'],
'for BF16 models.') help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
# Parallel arguments # Parallel arguments
parser.add_argument('--worker-use-ray', action='store_true', parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be ' help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU') 'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
default=EngineArgs.pipeline_parallel_size, default=EngineArgs.pipeline_parallel_size,
help='number of pipeline stages') help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, parser.add_argument('--tensor-parallel-size',
'-tp',
type=int,
default=EngineArgs.tensor_parallel_size, default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas') help='number of tensor parallel replicas')
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', type=int, parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size, default=EngineArgs.block_size,
choices=[8, 16, 32], choices=[8, 16, 32],
help='token block size') help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request). # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=EngineArgs.seed, parser.add_argument('--seed',
type=int,
default=EngineArgs.seed,
help='random seed') help='random seed')
parser.add_argument('--swap-space', type=int, parser.add_argument('--swap-space',
type=int,
default=EngineArgs.swap_space, default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU') help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float, parser.add_argument('--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization, default=EngineArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for' help='the percentage of GPU memory to be used for'
'the model executor') 'the model executor')
parser.add_argument('--max-num-batched-tokens', type=int, parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens, default=EngineArgs.max_num_batched_tokens,
help='maximum number of batched tokens per ' help='maximum number of batched tokens per '
'iteration') 'iteration')
parser.add_argument('--max-num-seqs', type=int, parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs, default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration') help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true', parser.add_argument('--disable-log-stats',
action='store_true',
help='disable logging statistics') help='disable logging statistics')
return parser return parser
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs": def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
# Get the list of attributes of this dataclass. # Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)] attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments. # Set the attributes from the parsed arguments.
@@ -103,16 +150,19 @@ class EngineArgs:
self, self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs. # Initialize the configs.
model_config = ModelConfig( model_config = ModelConfig(self.model, self.tokenizer,
self.model, self.download_dir, self.use_np_weights, self.tokenizer_mode, self.trust_remote_code,
self.use_dummy_weights, self.dtype, self.seed) self.download_dir, self.load_format,
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.dtype, self.seed)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space) self.swap_space)
parallel_config = ParallelConfig(self.pipeline_parallel_size, parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size, self.tensor_parallel_size,
self.worker_use_ray) self.worker_use_ray)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs) self.max_num_seqs,
model_config.get_max_model_len())
return model_config, cache_config, parallel_config, scheduler_config return model_config, cache_config, parallel_config, scheduler_config
@@ -124,12 +174,13 @@ class AsyncEngineArgs(EngineArgs):
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
parser: argparse.ArgumentParser, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
) -> argparse.ArgumentParser:
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray', action='store_true', parser.add_argument('--engine-use-ray',
action='store_true',
help='use Ray to start the LLM engine in a ' help='use Ray to start the LLM engine in a '
'separate process as the server process.') 'separate process as the server process.')
parser.add_argument('--disable-log-requests', action='store_true', parser.add_argument('--disable-log-requests',
action='store_true',
help='disable logging requests') help='disable logging requests')
return parser return parser

View File

@@ -1,7 +1,9 @@
import asyncio import asyncio
import time import time
from typing import Dict, List, Optional from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster, ray from vllm.engine.ray_utils import initialize_cluster, ray
@@ -11,7 +13,202 @@ from vllm.sampling_params import SamplingParams
logger = init_logger(__name__) logger = init_logger(__name__)
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
class AsyncEngineDeadError(RuntimeError):
pass
def _raise_exception_on_finish(task: asyncio.Task,
request_tracker: "RequestTracker") -> None:
msg = ("Task finished unexpectedly. This should never happen! "
"Please open an issue on Github.")
try:
try:
task.result()
except asyncio.CancelledError:
return
except Exception as exc:
raise AsyncEngineDeadError(
msg + " See stack trace above for the actual cause.") from exc
raise AsyncEngineDeadError(msg)
except Exception as exc:
request_tracker.propagate_exception(exc)
raise exc
class AsyncStream:
"""A stream of RequestOutputs for a request that can be
iterated over asynchronously."""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._queue = asyncio.Queue()
self._finished = False
def put(self, item: RequestOutput) -> None:
if self._finished:
return
self._queue.put_nowait(item)
def finish(self) -> None:
self._queue.put_nowait(StopIteration)
self._finished = True
@property
def finished(self) -> bool:
return self._finished
def __aiter__(self):
return self
async def __anext__(self) -> RequestOutput:
result = await self._queue.get()
if result is StopIteration:
raise StopAsyncIteration
elif isinstance(result, Exception):
raise result
return result
class RequestTracker:
"""Synchronous abstraction for tracking requests."""
def __init__(self) -> None:
self._request_streams: Dict[str, AsyncStream] = {}
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
def __contains__(self, item):
return item in self._request_streams
def propagate_exception(self, exc: Exception) -> None:
"""Propagate an exception to all request streams."""
for stream in self._request_streams.values():
stream.put(exc)
def process_request_output(self,
request_output: RequestOutput,
*,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
request_id = request_output.request_id
self._request_streams[request_id].put(request_output)
if request_output.finished:
if verbose:
logger.info(f"Finished request {request_id}.")
self.abort_request(request_id)
def add_request(self, request_id: str,
**engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background
loop iteration."""
if request_id in self._request_streams:
raise KeyError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id)
self._new_requests.put_nowait((stream, {
"request_id": request_id,
**engine_add_request_kwargs
}))
return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
logger.info(f"Aborted request {request_id}.")
self._finished_requests.put_nowait(request_id)
if request_id not in self._request_streams or self._request_streams[
request_id].finished:
# The request has already finished or been aborted.
return
self._request_streams[request_id].finish()
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[dict] = []
finished_requests: Set[str] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
finished_requests.add(request_id)
self._request_streams.pop(request_id, None)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if stream.request_id in finished_requests:
# The request has already been aborted.
stream.finish()
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
return new_requests, finished_requests
class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""
async def step_async(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
(seq_group_metadata_list, scheduler_outputs,
early_return) = self._schedule()
if early_return is not None:
return early_return
# Execute the model.
output = await self._run_workers_async(
"execute_model",
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)
return self._process_model_outputs(output, scheduler_outputs)
async def _run_workers_async(
self,
method: str,
*args,
get_all_outputs: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
all_outputs = []
for worker in self.workers:
if self.parallel_config.worker_use_ray:
executor = partial(worker.execute_method.remote, method)
else:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
all_outputs.append(output)
if self.parallel_config.worker_use_ray:
all_outputs = await asyncio.gather(*all_outputs)
if get_all_outputs:
return all_outputs
# Make sure all workers have the same results.
output = all_outputs[0]
for other_output in all_outputs[1:]:
assert output == other_output
return output
class AsyncLLMEngine: class AsyncLLMEngine:
@@ -33,55 +230,131 @@ class AsyncLLMEngine:
async frontend will be executed in a separate process as the async frontend will be executed in a separate process as the
model workers. model workers.
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args, *kwargs: Arguments for LLMEngine. *args, *kwargs: Arguments for LLMEngine.
""" """
def __init__(self, worker_use_ray: bool, engine_use_ray: bool,
log_requests: bool = True, *args, **kwargs) -> None: _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
def __init__(self,
worker_use_ray: bool,
engine_use_ray: bool,
*args,
log_requests: bool = True,
start_engine_loop: bool = True,
**kwargs) -> None:
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray self.engine_use_ray = engine_use_ray
self.log_requests = log_requests self.log_requests = log_requests
if not self.engine_use_ray: self.engine = self._init_engine(*args, **kwargs)
engine_class = LLMEngine
elif self.worker_use_ray:
engine_class = ray.remote(num_cpus=0)(LLMEngine).remote
else:
engine_class = ray.remote(num_gpus=1)(LLMEngine).remote
self.engine = engine_class(*args, **kwargs)
# Request id -> request output.
self.request_outputs: Dict[str, RequestOutput] = {}
# Request id -> event to notify that there is new output.
self.request_events: Dict[str, asyncio.Event] = {}
self.is_engine_running = False
self.kicking_request_id: Optional[str] = None
async def engine_step(self, kicking_request_id: Optional[str] = None): self.request_tracker: RequestTracker = RequestTracker()
self.background_loop = None
self.start_engine_loop = start_engine_loop
@property
def is_running(self) -> bool:
return (self.background_loop is not None
and not self.background_loop.done())
def start_background_loop(self) -> None:
"""Start the background loop."""
if self.is_running:
raise RuntimeError("Background loop is already running.")
self.background_loop = asyncio.get_event_loop().create_task(
self.run_engine_loop())
self.background_loop.add_done_callback(
partial(_raise_exception_on_finish,
request_tracker=self.request_tracker))
def _init_engine(self, *args,
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
if not self.engine_use_ray:
engine_class = self._engine_class
elif self.worker_use_ray:
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
else:
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
return engine_class(*args, **kwargs)
async def engine_step(self):
"""Kick the engine to process the waiting requests.""" """Kick the engine to process the waiting requests."""
self.is_engine_running = True
self.kicking_request_id = kicking_request_id new_requests, finished_requests = (
self.request_tracker.get_new_and_finished_requests())
for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue.
# TODO: Maybe add add_request_batch to reduce Ray overhead
if self.engine_use_ray:
await self.engine.add_request.remote(**new_request)
else:
self.engine.add_request(**new_request)
if finished_requests:
await self._engine_abort(finished_requests)
if self.engine_use_ray: if self.engine_use_ray:
request_outputs = await self.engine.step.remote() request_outputs = await self.engine.step.remote()
else: else:
# Yield to the event loop to allow other coroutines to run request_outputs = await self.engine.step_async()
# while is_engine_running is True. This let the engine to add new
# requests into the queue.
await asyncio.sleep(0)
request_outputs = self.engine.step()
self.is_engine_running = False
self.kicking_request_id = None
# Notify the waiting coroutines that there are new outputs ready. # Put the outputs into the corresponding streams.
for request_output in request_outputs: for request_output in request_outputs:
request_id = request_output.request_id self.request_tracker.process_request_output(
self.request_outputs[request_id] = request_output request_output, verbose=self.log_requests)
self.request_events[request_id].set()
async def generate( async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
await self.engine.abort_request.remote(request_ids)
else:
self.engine.abort_request(request_ids)
async def run_engine_loop(self):
while True:
await self.engine_step()
await asyncio.sleep(0)
async def add_request(
self, self,
request_id: str,
prompt: Optional[str], prompt: Optional[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, prompt_token_ids: Optional[List[int]] = None,
prompt_token_ids: Optional[List[int]] = None arrival_time: Optional[float] = None,
) -> RequestOutput: ) -> AsyncStream:
if self.log_requests:
logger.info(f"Received request {request_id}: "
f"prompt: {prompt!r}, "
f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.")
if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
else:
raise AsyncEngineDeadError(
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
stream = self.request_tracker.add_request(
request_id,
prompt=prompt,
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
return stream
async def generate(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
"""Generate outputs for a request. """Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the Generate outputs for a request. This method is a coroutine. It adds the
@@ -103,69 +376,19 @@ class AsyncLLMEngine:
# Preprocess the request. # Preprocess the request.
arrival_time = time.time() arrival_time = time.time()
# Create an event to notify us that there is new output from the try:
# vLLM engine. stream = await self.add_request(request_id,
request_event = asyncio.Event() prompt,
self.request_events[request_id] = request_event sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
if self.log_requests: async for request_output in stream:
logger.info(f"Received request {request_id}: " yield request_output
f"prompt: {prompt!r}, " except Exception as e:
f"sampling params: {sampling_params}, " # If there is an exception, abort the request.
f"prompt token ids: {prompt_token_ids}.") self._abort(request_id)
raise e
# Add the request into the vLLM engine's waiting queue.
if self.engine_use_ray:
await self.engine.add_request.remote(
request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
else:
self.engine.add_request(
request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
# The vLLM engine does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# the engine to process the requests.
while True:
if request_id not in self.request_events:
# The request has been aborted.
return
# Kick the engine if the engine is not running.
if not self.is_engine_running:
await self.engine_step(request_id)
# Wait for new output. The group_event will be set in engine_step
# when there is new output available for the sequence group.
# Added a timeout to prevent deadlock.
try:
await asyncio.wait_for(request_event.wait(),
timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
except asyncio.TimeoutError:
continue
# Reset the event to wait for the next output.
request_event.clear()
# Decode and return new outputs.
request_output = self.request_outputs[request_id]
yield request_output
# Once finished, release the resources of the sequence group.
if request_output.finished():
if self.log_requests:
logger.info(f"Finished request {request_id}.")
del self.request_outputs[request_id]
del self.request_events[request_id]
# Kick the engine if the engine is not running. This is to
# prevent that there are still requests in engine's waiting
# queue to be executed.
if not self.is_engine_running:
await self.engine_step()
break
async def abort(self, request_id: str) -> None: async def abort(self, request_id: str) -> None:
"""Abort a request. """Abort a request.
@@ -176,43 +399,52 @@ class AsyncLLMEngine:
Args: Args:
request_id: The unique id of the request. request_id: The unique id of the request.
""" """
if request_id not in self.request_events: if not self.is_running:
# The request has already finished or been aborted. raise AsyncEngineDeadError(
return "Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
if self.log_requests: return self._abort(request_id)
logger.info(f"Aborted request {request_id}.")
def _abort(self, request_id: str) -> None:
"""Abort a request.
Abort a submitted request. If the request is finished or not found,
this method will be a no-op.
Args:
request_id: The unique id of the request.
"""
self.request_tracker.abort_request(request_id,
verbose=self.log_requests)
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.abort_request.remote(request_id) return await self.engine.get_model_config.remote()
else: else:
self.engine.abort_request(request_id) return self.engine.get_model_config()
if request_id in self.request_events:
del self.request_events[request_id]
if request_id in self.request_outputs:
del self.request_outputs[request_id]
# To prevent deadlock when a request is aborted while the engine is
# running.
if self.kicking_request_id == request_id:
self.is_engine_running = False
self.kicking_request_id = None
@classmethod @classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine": def from_engine_args(cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2] parallel_config = engine_configs[2]
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, devices = initialize_cluster( distributed_init_method, placement_group = initialize_cluster(
parallel_config, engine_args.engine_use_ray) parallel_config, engine_args.engine_use_ray)
# Create the async LLM engine. # Create the async LLM engine.
engine = cls(engine_args.worker_use_ray, engine = cls(engine_args.worker_use_ray,
engine_args.engine_use_ray, engine_args.engine_use_ray,
not engine_args.disable_log_requests,
*engine_configs, *engine_configs,
distributed_init_method, devices, distributed_init_method,
log_stats=not engine_args.disable_log_stats) placement_group,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop)
return engine return engine

View File

@@ -1,21 +1,34 @@
import copy
import time import time
from typing import Any, List, Optional from functools import partial
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
from vllm.engine.tokenizer_utils import detokenize_incrementally, get_tokenizer
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs,
SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
from vllm.utils import Counter from vllm.utils import Counter
from vllm.worker.worker import Worker
if ray:
from ray.air.util.torch_dist import init_torch_dist_process_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__) logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 5
class LLMEngine: class LLMEngine:
"""An LLM engine that receives requests and generates texts. """An LLM engine that receives requests and generates texts.
@@ -53,19 +66,20 @@ class LLMEngine:
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
distributed_init_method: str, distributed_init_method: str,
stage_devices: List[List[DeviceID]], placement_group: Optional["PlacementGroup"],
log_stats: bool, log_stats: bool,
) -> None: ) -> None:
logger.info( logger.info(
"Initializing an LLM engine with config: " "Initializing an LLM engine with config: "
f"model={model_config.model!r}, " f"model={model_config.model!r}, "
f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, "
f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, " f"dtype={model_config.dtype}, "
f"use_dummy_weights={model_config.use_dummy_weights}, "
f"download_dir={model_config.download_dir!r}, " f"download_dir={model_config.download_dir!r}, "
f"use_np_weights={model_config.use_np_weights}, " f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"seed={model_config.seed})" f"seed={model_config.seed})")
)
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config self.model_config = model_config
@@ -75,34 +89,91 @@ class LLMEngine:
self.log_stats = log_stats self.log_stats = log_stats
self._verify_args() self._verify_args()
self.tokenizer = get_tokenizer(model_config.model) self.tokenizer = get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)
self.seq_counter = Counter() self.seq_counter = Counter()
# Create the parallel GPU workers. # Create the parallel GPU workers.
self.workers: List[Worker] = [] if self.parallel_config.worker_use_ray:
assert len(stage_devices) == 1, "Only support one stage for now." self._init_workers_ray(placement_group)
for rank, node_resource, _ in stage_devices[0]: else:
worker_cls = Worker self._init_workers(distributed_init_method)
if self.parallel_config.worker_use_ray:
worker_cls = ray.remote(
num_cpus=0,
num_gpus=1,
resources={node_resource: 1e-5},
)(worker_cls).remote
worker = worker_cls(
model_config,
parallel_config,
scheduler_config,
rank,
distributed_init_method,
)
self.workers.append(worker)
# Profile the memory usage and initialize the cache. # Profile the memory usage and initialize the cache.
self._init_cache() self._init_cache()
# Create the scheduler. # Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats) self.scheduler = Scheduler(scheduler_config, cache_config)
# Logging.
self.last_logging_time = 0.0
# List of (timestamp, num_tokens)
self.num_prompt_tokens: List[Tuple[float, int]] = []
# List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = []
def _init_workers(self, distributed_init_method: str):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
assert self.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
self.workers: List[Worker] = []
worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
0,
distributed_init_method,
)
self.workers.append(worker)
self._run_workers(
"init_model",
get_all_outputs=True,
)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
self.workers: List[Worker] = []
for bundle in placement_group.bundle_specs:
if not bundle.get("GPU", 0):
continue
worker = ray.remote(
num_cpus=0,
num_gpus=1,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True),
**ray_remote_kwargs,
)(RayWorker).remote(self.model_config.trust_remote_code)
self.workers.append(worker)
# Initialize torch distributed process group for the workers.
init_torch_dist_process_group(self.workers, backend="nccl")
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
self._run_workers("init_worker",
get_all_outputs=True,
worker_init_fn=lambda: Worker(
model_config,
parallel_config,
scheduler_config,
None,
None,
))
self._run_workers(
"init_model",
get_all_outputs=True,
)
def _verify_args(self) -> None: def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
@@ -125,8 +196,14 @@ class LLMEngine:
num_gpu_blocks = min(b[0] for b in num_blocks) num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks)
# FIXME(woosuk): Change to debug log. # FIXME(woosuk): Change to debug log.
logger.info(f'# GPU blocks: {num_gpu_blocks}, ' logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f'# CPU blocks: {num_cpu_blocks}') f"# CPU blocks: {num_cpu_blocks}")
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks
@@ -140,9 +217,12 @@ class LLMEngine:
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2] parallel_config = engine_configs[2]
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config) distributed_init_method, placement_group = initialize_cluster(
parallel_config)
# Create the LLM engine. # Create the LLM engine.
engine = cls(*engine_configs, distributed_init_method, devices, engine = cls(*engine_configs,
distributed_init_method,
placement_group,
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return engine return engine
@@ -178,27 +258,28 @@ class LLMEngine:
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
seqs: List[Sequence] = [] seq_id = next(self.seq_counter)
for _ in range(sampling_params.best_of): seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seqs.append(seq)
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id, seqs, sampling_params, seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time) arrival_time)
# Add the sequence group to the scheduler. # Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group) self.scheduler.add_seq_group(seq_group)
def abort_request(self, request_id: str) -> None: def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a request with the given ID. """Aborts a request(s) with the given ID.
Args: Args:
request_id: The ID of the request to abort. request_id: The ID(s) of the request to abort.
""" """
self.scheduler.abort_seq_group(request_id) self.scheduler.abort_seq_group(request_id)
def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests.""" """Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups() return self.scheduler.get_num_unfinished_seq_groups()
@@ -207,6 +288,251 @@ class LLMEngine:
"""Returns True if there are unfinished requests.""" """Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs() return self.scheduler.has_unfinished_seqs()
def _schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
Optional[List[RequestOutput]]]:
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if scheduler_outputs.is_empty():
return seq_group_metadata_list, scheduler_outputs, [
RequestOutput.from_seq_group(seq_group)
for seq_group in scheduler_outputs.ignored_seq_groups
]
return seq_group_metadata_list, scheduler_outputs, None
def _check_beam_search_early_stopping(
self,
early_stopping: Union[bool, str],
sampling_params: SamplingParams,
best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
assert sampling_params.use_beam_search
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = (current_worst_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
if early_stopping is False:
highest_attainable_score = (best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
else:
assert early_stopping == "never"
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(
best_running_seq.get_prompt_len() +
sampling_params.max_tokens,
self.scheduler_config.max_model_len)
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
return current_worst_score >= highest_attainable_score
def _process_sequence_group_samples(
self, seq_group: SequenceGroup,
samples: List[SequenceOutputs]) -> None:
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
}
for sample in samples:
parent_child_dict[sample.parent_seq_id].append(sample)
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []
# Process the child samples for each parent sequence
for parent in parent_seqs:
child_samples: List[SequenceOutputs] = parent_child_dict[
parent.seq_id]
if len(child_samples) == 0:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id)
self.scheduler.free_seq(parent)
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
new_child_seq_id = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
self._decode_sequence(seq)
self._check_stop(seq, seq_group.sampling_params)
# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs = []
unselected_child_seqs = []
beam_width = seq_group.sampling_params.best_of
length_penalty = seq_group.sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False)
for seq in existing_finished_seqs]
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
if seq.is_finished()]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
if not seq.is_finished()]
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id),
reverse=True)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
seq_group.sampling_params.early_stopping,
seq_group.sampling_params, best_running_seq, current_worst_seq)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
self.scheduler.free_seq(seq)
def _process_model_outputs(
self, output: SamplerOutput,
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for seq_group, samples in zip(scheduled_seq_groups, output):
self._process_sequence_group_samples(seq_group, samples)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in (scheduled_seq_groups +
scheduler_outputs.ignored_seq_groups):
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
if self.log_stats:
# Log the system stats.
self._log_system_stats(scheduler_outputs.prompt_run,
scheduler_outputs.num_batched_tokens)
return request_outputs
def step(self) -> List[RequestOutput]: def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
@@ -216,10 +542,10 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() (seq_group_metadata_list, scheduler_outputs,
if (not seq_group_metadata_list) and scheduler_outputs.is_empty(): early_return) = self._schedule()
# Nothing to do. if early_return is not None:
return [] return early_return
# Execute the model. # Execute the model.
output = self._run_workers( output = self._run_workers(
@@ -229,80 +555,125 @@ class LLMEngine:
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy, blocks_to_copy=scheduler_outputs.blocks_to_copy,
) )
# Update the scheduler with the model outputs.
seq_groups = self.scheduler.update(output)
# Decode the sequences. return self._process_model_outputs(output, scheduler_outputs)
self._decode_sequences(seq_groups)
# Stop the sequences that meet the stopping criteria.
self._stop_sequences(seq_groups)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()
# Create the outputs. def _log_system_stats(
request_outputs: List[RequestOutput] = [] self,
for seq_group in seq_groups: prompt_run: bool,
request_output = RequestOutput.from_seq_group(seq_group) num_batched_tokens: int,
request_outputs.append(request_output) ) -> None:
return request_outputs now = time.time()
# Log the number of batched input tokens.
if prompt_run:
self.num_prompt_tokens.append((now, num_batched_tokens))
else:
self.num_generation_tokens.append((now, num_batched_tokens))
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: elapsed_time = now - self.last_logging_time
"""Decodes the sequence outputs.""" if elapsed_time < _LOGGING_INTERVAL_SEC:
for seq_group in seq_groups: return
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_token, new_output_text = detokenize_incrementally(
self.tokenizer,
seq.output_tokens,
seq.get_last_token_id(),
skip_special_tokens=True,
)
seq.output_tokens.append(new_token)
seq.output_text = new_output_text
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None: # Discard the old stats.
self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens
if now - t < _LOGGING_INTERVAL_SEC]
self.num_generation_tokens = [(t, n)
for t, n in self.num_generation_tokens
if now - t < _LOGGING_INTERVAL_SEC]
if len(self.num_prompt_tokens) > 1:
total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
window = now - self.num_prompt_tokens[0][0]
avg_prompt_throughput = total_num_tokens / window
else:
avg_prompt_throughput = 0.0
if len(self.num_generation_tokens) > 1:
total_num_tokens = sum(n
for _, n in self.num_generation_tokens[:-1])
window = now - self.num_generation_tokens[0][0]
avg_generation_throughput = total_num_tokens / window
else:
avg_generation_throughput = 0.0
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = (
self.scheduler.block_manager.get_num_free_gpu_blocks())
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0:
num_free_cpu_blocks = (
self.scheduler.block_manager.get_num_free_cpu_blocks())
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else:
cpu_cache_usage = 0.0
logger.info("Avg prompt throughput: "
f"{avg_prompt_throughput:.1f} tokens/s, "
"Avg generation throughput: "
f"{avg_generation_throughput:.1f} tokens/s, "
f"Running: {len(self.scheduler.running)} reqs, "
f"Swapped: {len(self.scheduler.swapped)} reqs, "
f"Pending: {len(self.scheduler.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now
def _decode_sequence(self, seq: Sequence) -> None:
"""Decodes the new token for a sequence."""
new_token, new_output_text = detokenize_incrementally(
self.tokenizer,
seq.output_tokens,
seq.get_last_token_id(),
skip_special_tokens=True,
)
if new_token is not None:
seq.output_tokens.append(new_token)
seq.output_text = new_output_text
def _check_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
"""Stop the finished sequences.""" """Stop the finished sequences."""
for seq_group in seq_groups: for stop_str in sampling_params.stop:
sampling_params = seq_group.sampling_params if seq.output_text.endswith(stop_str):
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): # Truncate the output text so that the stop string is
# Check if the sequence has generated a stop string. # not included in the output.
stopped = False seq.output_text = seq.output_text[:-len(stop_str)]
for stop_str in sampling_params.stop: seq.status = SequenceStatus.FINISHED_STOPPED
if seq.output_text.endswith(stop_str): return
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
self.scheduler.free_seq(seq,
SequenceStatus.FINISHED_STOPPED)
stopped = True
break
if stopped:
continue
# Check if the sequence has reached max_tokens. # Check if the sequence has reached max_model_len.
if seq.get_output_len() == sampling_params.max_tokens: if seq.get_len() > self.scheduler_config.max_model_len:
self.scheduler.free_seq( seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
seq, SequenceStatus.FINISHED_LENGTH_CAPPED) return
continue
# Check if the sequence has generated the EOS token. # Check if the sequence has reached max_tokens.
if not sampling_params.ignore_eos: if seq.get_output_len() == sampling_params.max_tokens:
if seq.get_last_token_id() == self.tokenizer.eos_token_id: seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
self.scheduler.free_seq(seq, return
SequenceStatus.FINISHED_STOPPED)
continue # Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == self.tokenizer.eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED
return
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
get_all_outputs: bool = False,
*args, *args,
get_all_outputs: bool = False,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers.""" """Runs the given method on all workers."""
all_outputs = [] all_outputs = []
for worker in self.workers: for worker in self.workers:
executor = getattr(worker, method)
if self.parallel_config.worker_use_ray: if self.parallel_config.worker_use_ray:
executor = executor.remote executor = partial(worker.execute_method.remote, method)
else:
executor = getattr(worker, method)
output = executor(*args, **kwargs) output = executor(*args, **kwargs)
all_outputs.append(output) all_outputs.append(output)

View File

@@ -1,21 +1,53 @@
import random import socket
from typing import List, Optional, Tuple from typing import Optional, Tuple, TYPE_CHECKING
try:
import ray
except ImportError:
ray = None
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), device id try:
import ray
from ray.air.util.torch_dist import TorchDistributedWorker
class RayWorker(TorchDistributedWorker):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self, init_cached_hf_modules=False) -> None:
if init_cached_hf_modules:
# pylint: disable=import-outside-toplevel
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
self.worker = None
def init_worker(self, worker_init_fn):
self.worker = worker_init_fn()
def __getattr__(self, name):
return getattr(self.worker, name)
def execute_method(self, method, *args, **kwargs):
executor = getattr(self, method)
return executor(*args, **kwargs)
except ImportError:
ray = None
TorchDistributedWorker = None
RayWorker = None # pylint: disable=invalid-name
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def initialize_cluster( def initialize_cluster(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
engine_use_ray: bool = False, engine_use_ray: bool = False,
ray_address: Optional[str] = None, ray_address: Optional[str] = None,
) -> Tuple[str, List[List[DeviceID]]]: ) -> Tuple[str, Optional["PlacementGroup"]]:
"""Initialize the distributed cluster probably with Ray. """Initialize the distributed cluster probably with Ray.
Args: Args:
@@ -37,71 +69,46 @@ def initialize_cluster(
"Ray is not installed. Please install Ray to use distributed " "Ray is not installed. Please install Ray to use distributed "
"serving.") "serving.")
# Connect to a ray cluster. # Connect to a ray cluster.
ray.init(address=ray_address) ray.init(address=ray_address, ignore_reinit_error=True)
if not parallel_config.worker_use_ray: if not parallel_config.worker_use_ray:
# Initialize cluster locally. # Initialize cluster locally.
port = random.randint(10000, 20000) port = get_open_port()
# We need to setup the distributed init method to make sure # We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly. # the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method = f"tcp://localhost:{port}" distributed_init_method = f"tcp://localhost:{port}"
all_stage_devices = [[(0, None, 0)]] return distributed_init_method, None
return distributed_init_method, all_stage_devices
# Assume we have a uniform cluster that each node has the same number of current_placement_group = ray.util.get_current_placement_group()
# GPUs for now. if current_placement_group:
valid_node_resources = [] # We are in a placement group
num_devices_per_node = None bundles = current_placement_group.bundle_specs
for node in ray.nodes(): # Verify that we can use the placement group.
if (not node['Alive']) or node['Resources']['GPU'] <= 0: gpu_bundles = 0
continue for bundle in bundles:
if num_devices_per_node is None: bundle_gpus = bundle.get("GPU", 0)
num_devices_per_node = node['Resources']['GPU'] if bundle_gpus > 1:
else: raise ValueError(
assert num_devices_per_node == node['Resources']['GPU'], ( "Placement group bundle cannot have more than 1 GPU.")
"The number of GPUs per node is not uniform.") if bundle_gpus:
for key in node['Resources']: gpu_bundles += 1
if key.startswith('node:'): if parallel_config.world_size > gpu_bundles:
valid_node_resources.append(key)
# Verify the parallel config.
num_nodes = len(valid_node_resources)
if parallel_config.world_size > num_nodes * num_devices_per_node:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs.")
if parallel_config.tensor_parallel_size >= num_devices_per_node:
if parallel_config.tensor_parallel_size % num_devices_per_node != 0:
raise ValueError( raise ValueError(
"The number of tensor parallelism is not divisible by the " "The number of required GPUs exceeds the total number of "
"number of GPUs per node.") "available GPUs in the placement group.")
else: else:
if num_devices_per_node % parallel_config.tensor_parallel_size != 0: num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
if parallel_config.world_size > num_gpus_in_cluster:
raise ValueError( raise ValueError(
"The number of GPUs per node is not divisible by the number " "The number of required GPUs exceeds the total number of "
"of tensor parallelism.") "available GPUs in the cluster.")
# Create a new placement group
current_placement_group = ray.util.placement_group([{
"GPU": 1
}] * parallel_config.world_size)
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
ray.get(current_placement_group.ready(), timeout=1800)
# Assign GPUs to pipeline stages. return None, current_placement_group
rank = 0
current_node_id = 0
current_device_id = 0
distributed_init_method = None
all_stage_devices = []
for _ in range(parallel_config.pipeline_parallel_size):
stage_devices = []
for _ in range(parallel_config.tensor_parallel_size):
node_resource = valid_node_resources[current_node_id]
stage_devices.append((rank, node_resource, current_device_id))
if distributed_init_method is None:
ip = node_resource.split("node:")[-1]
port = random.randint(10000, 20000)
distributed_init_method = f"tcp://{ip}:{port}"
rank += 1
current_device_id += 1
if current_device_id >= num_devices_per_node:
current_node_id += 1
current_device_id = 0
all_stage_devices.append(stage_devices)
return distributed_init_method, all_stage_devices

View File

@@ -1,92 +0,0 @@
from typing import List, Tuple, Union
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.logger import init_logger
logger = init_logger(__name__)
_MODEL_TYPES_WITH_SLOW_TOKENIZER = []
def get_tokenizer(
model_name: str,
*args,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
config = AutoConfig.from_pretrained(model_name)
if "open_llama" in model_name:
kwargs["use_fast"] = False
logger.info(
"OpenLLaMA models do not support the fast tokenizer. "
"Using the slow tokenizer instead.")
elif config.model_type == "llama" and getattr(kwargs, "use_fast", True):
# LLaMA fast tokenizer causes protobuf errors in some environments.
# However, we found that the below LLaMA fast tokenizer works well in
# most environments.
model_name = "hf-internal-testing/llama-tokenizer"
logger.info(
f"Using the LLaMA fast tokenizer in '{model_name}' to avoid "
"potential protobuf errors.")
elif config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
if getattr(kwargs, "use_fast", False) == True:
raise ValueError(
f"Cannot use the fast tokenizer for {config.model_type} due to "
"bugs in the fast tokenizer.")
logger.info(
f"Using the slow tokenizer for {config.model_type} due to bugs in "
"the fast tokenizer. This could potentially lead to performance "
"degradation.")
kwargs["use_fast"] = False
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)
def detokenize_incrementally(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prev_output_tokens: List[str],
new_token_id: int,
skip_special_tokens: bool,
) -> Tuple[str, str]:
"""Detokenizes the new token in conjuction with the previous output tokens.
NOTE: This function does not update prev_output_tokens.
Returns:
new_token: The new token as a string.
output_text: The new output text as a string.
"""
new_token = tokenizer.convert_ids_to_tokens(
new_token_id, skip_special_tokens=skip_special_tokens)
output_tokens = prev_output_tokens + [new_token]
# Convert the tokens to a string.
# Optimization: If the tokenizer does not have `added_tokens_encoder`,
# then we can directly use `convert_tokens_to_string`.
if not getattr(tokenizer, "added_tokens_encoder", {}):
output_text = tokenizer.convert_tokens_to_string(output_tokens)
return new_token, output_text
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts = []
current_sub_text = []
for token in output_tokens:
if skip_special_tokens and token in tokenizer.all_special_ids:
continue
if token in tokenizer.added_tokens_encoder:
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
output_text = " ".join(sub_texts)
return new_token, output_text

View File

@@ -3,7 +3,7 @@ import json
from typing import AsyncGenerator from typing import AsyncGenerator
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn import uvicorn
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
@@ -11,9 +11,10 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI() app = FastAPI()
engine = None
@app.post("/generate") @app.post("/generate")
@@ -30,6 +31,7 @@ async def generate(request: Request) -> Response:
stream = request_dict.pop("stream", False) stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
request_id = random_uuid() request_id = random_uuid()
results_generator = engine.generate(prompt, sampling_params, request_id) results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case # Streaming case
@@ -37,8 +39,7 @@ async def generate(request: Request) -> Response:
async for request_output in results_generator: async for request_output in results_generator:
prompt = request_output.prompt prompt = request_output.prompt
text_outputs = [ text_outputs = [
prompt + output.text prompt + output.text for output in request_output.outputs
for output in request_output.outputs
] ]
ret = {"text": text_outputs} ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
@@ -63,12 +64,9 @@ async def generate(request: Request) -> Response:
assert final_output is not None assert final_output is not None
prompt = final_output.prompt prompt = final_output.prompt
text_outputs = [ text_outputs = [prompt + output.text for output in final_output.outputs]
prompt + output.text
for output in final_output.outputs
]
ret = {"text": text_outputs} ret = {"text": text_outputs}
return Response(content=json.dumps(ret)) return JSONResponse(ret)
if __name__ == "__main__": if __name__ == "__main__":
@@ -81,5 +79,8 @@ if __name__ == "__main__":
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app, host=args.host, port=args.port, log_level="debug", uvicorn.run(app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE) timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@@ -25,6 +25,11 @@ class LLM:
Args: Args:
model: The name or path of a HuggingFace Transformers model. model: The name or path of a HuggingFace Transformers model.
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism. execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently, dtype: The data type for the model weights and activations. Currently,
@@ -38,6 +43,9 @@ class LLM:
def __init__( def __init__(
self, self,
model: str, model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
dtype: str = "auto", dtype: str = "auto",
seed: int = 0, seed: int = 0,
@@ -47,6 +55,9 @@ class LLM:
kwargs["disable_log_stats"] = True kwargs["disable_log_stats"] = True
engine_args = EngineArgs( engine_args = EngineArgs(
model=model, model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
dtype=dtype, dtype=dtype,
seed=seed, seed=seed,
@@ -56,10 +67,15 @@ class LLM:
self.request_counter = Counter() self.request_counter = Counter()
def get_tokenizer( def get_tokenizer(
self, self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer return self.llm_engine.tokenizer
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
self.llm_engine.tokenizer = tokenizer
def generate( def generate(
self, self,
prompts: Optional[Union[str, List[str]]] = None, prompts: Optional[Union[str, List[str]]] = None,
@@ -133,10 +149,14 @@ class LLM:
while self.llm_engine.has_unfinished_requests(): while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step() step_outputs = self.llm_engine.step()
for output in step_outputs: for output in step_outputs:
if output.finished(): if output.finished:
outputs.append(output) outputs.append(output)
if use_tqdm: if use_tqdm:
pbar.update(1) pbar.update(1)
if use_tqdm: if use_tqdm:
pbar.close() pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs return outputs

View File

@@ -1,47 +1,61 @@
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py # Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import argparse import argparse
from http import HTTPStatus import asyncio
import json import json
import time import time
from typing import AsyncGenerator, Dict, List, Optional from http import HTTPStatus
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
import fastapi import fastapi
import uvicorn
from fastapi import BackgroundTasks, Request from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
import uvicorn from packaging import version
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.tokenizer_utils import get_tokenizer
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, CompletionResponseStreamChoice, CompletionStreamResponse,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo) LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds try:
import fastchat
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
_fastchat_available = True
except ImportError:
_fastchat_available = False
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__) logger = init_logger(__name__)
served_model = None served_model = None
app = fastapi.FastAPI() app = fastapi.FastAPI()
engine = None
def create_error_response(status_code: HTTPStatus, def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse: message: str) -> JSONResponse:
return JSONResponse( return JSONResponse(ErrorResponse(message=message,
ErrorResponse(message=message, type="invalid_request_error").dict(), type="invalid_request_error").dict(),
status_code=status_code.value status_code=status_code.value)
)
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc): async def validation_exception_handler(request, exc): # pylint: disable=unused-argument
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
@@ -55,11 +69,88 @@ async def check_model(request) -> Optional[JSONResponse]:
return ret return ret
async def get_gen_prompt(request) -> str:
if not _fastchat_available:
raise ModuleNotFoundError(
"fastchat is not installed. Please install fastchat to use "
"the chat completion and conversation APIs: `$ pip install fschat`"
)
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
raise ImportError(
f"fastchat version is low. Current version: {fastchat.__version__} "
"Please upgrade fastchat to use: `$ pip install -U fschat`")
conv = get_conversation_template(request.model)
conv = Conversation(
name=conv.name,
system_template=conv.system_template,
system_message=conv.system_message,
roles=conv.roles,
messages=list(conv.messages), # prevent in-place modification
offset=conv.offset,
sep_style=SeparatorStyle(conv.sep_style),
sep=conv.sep,
sep2=conv.sep2,
stop_str=conv.stop_str,
stop_token_ids=conv.stop_token_ids,
)
if isinstance(request.messages, str):
prompt = request.messages
else:
for message in request.messages:
msg_role = message["role"]
if msg_role == "system":
conv.system_message = message["content"]
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
else:
raise ValueError(f"Unknown role: {msg_role}")
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt
async def check_length(
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None
) -> Tuple[List[int], Optional[JSONResponse]]:
assert (not (prompt is None and prompt_ids is None)
and not (prompt is not None and prompt_ids is not None)
), "Either prompt or prompt_ids should be provided."
if prompt_ids is not None:
input_ids = prompt_ids
else:
input_ids = tokenizer(prompt).input_ids
token_num = len(input_ids)
if token_num + request.max_tokens > max_model_len:
return input_ids, create_error_response(
HTTPStatus.BAD_REQUEST,
f"This model's maximum context length is {max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.",
)
else:
return input_ids, None
@app.get("/v1/models") @app.get("/v1/models")
async def show_available_models(): async def show_available_models():
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
model_cards = [ModelCard(id=served_model, root=served_model, model_cards = [
permission=[ModelPermission()])] ModelCard(id=served_model,
root=served_model,
permission=[ModelPermission()])
]
return ModelList(data=model_cards) return ModelList(data=model_cards)
@@ -76,17 +167,189 @@ def create_logprobs(token_ids: List[int],
if len(logprobs.text_offset) == 0: if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset) logprobs.text_offset.append(initial_text_offset)
else: else:
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token) last_token_len = len(token)
logprobs.top_logprobs.append( logprobs.top_logprobs.append({
{tokenizer.convert_ids_to_tokens(i): p tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()}) for i, p in id_logprob.items()
})
return logprobs return logprobs
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
logger.info(f"Received chat completion request: {request}")
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.logit_bias is not None:
# TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
prompt = await get_gen_prompt(request)
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
try:
sampling_params = SamplingParams(
n=request.n,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop,
max_tokens=request.max_tokens,
best_of=request.best_of,
top_k=request.top_k,
ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json(
index: int,
text: str,
finish_reason: Optional[str] = None,
) -> str:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=text),
finish_reason=finish_reason,
)
response = ChatCompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.json(ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
# First chunk with role
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
response_json = create_stream_response_json(
index=i,
text=delta_text,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
response_json = create_stream_response_json(
index=i,
text="",
finish_reason=output.finish_reason,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if request.stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream",
background=background_tasks)
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
for output in final_res.outputs:
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role="assistant", content=output.text),
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
return response
@app.post("/v1/completions") @app.post("/v1/completions")
async def create_completion(raw_request: Request): async def create_completion(request: CompletionRequest, raw_request: Request):
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create See https://platform.openai.com/docs/api-reference/completions/create
@@ -99,7 +362,6 @@ async def create_completion(raw_request: Request):
suffix) suffix)
- logit_bias (to be supported by vLLM engine) - logit_bias (to be supported by vLLM engine)
""" """
request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}") logger.info(f"Received completion request: {request}")
error_check_ret = await check_model(request) error_check_ret = await check_model(request)
@@ -115,7 +377,7 @@ async def create_completion(raw_request: Request):
if request.suffix is not None: if request.suffix is not None:
# The language models we currently support do not support suffix. # The language models we currently support do not support suffix.
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported") "suffix is not currently supported")
if request.logit_bias is not None: if request.logit_bias is not None:
# TODO: support logit_bias in vLLM engine. # TODO: support logit_bias in vLLM engine.
@@ -124,7 +386,34 @@ async def create_completion(raw_request: Request):
model_name = request.model model_name = request.model
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
prompt = request.prompt
use_token_ids = False
if isinstance(request.prompt, list):
if len(request.prompt) == 0:
return create_error_response(HTTPStatus.BAD_REQUEST,
"please provide at least one prompt")
first_element = request.prompt[0]
if isinstance(first_element, int):
use_token_ids = True
prompt = request.prompt
elif isinstance(first_element, (str, list)):
# TODO: handles multiple prompt case in list[list[int]]
if len(request.prompt) > 1:
return create_error_response(
HTTPStatus.BAD_REQUEST,
"multiple prompts in a batch is not currently supported")
use_token_ids = not isinstance(first_element, str)
prompt = request.prompt[0]
else:
prompt = request.prompt
if use_token_ids:
_, error_check_ret = await check_length(request, prompt_ids=prompt)
else:
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
created_time = int(time.time()) created_time = int(time.time())
try: try:
sampling_params = SamplingParams( sampling_params = SamplingParams(
@@ -144,22 +433,30 @@ async def create_completion(raw_request: Request):
except ValueError as e: except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, if use_token_ids:
request_id) result_generator = engine.generate(None,
sampling_params,
request_id,
prompt_token_ids=prompt)
else:
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search. # results. In addition, we do not stream the results when use beam search.
stream = (request.stream and stream = (request.stream
(request.best_of is None or request.n == request.best_of) and and (request.best_of is None or request.n == request.best_of)
not request.use_beam_search) and not request.use_beam_search)
async def abort_request() -> None: async def abort_request() -> None:
await engine.abort(request_id) await engine.abort(request_id)
def create_stream_response_json(index: int, def create_stream_response_json(
text: str, index: int,
logprobs: Optional[LogProbs] = None, text: str,
finish_reason: Optional[str] = None) -> str: logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None,
) -> str:
choice_data = CompletionResponseStreamChoice( choice_data = CompletionResponseStreamChoice(
index=index, index=index,
text=text, text=text,
@@ -200,7 +497,8 @@ async def create_completion(raw_request: Request):
) )
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
if output.finish_reason is not None: if output.finish_reason is not None:
logprobs = LogProbs() if request.logprobs is not None else None logprobs = (LogProbs()
if request.logprobs is not None else None)
response_json = create_stream_response_json( response_json = create_stream_response_json(
index=i, index=i,
text="", text="",
@@ -208,7 +506,7 @@ async def create_completion(raw_request: Request):
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
) )
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
# Streaming response # Streaming response
if stream: if stream:
@@ -244,8 +542,8 @@ async def create_completion(raw_request: Request):
choices.append(choice_data) choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids) num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) num_generated_tokens = sum(
for output in final_res.outputs) len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens, completion_tokens=num_generated_tokens,
@@ -263,9 +561,11 @@ async def create_completion(raw_request: Request):
# When user requests streaming but we don't stream, we still need to # When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event. # return a streaming response with a single event.
response_json = response.json(ensure_ascii=False) response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]: async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(), return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream") media_type="text/event-stream")
@@ -274,26 +574,34 @@ async def create_completion(raw_request: Request):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server." description="vLLM OpenAI-Compatible RESTful API server.")
) parser.add_argument("--host",
parser.add_argument("--host", type=str, default="localhost", help="host name") type=str,
default="localhost",
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number") parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument( parser.add_argument("--allow-credentials",
"--allow-credentials", action="store_true", help="allow credentials" action="store_true",
) help="allow credentials")
parser.add_argument( parser.add_argument("--allowed-origins",
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins" type=json.loads,
) default=["*"],
parser.add_argument( help="allowed origins")
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods" parser.add_argument("--allowed-methods",
) type=json.loads,
parser.add_argument( default=["*"],
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers" help="allowed methods")
) parser.add_argument("--allowed-headers",
parser.add_argument("--served-model-name", type=str, default=None, type=json.loads,
help="The model name used in the API. If not specified, " default=["*"],
"the model name will be the same as the " help="allowed headers")
"huggingface name.") parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
@@ -307,13 +615,23 @@ if __name__ == "__main__":
logger.info(f"args: {args}") logger.info(f"args: {args}")
served_model = args.served_model_name or args.model if args.served_model_name is not None:
served_model = args.served_model_name
else:
served_model = args.model
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config())
max_model_len = engine_model_config.get_max_model_len()
# A separate tokenizer to map token IDs to strings. # A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model) tokenizer = get_tokenizer(engine_args.tokenizer,
tokenizer_mode=engine_args.tokenizer_mode,
trust_remote_code=engine_args.trust_remote_code)
uvicorn.run(app, host=args.host, port=args.port, log_level="info", uvicorn.run(app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE) timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@@ -1,4 +1,5 @@
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time import time
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
@@ -53,21 +54,28 @@ class UsageInfo(BaseModel):
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[Dict[str, str]] messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.7 temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0 top_p: Optional[float] = 1.0
n: Optional[int] = 1 n: Optional[int] = 1
max_tokens: Optional[int] = None max_tokens: Optional[int] = 16
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None user: Optional[str] = None
# Additional parameters supported by vLLM
best_of: Optional[int] = None
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str
prompt: str # a string, array of strings, array of tokens, or array of token arrays
prompt: Union[List[int], List[List[int]], str, List[str]]
suffix: Optional[str] = None suffix: Optional[str] = None
max_tokens: Optional[int] = 16 max_tokens: Optional[int] = 16
temperature: Optional[float] = 1.0 temperature: Optional[float] = 1.0
@@ -92,7 +100,8 @@ class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(BaseModel):
@@ -124,3 +133,42 @@ class CompletionStreamResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[CompletionResponseStreamChoice] choices: List[CompletionResponseStreamChoice]
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]

View File

@@ -1,9 +1,9 @@
# Adapted from https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py # Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM."""
import logging import logging
import sys import sys
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S" _DATE_FORMAT = "%m-%d %H:%M:%S"

View File

@@ -2,7 +2,6 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
__all__ = [ __all__ = [
"InputMetadata", "InputMetadata",
"get_model", "get_model",

View File

@@ -1,18 +1,29 @@
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from xformers.ops import AttentionBias
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
class InputMetadata: class InputMetadata:
"""Metadata for input sequences. Used for PagedAttention.
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token.
context_lens: the length of attention context for each generation token.
max_context_len: The maximum context length.
block_tables: The block tables. (Seq id -> list of physical block)
"""
def __init__( def __init__(
self, self,
seq_groups: List[Tuple[List[int], SamplingParams]], # List of (seq_ids, sampling_params). seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData], # Seq_id -> SequenceData. seq_data: Dict[int, SequenceData],
prompt_lens: List[int], prompt_lens: List[int],
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
context_lens: torch.Tensor, context_lens: torch.Tensor,
@@ -27,7 +38,6 @@ class InputMetadata:
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.block_tables = block_tables self.block_tables = block_tables
self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
self.num_prompts = len(prompt_lens) self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens) self.num_prompt_tokens = sum(prompt_lens)
self.num_generation_tokens = context_lens.shape[0] self.num_generation_tokens = context_lens.shape[0]
@@ -39,6 +49,9 @@ class InputMetadata:
assert block_tables.shape[0] == self.num_generation_tokens assert block_tables.shape[0] == self.num_generation_tokens
assert context_lens.shape[0] == self.num_generation_tokens assert context_lens.shape[0] == self.num_generation_tokens
# Set during the execution of the first attention op.
self.attn_bias: List[AttentionBias] = []
def __repr__(self) -> str: def __repr__(self) -> str:
# Print only useful metadata. # Print only useful metadata.
return (f'InputMetadata(' return (f'InputMetadata('

View File

@@ -4,10 +4,50 @@ import torch.nn as nn
from vllm import activation_ops from vllm import activation_ops
class SiluAndMul(nn.Module):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
Shapes:
x: (num_tokens, 2 * d)
return: (num_tokens, d)
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
d = x.shape[1] // 2
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.silu_and_mul(out, x)
return out
class NewGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
d = x.shape[1]
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.gelu_new(out, x)
return out
class FastGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
d = x.shape[1]
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.gelu_fast(out, x)
return out
_ACTIVATION_REGISTRY = { _ACTIVATION_REGISTRY = {
"gelu": nn.GELU(), "gelu": nn.GELU(),
"gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. "gelu_fast": FastGELU(),
"gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. "gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(), "relu": nn.ReLU(),
} }
@@ -18,23 +58,3 @@ def get_act_fn(act_fn: str) -> nn.Module:
if act_fn in _ACTIVATION_REGISTRY: if act_fn in _ACTIVATION_REGISTRY:
return _ACTIVATION_REGISTRY[act_fn] return _ACTIVATION_REGISTRY[act_fn]
raise ValueError(f"Activation function {act_fn!r} is not supported.") raise ValueError(f"Activation function {act_fn!r} is not supported.")
class SiluAndMul(nn.Module):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
"""
def __init__(self):
super().__init__()
def forward(
self,
x: torch.Tensor, # (num_tokens, 2 * d)
) -> torch.Tensor: # (num_tokens, d)
num_tokens = x.shape[0]
d = x.shape[1] // 2
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.silu_and_mul(out, x)
return out

View File

@@ -1,28 +1,39 @@
"""Multi-head attention.""" """Multi-head attention."""
from typing import Optional from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)
from vllm import attention_ops from vllm import attention_ops
from vllm import cache_ops from vllm import cache_ops
from vllm import pos_encoding_ops from vllm import pos_encoding_ops
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128] _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
class PagedAttention(nn.Module): class PagedAttention(nn.Module):
# pylint: disable=line-too-long
"""GPT-style multi-head PagedAttention. """GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The This class takes flattened 1D query, key, and value tensors as input. The
input 1D tensors can be split into three parts: the prompt tokens, the input 1D tensors can either contain prompt tokens or generation tokens, in
generation tokens, and the paddings. addition to paddings.
|<------------------------------------- num_valid_tokens ------------------------------------->| If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->| |<---------------------- num_valid_tokens ---------------------->|
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
Otherwise, the layout is as follows:
|<------------------ num_valid_tokens ------------------->|
|<------- num_generation_tokens (M) ------->|
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
The prompts might have different lengths, while the generation tokens always The prompts might have different lengths, while the generation tokens always
have length 1. The paddings are appended to make the input length a multiple have length 1. The paddings are appended to make the input length a multiple
@@ -41,34 +52,73 @@ class PagedAttention(nn.Module):
5. Output a flattened 1D tensor. 5. Output a flattened 1D tensor.
""" """
def __init__(self, num_heads: int, head_size: int, scale: float) -> None: def __init__(self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None) -> None:
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp() self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.head_mapping = torch.repeat_interleave(
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
self.num_queries_per_kv)
if self.head_size not in _SUPPORTED_HEAD_SIZES: if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f"head_size ({self.head_size}) is not supported. " raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
def set_attn_bias(
self,
input_metadata: InputMetadata,
dtype: torch.dtype,
) -> None:
del dtype # Unused.
if input_metadata.attn_bias:
# Already set by a previous layer.
return
prompt_lens = input_metadata.prompt_lens
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
input_metadata.attn_bias.append(attn_bias)
def multi_query_kv_attention( def multi_query_kv_attention(
self, self,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] output: torch.Tensor,
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] query: torch.Tensor,
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] key: torch.Tensor,
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] value: torch.Tensor,
attn_bias: xops.AttentionBias, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
"""Normal attention for the prompt tokens.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""
if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads.
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=1)
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward( out = xops.memory_efficient_attention_forward(
query.unsqueeze(0), query.unsqueeze(0),
key.unsqueeze(0), key.unsqueeze(0),
value.unsqueeze(0), value.unsqueeze(0),
attn_bias=attn_bias, attn_bias=input_metadata.attn_bias[0],
p=0.0, p=0.0,
scale=self.scale, scale=self.scale,
op=self.attn_op,
) )
# TODO(woosuk): Unnecessary copy. Optimize. # TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.squeeze(0)) output.copy_(out.squeeze(0))
@@ -76,42 +126,72 @@ class PagedAttention(nn.Module):
def single_query_cached_kv_attention( def single_query_cached_kv_attention(
self, self,
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size] output: torch.Tensor,
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size] query: torch.Tensor,
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] key_cache: torch.Tensor,
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] value_cache: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> None: ) -> None:
"""PagedAttention for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
"""
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention( attention_ops.single_query_cached_kv_attention(
output, output,
query, query,
key_cache, key_cache,
value_cache, value_cache,
self.head_mapping,
self.scale, self.scale,
input_metadata.block_tables, input_metadata.block_tables,
input_metadata.context_lens, input_metadata.context_lens,
block_size, block_size,
input_metadata.max_context_len, input_metadata.max_context_len,
None, # alibi_slopes
) )
def forward( def forward(
self, self,
query: torch.Tensor, # [num_tokens, num_heads * head_size] query: torch.Tensor,
key: torch.Tensor, # [num_tokens, num_heads * head_size] key: torch.Tensor,
value: torch.Tensor, # [num_tokens, num_heads * head_size] value: torch.Tensor,
key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x] key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size] value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size] ) -> torch.Tensor:
# NOTE: The query, key, and value tensors must be sliced from a qkv """PagedAttention forward pass.
# tensor of shape [num_tokens, 3 * num_heads * head_size].
NOTE: The query, key, and value tensors must be sliced from a qkv
tensor of shape [num_tokens, 3 * num_heads * head_size].
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
# Pre-allocate the output tensor. # Pre-allocate the output tensor.
output = torch.empty_like(query) output = torch.empty_like(query)
@@ -119,12 +199,15 @@ class PagedAttention(nn.Module):
# Compute the attention op for prompts. # Compute the attention op for prompts.
num_prompt_tokens = input_metadata.num_prompt_tokens num_prompt_tokens = input_metadata.num_prompt_tokens
if num_prompt_tokens > 0: if num_prompt_tokens > 0:
# Prompt run.
assert input_metadata.num_generation_tokens == 0
self.set_attn_bias(input_metadata, dtype=query.dtype)
self.multi_query_kv_attention( self.multi_query_kv_attention(
output[:num_prompt_tokens], output[:num_prompt_tokens],
query[:num_prompt_tokens], query[:num_prompt_tokens],
key[:num_prompt_tokens], key[:num_prompt_tokens],
value[:num_prompt_tokens], value[:num_prompt_tokens],
input_metadata.attn_bias, input_metadata,
) )
# Wait until the cache op is done. # Wait until the cache op is done.
@@ -136,7 +219,7 @@ class PagedAttention(nn.Module):
# and value vectors will not be cached. # and value vectors will not be cached.
num_valid_tokens = input_metadata.num_valid_tokens num_valid_tokens = input_metadata.num_valid_tokens
if (num_valid_tokens > 0 and key_cache is not None if (num_valid_tokens > 0 and key_cache is not None
and value_cache is not None): and value_cache is not None):
# The stride is 3 because the key and value are sliced from qkv. # The stride is 3 because the key and value are sliced from qkv.
cache_ops.reshape_and_cache( cache_ops.reshape_and_cache(
key[:num_valid_tokens], key[:num_valid_tokens],
@@ -147,17 +230,16 @@ class PagedAttention(nn.Module):
) )
if input_metadata.num_generation_tokens > 0: if input_metadata.num_generation_tokens > 0:
# Decoding run.
assert input_metadata.num_prompt_tokens == 0
assert key_cache is not None and value_cache is not None, ( assert key_cache is not None and value_cache is not None, (
"key_cache and value_cache must be provided when " "key_cache and value_cache must be provided when "
"generating tokens." "generating tokens.")
)
# Compute the attention op for generation tokens. # Compute the attention op for generation tokens.
self.single_query_cached_kv_attention( self.single_query_cached_kv_attention(
output[num_prompt_tokens:num_valid_tokens], output[num_prompt_tokens:num_valid_tokens],
query[num_prompt_tokens:num_valid_tokens], query[num_prompt_tokens:num_valid_tokens], key_cache,
key_cache, value_cache, input_metadata)
value_cache,
input_metadata)
# Reshape the output tensor. # Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings. # NOTE(woosuk): The output tensor may include paddings.
@@ -165,7 +247,7 @@ class PagedAttention(nn.Module):
class PagedAttentionWithRoPE(PagedAttention): class PagedAttentionWithRoPE(PagedAttention):
"""PagedAttention with GPT-NeoX style rotary embedding.""" """PagedAttention with rotary embedding."""
def __init__( def __init__(
self, self,
@@ -175,19 +257,24 @@ class PagedAttentionWithRoPE(PagedAttention):
rotary_dim: int, rotary_dim: int,
max_position: int = 8192, max_position: int = 8192,
base: int = 10000, base: int = 10000,
num_kv_heads: Optional[int] = None,
is_neox_style: bool = True,
) -> None: ) -> None:
super().__init__(num_heads, head_size, scale) super().__init__(num_heads, head_size, scale, num_kv_heads)
self.is_neox_style = is_neox_style
# Create the cos and sin cache. # Create the cos and sin cache.
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim)) inv_freq = 1.0 / (base**(torch.arange(
t = torch.arange(max_position).float() 0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) t = torch.arange(max_position, dtype=torch.float, device="cuda")
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
# FIXME(woosuk): This assumes that we configure the default dtype when # FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model. Make it more robust. # initializing the model.
# TODO(woosuk): Make it more robust.
torch_dtype = torch.get_default_dtype() torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype) cache = cache.to(torch_dtype)
# Embedding size: [max_position, rotary_dim] # Embedding size: [max_position, rotary_dim]
@@ -195,23 +282,42 @@ class PagedAttentionWithRoPE(PagedAttention):
def forward( def forward(
self, self,
positions: torch.Tensor, # [num_tokens] positions: torch.Tensor,
query: torch.Tensor, # [num_tokens, num_heads * head_size] query: torch.Tensor,
key: torch.Tensor, # [num_tokens, num_heads * head_size] key: torch.Tensor,
value: torch.Tensor, # [num_tokens, num_heads * head_size] value: torch.Tensor,
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] key_cache: torch.Tensor,
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] value_cache: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size] ) -> torch.Tensor:
""" PagedAttention forward pass with rotary embedding.
Args:
positions: shape = [num_tokens]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Apply rotary embedding to the query and key before passing them # Apply rotary embedding to the query and key before passing them
# to the attention op. # to the attention op.
pos_encoding_ops.rotary_embedding_neox( pos_encoding_ops.rotary_embedding(
positions, positions,
query, query,
key, key,
self.head_size, self.head_size,
self.cos_sin_cache, self.cos_sin_cache,
self.is_neox_style,
) )
return super().forward( return super().forward(
query, query,
@@ -222,3 +328,127 @@ class PagedAttentionWithRoPE(PagedAttention):
input_metadata, input_metadata,
cache_event, cache_event,
) )
class PagedAttentionWithALiBi(PagedAttention):
"""PagedAttention with ALiBi attention bias."""
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
slopes: List[float],
num_kv_heads: Optional[int] = None) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads)
assert len(slopes) == num_heads
slopes = torch.tensor(slopes, dtype=torch.float32)
self.register_buffer("alibi_slopes", slopes, persistent=False)
def set_attn_bias(self, input_metadata: InputMetadata,
dtype: torch.dtype) -> None:
if input_metadata.attn_bias:
# Already set by a previous layer.
return
# Generates ALiBi mask for each prompt.
for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len, dtype=dtype)
# Note(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
bias = bias.to(self.alibi_slopes.device)
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (prompt_len + 7) // 8 * 8
bias = torch.empty(
1, # batch_size
self.num_heads,
prompt_len,
padded_len,
device=self.alibi_slopes.device,
dtype=dtype,
)[:, :, :, :prompt_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None])
attn_bias = LowerTriangularMaskWithTensorBias(bias)
input_metadata.attn_bias.append(attn_bias)
def multi_query_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
"""Attention with ALiBi bias for the prompt tokens.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""
if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads.
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=1)
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
start = 0
for i, prompt_len in enumerate(input_metadata.prompt_lens):
end = start + prompt_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=input_metadata.attn_bias[i],
p=0.0,
scale=self.scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.squeeze(0))
start += prompt_len
return output
def single_query_cached_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
) -> None:
"""PagedAttention with ALiBi bias for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
"""
block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
self.head_mapping,
self.scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
self.alibi_slopes,
)

View File

@@ -9,7 +9,9 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
gather_from_tensor_model_parallel_region) gather_from_tensor_model_parallel_region)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput, SequenceOutputs
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module): class Sampler(nn.Module):
@@ -36,12 +38,15 @@ class Sampler(nn.Module):
embedding: torch.Tensor, embedding: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> Dict[int, SequenceOutputs]: embedding_bias: Optional[torch.Tensor] = None,
) -> SamplerOutput:
# Get the hidden states that we use for sampling. # Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, input_metadata) hidden_states = _prune_hidden_states(hidden_states, input_metadata)
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t()) logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits) logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab (if any). # Remove paddings in vocab (if any).
logits = logits[:, :self.vocab_size] logits = logits[:, :self.vocab_size]
@@ -49,34 +54,37 @@ class Sampler(nn.Module):
# Apply presence and frequency penalties. # Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata) output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0] assert len(output_tokens) == logits.shape[0]
presence_penalties, frequency_penalties = _get_penalties(input_metadata) presence_penalties, frequency_penalties = _get_penalties(
input_metadata)
assert len(presence_penalties) == logits.shape[0] assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0]
logits = _apply_penalties( logits = _apply_penalties(logits, output_tokens, presence_penalties,
logits, output_tokens, presence_penalties, frequency_penalties, frequency_penalties, self.vocab_size)
self.vocab_size)
# Apply temperature scaling. # Apply temperature scaling.
temperatures = _get_temperatures(input_metadata) temperatures = _get_temperatures(input_metadata)
assert len(temperatures) == logits.shape[0] assert len(temperatures) == logits.shape[0]
if any(t != 1.0 for t in temperatures): if any(t != 1.0 for t in temperatures):
t = torch.tensor( t = torch.tensor(temperatures,
temperatures, dtype=logits.dtype, device=logits.device) dtype=logits.dtype,
device=logits.device)
# Use in-place division to avoid creating a new tensor. # Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1)) logits.div_(t.unsqueeze(dim=1))
# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == logits.shape[0]
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k:
logits = _apply_top_p_top_k(logits, top_ps, top_ks)
# We use float32 for probabilities and log probabilities. # We use float32 for probabilities and log probabilities.
# Compute the probabilities. # Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p and top-k). # Compute the log probabilities (before applying top-p and top-k).
logprobs = torch.log(probs) logprobs = torch.log(probs)
# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == probs.shape[0]
if any(p < 1.0 for p in top_ps) or any(k != self.vocab_size for k in top_ks):
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
# Sample the next tokens. # Sample the next tokens.
return _sample(probs, logprobs, input_metadata) return _sample(probs, logprobs, input_metadata)
@@ -92,12 +100,12 @@ def _prune_hidden_states(
start_idx += prompt_len start_idx += prompt_len
last_token_indicies.extend( last_token_indicies.extend(
range(start_idx, start_idx + input_metadata.num_generation_tokens)) range(start_idx, start_idx + input_metadata.num_generation_tokens))
return hidden_states[last_token_indicies] return hidden_states.index_select(
0, torch.tensor(last_token_indicies, device=hidden_states.device))
def _get_penalties( def _get_penalties(
input_metadata: InputMetadata, input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
) -> Tuple[List[float], List[float]]:
# Collect the presence and frequency penalties. # Collect the presence and frequency penalties.
presence_penalties: List[float] = [] presence_penalties: List[float] = []
frequency_penalties: List[float] = [] frequency_penalties: List[float] = []
@@ -116,9 +124,7 @@ def _get_penalties(
return presence_penalties, frequency_penalties return presence_penalties, frequency_penalties
def _get_output_tokens( def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
input_metadata: InputMetadata,
) -> List[List[int]]:
output_tokens: List[List[int]] = [] output_tokens: List[List[int]] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, _ = seq_group seq_ids, _ = seq_group
@@ -152,7 +158,7 @@ def _apply_penalties(
continue continue
p = presence_penalties[i] p = presence_penalties[i]
f = frequency_penalties[i] f = frequency_penalties[i]
if p == 0.0 and f == 0.0: if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
continue continue
indices.append(i) indices.append(i)
@@ -168,11 +174,13 @@ def _apply_penalties(
device=logits.device) device=logits.device)
frequency_penalties = [frequency_penalties[i] for i in indices] frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor( frequency_penalties = torch.tensor(frequency_penalties,
frequency_penalties, dtype=logits.dtype, device=logits.device) dtype=logits.dtype,
device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices] presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor( presence_penalties = torch.tensor(presence_penalties,
presence_penalties, dtype=logits.dtype, device=logits.device) dtype=logits.dtype,
device=logits.device)
# We follow the definition in OpenAI API. # We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details # Refer to https://platform.openai.com/docs/api-reference/parameter-details
@@ -182,15 +190,13 @@ def _apply_penalties(
return logits return logits
def _get_temperatures( def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
input_metadata: InputMetadata,
) -> List[float]:
# Collect the temperatures for the logits. # Collect the temperatures for the logits.
temperatures: List[float] = [] temperatures: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature temperature = sampling_params.temperature
if temperature == 0.0: if temperature < _SAMPLING_EPS:
# NOTE: Zero temperature means deterministic sampling # NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search). # (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero. # Set the temperature to 1 to avoid division by zero.
@@ -230,30 +236,32 @@ def _get_top_p_top_k(
def _apply_top_p_top_k( def _apply_top_p_top_k(
probs: torch.Tensor, logits: torch.Tensor,
top_ps: List[float], top_ps: List[float],
top_ks: List[int], top_ks: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device) p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device) k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
probs_sort, probs_idx = probs.sort(dim=-1, descending=True) logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
# Apply top-p. # Apply top-p.
probs_sum = torch.cumsum(probs_sort, dim=-1) probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
probs_sort[top_p_mask] = 0.0 logits_sort[top_p_mask] = -float("inf")
# Apply top-k. # Apply top-k.
# Create a mask for the top-k elements. # Create a mask for the top-k elements.
top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1) top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
top_k_mask = top_k_mask >= k.unsqueeze(dim=1) top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
probs_sort[top_k_mask] = 0.0 logits_sort[top_k_mask] = -float("inf")
# Re-sort the probabilities. # Re-sort the probabilities.
probs = torch.gather( logits = torch.gather(logits_sort,
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1)) dim=-1,
return probs index=torch.argsort(logits_idx, dim=-1))
return logits
def _get_topk_logprobs( def _get_topk_logprobs(
@@ -284,9 +292,15 @@ def _sample_from_prompt(
if sampling_params.use_beam_search: if sampling_params.use_beam_search:
# Beam search. # Beam search.
beam_width = sampling_params.best_of beam_width = sampling_params.best_of
_, next_token_ids = torch.topk(prob, beam_width) # Sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
# for details. See also HF reference:
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
_, next_token_ids = torch.topk(prob, 2 * beam_width)
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
elif sampling_params.temperature == 0.0: elif sampling_params.temperature < _SAMPLING_EPS:
# Greedy sampling. # Greedy sampling.
assert sampling_params.best_of == 1 assert sampling_params.best_of == 1
next_token_id = torch.argmax(prob) next_token_id = torch.argmax(prob)
@@ -295,8 +309,9 @@ def _sample_from_prompt(
# Random sampling. # Random sampling.
# Sample `best_of` tokens for the prompt. # Sample `best_of` tokens for the prompt.
num_seqs = sampling_params.best_of num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial( next_token_ids = torch.multinomial(prob,
prob, num_samples=num_seqs, replacement=True) num_samples=num_seqs,
replacement=True)
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
return next_token_ids return next_token_ids
@@ -314,36 +329,19 @@ def _sample_from_generation_tokens(
if sampling_params.use_beam_search: if sampling_params.use_beam_search:
# Beam search. # Beam search.
# Add cumulative logprobs for the sequences in the group. # Add cumulative logprobs for the sequences in the group.
seq_logprobs = torch.tensor( seq_logprobs = torch.tensor(seq_logprobs,
seq_logprobs, dtype=torch.float, device=logprobs.device) dtype=torch.float,
device=logprobs.device)
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1) logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
vocab_size = logprobs.size(-1) vocab_size = logprobs.size(-1)
beam_width = len(seq_ids) beam_width = len(seq_ids)
_, topk_ids = torch.topk(logprobs.flatten(), beam_width) _, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
topk_ids = topk_ids.tolist() topk_ids = topk_ids.tolist()
seq_idx = [i // vocab_size for i in topk_ids] seq_idx = [i // vocab_size for i in topk_ids]
beam_seq_ids = [seq_ids[i] for i in seq_idx] parent_seq_ids = [seq_ids[i] for i in seq_idx]
token_ids = [i % vocab_size for i in topk_ids] next_token_ids = [i % vocab_size for i in topk_ids]
elif sampling_params.temperature < _SAMPLING_EPS:
beam_outputs: Dict[int, Tuple[int, int]] = {}
outstanding_beams: List[Tuple[int, int]] = []
# If a beam survives, continue with it.
for seq_id, token_id in zip(beam_seq_ids, token_ids):
if seq_id not in beam_outputs:
beam_outputs[seq_id] = (seq_id, token_id)
else:
outstanding_beams.append((seq_id, token_id))
# If a beam is discarded, fork another beam.
for seq_id in seq_ids:
if seq_id not in beam_outputs:
beam_outputs[seq_id] = outstanding_beams.pop()
assert not outstanding_beams
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
elif sampling_params.temperature == 0.0:
# Greedy sampling. # Greedy sampling.
assert len(seq_ids) == 1 assert len(seq_ids) == 1
next_token_id = torch.argmax(probs, dim=-1) next_token_id = torch.argmax(probs, dim=-1)
@@ -352,8 +350,9 @@ def _sample_from_generation_tokens(
else: else:
# Random sampling. # Random sampling.
# Sample 1 token for each sequence in the group. # Sample 1 token for each sequence in the group.
next_token_ids = torch.multinomial( next_token_ids = torch.multinomial(probs,
probs, num_samples=1, replacement=True) num_samples=1,
replacement=True)
next_token_ids = next_token_ids.squeeze(dim=-1).tolist() next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
parent_seq_ids = seq_ids parent_seq_ids = seq_ids
return parent_seq_ids, next_token_ids return parent_seq_ids, next_token_ids
@@ -363,16 +362,18 @@ def _sample(
probs: torch.Tensor, probs: torch.Tensor,
logprobs: torch.Tensor, logprobs: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
seq_outputs: Dict[int, SequenceOutputs] = {} seq_outputs: SamplerOutput = []
# TODO(woosuk): Optimize. # TODO(woosuk): Optimize.
idx = 0 idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_group_outputs: List[SequenceOutputs] = []
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts: if i < input_metadata.num_prompts:
# Generate the next tokens for a prompt input. # Generate the next tokens for a prompt input.
assert len(seq_ids) == sampling_params.best_of assert len(seq_ids) == 1, "Prompt input should have only one seq."
parent_seq_id = seq_ids[0]
prob = probs[idx] prob = probs[idx]
logprob = logprobs[idx] logprob = logprobs[idx]
idx += 1 idx += 1
@@ -380,45 +381,47 @@ def _sample(
# Sample the next tokens. # Sample the next tokens.
next_token_ids = _sample_from_prompt(prob, sampling_params) next_token_ids = _sample_from_prompt(prob, sampling_params)
# Get top-k log probabilities for the next tokens. # Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs( next_logprobs = _get_topk_logprobs(logprob,
logprob, sampling_params.logprobs) sampling_params.logprobs)
# Build the output. # Build the output.
for seq_id, next_token_id in zip(seq_ids, next_token_ids): for next_token_id in next_token_ids:
output_logprobs = next_logprobs.copy() output_logprobs = next_logprobs.copy()
output_logprobs[next_token_id] = logprob[next_token_id].item() output_logprobs[next_token_id] = logprob[next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs( seq_group_outputs.append(
seq_id, seq_id, next_token_id, output_logprobs) SequenceOutputs(parent_seq_id, next_token_id,
output_logprobs))
else: else:
# Generate the next tokens for generation tokens. # Generate the next tokens for generation tokens.
prob = probs[idx:idx + len(seq_ids)] num_parent_seqs = len(seq_ids)
logprob = logprobs[idx:idx + len(seq_ids)] prob = probs[idx:idx + num_parent_seqs]
idx += len(seq_ids) logprob = logprobs[idx:idx + num_parent_seqs]
idx += num_parent_seqs
# Sample the next tokens. # Sample the next tokens.
seq_logprobs = [ seq_logprobs = [
input_metadata.seq_data[seq_id].cumulative_logprob input_metadata.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids] for seq_id in seq_ids
]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens( parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params) seq_ids, prob, logprob, seq_logprobs, sampling_params)
# Get top-k log probabilities for the next tokens. # Get top-k log probabilities for the next tokens.
next_logprobs: Dict[int, Dict[int, float]] = {} next_logprobs: Dict[int, Dict[int, float]] = {}
for i, seq_id in enumerate(seq_ids): for j, seq_id in enumerate(seq_ids):
next_logprobs[seq_id] = _get_topk_logprobs( next_logprobs[seq_id] = _get_topk_logprobs(
logprob[i], sampling_params.logprobs) logprob[j], sampling_params.logprobs)
# Build the output. # Build the output.
for seq_id, parent_seq_id, next_token_id in zip( for parent_seq_id, next_token_id in zip(parent_seq_ids,
seq_ids, parent_seq_ids, next_token_ids): next_token_ids):
i = seq_ids.index(parent_seq_id) j = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy() output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[i, next_token_id].item() output_logprobs[next_token_id] = logprob[j,
seq_outputs[seq_id] = SequenceOutputs( next_token_id].item()
seq_id, seq_group_outputs.append(
parent_seq_id, SequenceOutputs(parent_seq_id, next_token_id,
next_token_id, output_logprobs))
output_logprobs, seq_outputs.append(seq_group_outputs)
)
return seq_outputs return seq_outputs

View File

@@ -1,4 +1,5 @@
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
import contextlib
from typing import Type from typing import Type
import torch import torch
@@ -6,19 +7,39 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.models import (GPT2LMHeadModel, GPTNeoXForCausalLM, from vllm.model_executor.models import * # pylint: disable=wildcard-import
LlamaForCausalLM, OPTForCausalLM)
from vllm.model_executor.weight_utils import initialize_dummy_weights from vllm.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes. # TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = { _MODEL_REGISTRY = {
"AquilaModel": AquilaForCausalLM,
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
"BloomForCausalLM": BloomForCausalLM,
"FalconForCausalLM": FalconForCausalLM,
"GPT2LMHeadModel": GPT2LMHeadModel, "GPT2LMHeadModel": GPT2LMHeadModel,
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
"GPTJForCausalLM": GPTJForCausalLM,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM, "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"InternLMForCausalLM": InternLMForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM, "LlamaForCausalLM": LlamaForCausalLM,
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MPTForCausalLM": MPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM, "OPTForCausalLM": OPTForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel,
"RWForCausalLM": FalconForCausalLM,
} }
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", []) architectures = getattr(config, "architectures", [])
for arch in architectures: for arch in architectures:
@@ -26,26 +47,23 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
return _MODEL_REGISTRY[arch] return _MODEL_REGISTRY[arch]
raise ValueError( raise ValueError(
f"Model architectures {architectures} are not supported for now. " f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}" f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
)
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config) model_class = _get_model_architecture(model_config.hf_config)
torch.set_default_dtype(model_config.dtype) with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# Create a model instance. # The weights will be initialized as empty tensors.
# The weights will be initialized as empty tensors. model = model_class(model_config.hf_config)
model = model_class(model_config.hf_config) if model_config.load_format == "dummy":
if model_config.use_dummy_weights: model = model.cuda()
model = model.cuda() # NOTE(woosuk): For accurate performance evaluation, we assign
# NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights.
# random values to the weights. initialize_dummy_weights(model)
initialize_dummy_weights(model) else:
else: # Load the weights from the cached or downloaded files.
# Load the weights from the cached or downloaded files. model.load_weights(model_config.model, model_config.download_dir,
model.load_weights( model_config.load_format)
model_config.model, model_config.download_dir, model = model.cuda()
model_config.use_np_weights)
model = model.cuda()
return model.eval() return model.eval()

View File

@@ -1,12 +1,31 @@
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM from vllm.model_executor.models.aquila import AquilaForCausalLM
from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM,
BaichuanForCausalLM)
from vllm.model_executor.models.bloom import BloomForCausalLM
from vllm.model_executor.models.falcon import FalconForCausalLM
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from vllm.model_executor.models.internlm import InternLMForCausalLM
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.mpt import MPTForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.models.qwen import QWenLMHeadModel
__all__ = [ __all__ = [
"AquilaForCausalLM",
"BaiChuanForCausalLM",
"BaichuanForCausalLM",
"BloomForCausalLM",
"FalconForCausalLM",
"GPT2LMHeadModel", "GPT2LMHeadModel",
"GPTBigCodeForCausalLM",
"GPTJForCausalLM",
"GPTNeoXForCausalLM", "GPTNeoXForCausalLM",
"InternLMForCausalLM",
"LlamaForCausalLM", "LlamaForCausalLM",
"MPTForCausalLM",
"OPTForCausalLM", "OPTForCausalLM",
"QWenLMHeadModel",
] ]

View File

@@ -0,0 +1,357 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import List, Optional, Tuple
import torch
from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.aquila import AquilaConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class AquilaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class AquilaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
AquilaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1,
keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
class AquilaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
):
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(
hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
class AquilaDecoderLayer(nn.Module):
def __init__(self, config: AquilaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = AquilaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_attention_heads,
)
self.mlp = AquilaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = AquilaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class AquilaModel(nn.Module):
def __init__(self, config: AquilaConfig):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
#vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([
AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class AquilaForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.model = AquilaModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = [
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.num_attention_heads // tp_size)
attention_weight_specs = [
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
("k_proj", kv_proj_shard_size, q_proj_shard_size),
("v_proj", kv_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size),
]
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
if "rotary_emb.inv_freq" in name:
continue
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "qkv_proj")]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)

View File

@@ -0,0 +1,374 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
import math
from typing import List, Optional, Tuple
import torch
from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
PagedAttentionWithALiBi)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class BaiChuanMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class BaiChuanAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
hidden_size: int,
num_heads: int,
position_embedding: str,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.postion_embedding = position_embedding
# pylint: disable=invalid-name
self.W_pack = ColumnParallelLinear(
hidden_size,
3 * hidden_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
scaling, alibi_slopes)
else:
self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
if self.postion_embedding == "ALIBI":
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
else:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
class BaiChuanDecoderLayer(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = BaiChuanAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
position_embedding=position_embedding,
)
self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class BaiChuanModel(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class BaiChuanBaseForCausalLM(nn.Module):
def __init__(self, config, position_embedding: str):
super().__init__()
self.config = config
self.model = BaiChuanModel(config, position_embedding)
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = []
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
if "rotary_emb.inv_freq" in name:
continue
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
if "W_pack" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tp_rank)
continue
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
def __init__(self, config):
super().__init__(config, "ALIBI")
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
def __init__(self, config):
super().__init__(config, "ROPE")

View File

@@ -0,0 +1,324 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
# Copyright 2023 The CacheFlow team.
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
import math
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import BloomConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class BloomAttention(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.total_num_heads = config.n_head
self.head_dim = self.hidden_size // self.total_num_heads
assert self.head_dim * self.total_num_heads == self.hidden_size
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
self.query_key_value = ColumnParallelLinear(
self.hidden_size,
3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False,
)
# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
scaling, alibi_slopes)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
del position_ids # Unused.
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
output, _ = self.dense(attn_output)
return output
class BloomMLP(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
4 * hidden_size,
gather_output=False,
perform_initialization=False)
self.act = get_act_fn("gelu")
self.dense_4h_to_h = RowParallelLinear(4 * hidden_size,
hidden_size,
input_is_parallel=True,
perform_initialization=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.dense_h_to_4h(x)
x = self.act(x)
x, _ = self.dense_4h_to_h(x)
return x
class BloomBlock(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config)
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config)
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Layer norm post the self attention.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self attention.
attention_output = self.self_attention(
position_ids=position_ids,
hidden_states=layernorm_output,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
attention_output = attention_output + residual
layernorm_output = self.post_attention_layernorm(attention_output)
# Get residual
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# MLP.
output = self.mlp(layernorm_output) + residual
return output
class BloomModel(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
self.embed_dim = config.hidden_size
# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size, self.embed_dim, perform_initialization=False)
self.word_embeddings_layernorm = nn.LayerNorm(
self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks
self.h = nn.ModuleList(
[BloomBlock(config) for _ in range(config.num_hidden_layers)])
# Final Layer Norm
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
for i in range(len(self.h)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class BloomForCausalLM(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
self.config = config
self.transformer = BloomModel(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.lm_head_weight = self.transformer.word_embeddings.weight
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = [
"word_embeddings.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
if name == "lm_head.weight":
# Since hidden_states are parallelized, we need to
# load lm_head.weight in parallel.
self._column_parallel_weights.append(name)
# If lm_head is provided, use it instead.
param = self.lm_head_weight
else:
if not name.startswith("transformer."):
name = "transformer." + name
param = state_dict[name]
if "query_key_value" in name:
# NOTE(woosuk): BLOOM's fused QKV has the shape of
# [num_heads * 3 * head_size, hidden_size], while the
# required shape is [3 * num_heads * head_size, hidden_size].
# Thus, we need weight conversion.
shard_size = param.shape[0]
start = shard_size * tp_rank
end = shard_size * (tp_rank + 1)
loaded_weight = loaded_weight[start:end]
num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // num_heads
if "query_key_value.weight" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size,
hidden_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif "query_key_value.bias" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected weight name: {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)

View File

@@ -0,0 +1,498 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
# Copyright 2023 The vLLM team.
# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Falcon model."""
import math
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import (PagedAttention,
PagedAttentionWithALiBi,
PagedAttentionWithRoPE)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
reduce_from_tensor_model_parallel_region)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import RWConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
FalconConfig = Union[HF_FalconConfig, RWConfig]
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during
# training, this means that there's one additional quantization to bfloat16
# between the operations. In order not to degrade the quality of our HF-port,
# we keep these characteristics in the final model.
class FalconLinear(nn.Linear):
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden_states = x @ self.weight.T
if self.bias is None:
return hidden_states
return hidden_states + self.bias
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(1,
1 + 2 * num_remaining_heads,
2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class FalconAttention(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.head_dim = self.hidden_size // self.total_num_heads
assert self.head_dim * self.total_num_heads == self.hidden_size
self.new_decoder_architecture = config.new_decoder_architecture
self.multi_query = config.multi_query
if self.new_decoder_architecture:
self.total_num_kv_heads = config.num_kv_heads
assert self.total_num_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.query_key_value = ColumnParallelLinear(
self.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True,
)
elif self.multi_query:
self.total_num_kv_heads = 1
self.num_kv_heads = 1
self.query = ColumnParallelLinear(
self.hidden_size,
self.total_num_heads * self.head_dim,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True,
)
self.key_value = FalconLinear(self.hidden_size,
2 * self.head_dim,
bias=config.bias)
else:
self.total_num_kv_heads = self.total_num_heads
self.num_kv_heads = self.num_heads
self.query_key_value = ColumnParallelLinear(
self.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True,
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=config.bias,
input_is_parallel=True,
perform_initialization=False,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results)
self.use_rotary = config.rotary
self.use_alibi = config.alibi
assert not (self.use_rotary and self.use_alibi), (
"Rotary and alibi are mutually exclusive.")
if self.use_rotary:
# TODO(zhuohan): Pass in correct `max_position``
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.inv_norm_factor,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
self.inv_norm_factor)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.attn = PagedAttentionWithALiBi(self.num_heads,
self.head_dim,
self.inv_norm_factor,
alibi_slopes,
num_kv_heads=self.num_kv_heads)
else:
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
if not self.new_decoder_architecture and self.multi_query:
q, bias = self.query(hidden_states)
if bias is not None:
q += bias
kv = self.key_value(hidden_states)
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
else:
qkv, bias = self.query_key_value(hidden_states)
if bias is not None:
qkv += bias
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
k_cache, v_cache = kv_cache
if self.use_rotary:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
else:
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output, bias = self.dense(attn_output)
return attn_output, bias
class FalconMLP(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
4 * hidden_size,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True)
self.act = nn.GELU()
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
bias=config.bias,
input_is_parallel=True,
perform_initialization=False,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
x, bias = self.dense_h_to_4h(x)
if bias is not None:
x += bias
x = self.act(x)
x, bias = self.dense_4h_to_h(x)
return x, bias
class FalconDecoderLayer(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config)
self.mlp = FalconMLP(config)
self.config = config
if config.new_decoder_architecture:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
self.input_layernorm = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
):
residual = hidden_states
if self.config.new_decoder_architecture:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
attention_layernorm_out = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = self.self_attention(
positions=positions,
hidden_states=attention_layernorm_out,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
if self.reduce_row_parallel_results and attention_bias is not None:
attention_output += attention_bias
if not self.config.new_decoder_architecture:
if self.config.parallel_attn:
mlp_layernorm_out = attention_layernorm_out
else:
residual += attention_output
mlp_layernorm_out = self.post_attention_layernorm(residual)
# MLP.
mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
if self.reduce_row_parallel_results and mlp_bias is not None:
mlp_output += mlp_bias
if not self.reduce_row_parallel_results:
# When MLP and Attention layers are parallel, we can use
# only one all-reduce operator to reduce the results from
# both MLP and Attention layers.
mlp_output += attention_output
mlp_output = reduce_from_tensor_model_parallel_region(mlp_output)
if attention_bias is not None:
mlp_output += attention_bias
if mlp_bias is not None:
mlp_output += mlp_bias
output = mlp_output + residual
return output
class FalconModel(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.use_alibi = config.alibi
# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size, self.embed_dim, perform_initialization=False)
# Transformer blocks
self.h = nn.ModuleList([
FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
for i in range(len(self.h)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class FalconForCausalLM(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
self.config = config
self.transformer = FalconModel(config)
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.transformer(
input_ids,
positions,
kv_caches,
input_metadata,
cache_events,
)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = [
"word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight",
"dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
tp_size = (get_tensor_model_parallel_world_size())
tp_rank = get_tensor_model_parallel_rank()
hidden_size = self.config.hidden_size
total_num_heads = self.config.num_attention_heads
num_heads = total_num_heads // tp_size
head_size = hidden_size // total_num_heads
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads
num_kv_heads = total_num_kv_heads // tp_size
separated_q_kv = False
kv_head_start = tp_rank * num_kv_heads
kv_head_end = (tp_rank + 1) * num_kv_heads
elif self.config.multi_query:
total_num_kv_heads = 1
num_kv_heads = 1
separated_q_kv = True
kv_head_start = 0
kv_head_end = 1
else:
total_num_kv_heads = total_num_heads
num_kv_heads = total_num_kv_heads // tp_size
separated_q_kv = False
kv_head_start = tp_rank * num_kv_heads
kv_head_end = (tp_rank + 1) * num_kv_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
if "query_key_value" in name:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
loaded_weight_size = loaded_weight.size()
loaded_weight = loaded_weight.view(
total_num_kv_heads, num_query_heads_per_kv_head + 2,
head_size, *loaded_weight_size[1:])
wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:])
wk = loaded_weight[:, [-2]].reshape(-1,
*loaded_weight_size[1:])
wv = loaded_weight[:, [-1]].reshape(-1,
*loaded_weight_size[1:])
wq = wq[head_size * head_start:head_size * head_end]
wk = wk[head_size * kv_head_start:head_size * kv_head_end]
wv = wv[head_size * kv_head_start:head_size * kv_head_end]
if separated_q_kv:
loaded_weight_q = wq
loaded_weight_kv = torch.cat([wk, wv], dim=0)
q_weight_name = name.replace("query_key_value", "query")
kv_weight_name = name.replace("query_key_value",
"key_value")
load_tensor_parallel_weights(state_dict[q_weight_name],
loaded_weight_q,
q_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank)
load_tensor_parallel_weights(state_dict[kv_weight_name],
loaded_weight_kv,
kv_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank)
continue
else:
loaded_weight = torch.cat([wq, wk, wv], dim=0)
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)

View File

@@ -1,5 +1,6 @@
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
@@ -20,7 +21,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@@ -30,13 +31,14 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights) convert_pyslice_to_tensor, hf_model_weights_iterator,
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -47,19 +49,25 @@ class GPT2Attention(nn.Module):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0 assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim ** -0.5 self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, self.c_attn = ColumnParallelLinear(self.hidden_size,
bias=True, gather_output=False, 3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, self.c_proj = RowParallelLinear(self.hidden_size,
bias=True, input_is_parallel=True, self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale) scale=self.scale)
def forward( def forward(
@@ -72,8 +80,8 @@ class GPT2Attention(nn.Module):
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(q, k, v, key_cache, value_cache,
q, k, v, key_cache, value_cache, input_metadata, cache_event) input_metadata, cache_event)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
@@ -87,11 +95,15 @@ class GPT2MLP(nn.Module):
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size, self.c_fc = ColumnParallelLinear(hidden_size,
bias=True, gather_output=False, intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size, hidden_size, self.c_proj = RowParallelLinear(intermediate_size,
bias=True, input_is_parallel=True, hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.act = get_act_fn(config.activation_function) self.act = get_act_fn(config.activation_function)
@@ -107,7 +119,8 @@ class GPT2Block(nn.Module):
def __init__(self, config: GPT2Config): def __init__(self, config: GPT2Config):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config) self.attn = GPT2Attention(config)
@@ -145,9 +158,9 @@ class GPT2Model(nn.Module):
def __init__(self, config: GPT2Config): def __init__(self, config: GPT2Config):
super().__init__() super().__init__()
self.config = config self.config = config
assert config.add_cross_attention == False assert not config.add_cross_attention
assert config.scale_attn_by_inverse_layer_idx == False assert not config.scale_attn_by_inverse_layer_idx
assert config.reorder_and_upcast_attn == False assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it # Optimization: While the vocab size of GPT-2 is 50257, we extend it
@@ -180,8 +193,8 @@ class GPT2Model(nn.Module):
else: else:
cache_event = cache_events[i] cache_event = cache_events[i]
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
hidden_states, kv_caches[i], input_metadata, cache_event) cache_event)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
@@ -205,34 +218,40 @@ class GPT2LMHeadModel(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer( hidden_states = self.transformer(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.lm_head_weight, hidden_states,
self.lm_head_weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"] _column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"] _row_parallel_weights = ["c_proj.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): load_format: str = "auto"):
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache): model_name_or_path, cache_dir, load_format):
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
# linear layer. # linear layer.
continue continue
if ".attn.bias" in name: if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask. # Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped. # NOTE: "c_attn.bias" should not be skipped.
continue continue
name = "transformer." + name
if not name.startswith("transformer."):
name = "transformer." + name
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
# The HF's GPT-2 implementation uses Conv1D instead of Linear. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights. # Because of this, we need to transpose the weights.
@@ -245,17 +264,16 @@ class GPT2LMHeadModel(nn.Module):
param = state_dict[name] param = state_dict[name]
if name == "transformer.wte.weight": if name == "transformer.wte.weight":
# Consider padding in the vocab size. load_padded_tensor_parallel_vocab(param, loaded_weight,
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size tensor_model_parallel_rank)
num_extra_rows = padded_vocab_size - self.config.vocab_size continue
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
# For the fused QKV linear layer, manually shard the weights. # For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name: if "c_attn" in name:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size]. # GPT-2's fused QKV has the shape of
# When tensor parallelism is used, we shard the weights along the head dimension. # [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads head_size = hidden_size // total_num_heads
@@ -264,11 +282,13 @@ class GPT2LMHeadModel(nn.Module):
head_end = (tensor_model_parallel_rank + 1) * num_heads head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"): if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size) loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :] loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size) loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"): elif name.endswith(".bias"):
loaded_weight = loaded_weight.view(3, total_num_heads, head_size) loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :] loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1) loaded_weight = loaded_weight.reshape(-1)
else: else:

View File

@@ -0,0 +1,340 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import GPTBigCodeConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTBigCodeAttention(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
self.tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % self.tensor_model_parallel_world_size == 0
self.num_heads = (total_num_heads //
self.tensor_model_parallel_world_size)
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5
self.multi_query = config.multi_query
if self.multi_query:
self.num_kv_heads = 1
self.kv_dim = self.head_dim
self.c_attn_q = ColumnParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_attn_kv = nn.Linear(self.hidden_size,
2 * self.kv_dim,
bias=True)
else:
self.num_kv_heads = self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.c_attn = ColumnParallelLinear(self.hidden_size,
self.hidden_size +
2 * self.kv_dim,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale,
num_kv_heads=self.num_kv_heads)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
if self.multi_query:
q, _ = self.c_attn_q(hidden_states)
kv = self.c_attn_kv(hidden_states)
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
else:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split([
self.hidden_size // self.tensor_model_parallel_world_size,
self.kv_dim, self.kv_dim
],
dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
attn_output, _ = self.c_proj(attn_output)
return attn_output
class GPTBigMLP(nn.Module):
def __init__(
self,
intermediate_size: int,
config: GPTBigCodeConfig,
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.c_proj(hidden_states)
return hidden_states
class GPTBigCodeBlock(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
super().__init__()
hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# residual connection
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
return hidden_states
class GPTBigCodeModel(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
super().__init__()
self.config = config
assert not config.add_cross_attention
self.embed_dim = config.hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
# to 50304 in order to make it divisible by 64.
# This improves performance since GPUs are faster if the dimension
# is divisible by 64. In addition, it allows us to shard the embedding
# layer across 2, 4, 8, or more GPUs.
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList(
[GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
for i in range(len(self.h)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
cache_event)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPTBigCodeForCausalLM(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
super().__init__()
self.config = config
self.transformer = GPTBigCodeModel(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if not name.startswith("transformer."):
name = "transformer." + name
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
total_num_kv_heads = (1 if self.config.multi_query else
total_num_heads)
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
total_kv_size = head_size * total_num_kv_heads
num_heads = total_num_heads // tensor_model_parallel_world_size
head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
wq, wk, wv = torch.split(
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
dim=0)
wq = wq[head_size * head_start:head_size * head_end]
if not self.config.multi_query:
# Split the heads when using normal multi-head attention
wk = wk[head_size * head_start:head_size * head_end]
wv = wv[head_size * head_start:head_size * head_end]
loaded_weight = torch.cat([wq, wk, wv], dim=0)
else:
# For multi-query attention, we split the query
# but replicate the key and value.
loaded_weight_q = wq
loaded_weight_kv = torch.cat([wk, wv], dim=0)
q_weight_name = name.replace("c_attn", "c_attn_q")
kv_weight_name = name.replace("c_attn", "c_attn_kv")
load_tensor_parallel_weights(state_dict[q_weight_name],
loaded_weight_q,
q_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
load_tensor_parallel_weights(state_dict[kv_weight_name],
loaded_weight_kv,
kv_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
continue
param = state_dict[name]
if name == "transformer.wte.weight":
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)

View File

@@ -0,0 +1,254 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import GPTJConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTJAttention(nn.Module):
def __init__(self, config: GPTJConfig):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
self.qkv_proj = ColumnParallelLinear(config.hidden_size,
3 * config.hidden_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.out_proj = RowParallelLinear(config.hidden_size,
config.hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
scaling = self.head_size**-0.5
assert getattr(config, "rotary", True)
assert config.rotary_dim % 2 == 0
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_size,
scaling,
config.rotary_dim,
is_neox_style=False)
self.warmup = False
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
attn_output, _ = self.out_proj(attn_output)
return attn_output
class GPTJMLP(nn.Module):
def __init__(self, intermediate_size: int, config: GPTJConfig):
super().__init__()
hidden_size = config.n_embd
self.fc_in = ColumnParallelLinear(hidden_size,
intermediate_size,
gather_output=False,
perform_initialization=False)
self.fc_out = RowParallelLinear(intermediate_size,
hidden_size,
input_is_parallel=True,
perform_initialization=False)
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc_in(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.fc_out(hidden_states)
return hidden_states
class GPTJBlock(nn.Module):
def __init__(self, config: GPTJConfig):
super().__init__()
if config.n_inner is None:
inner_dim = 4 * config.n_embd
else:
inner_dim = config.n_inner
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config)
self.mlp = GPTJMLP(inner_dim, config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
mlp_output = self.mlp(hidden_states)
hidden_states = attn_output + mlp_output + residual
return hidden_states
class GPTJModel(nn.Module):
def __init__(self, config: GPTJConfig):
super().__init__()
self.config = config
self.embed_dim = config.n_embd
self.wte = VocabParallelEmbedding(config.vocab_size,
self.embed_dim,
perform_initialization=False)
self.h = nn.ModuleList(
[GPTJBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.h)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPTJForCausalLM(nn.Module):
def __init__(self, config: GPTJConfig):
super().__init__()
self.config = config
assert not config.tie_word_embeddings
self.transformer = GPTJModel(config)
self.lm_head = ColumnParallelLinear(config.n_embd,
config.vocab_size,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata, self.lm_head.bias)
return next_tokens
_column_parallel_weights = [
"wte.weight", "fc_in.weight", "fc_in.bias", "lm_head.weight",
"lm_head.bias"
]
_row_parallel_weights = ["out_proj.weight", "fc_out.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
if "attn.bias" in name or "attn.masked_bias" in name:
continue
is_attention_weight = False
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[1]
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)

View File

@@ -1,5 +1,6 @@
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. # Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
# #
@@ -19,7 +20,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@@ -35,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -48,19 +49,23 @@ class GPTNeoXAttention(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads self.head_size = self.hidden_size // self.total_num_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tensor_model_parallel_world_size == 0 assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.query_key_value = ColumnParallelLinear(config.hidden_size, self.query_key_value = ColumnParallelLinear(
3 * config.hidden_size, config.hidden_size,
gather_output=False, 3 * config.hidden_size,
perform_initialization=False) gather_output=False,
self.dense = RowParallelLinear(config.hidden_size, config.hidden_size, perform_initialization=False)
self.dense = RowParallelLinear(config.hidden_size,
config.hidden_size,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
scaling = self.head_size ** -0.5 scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct) rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0 assert rotary_dim % 2 == 0
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size, self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
@@ -75,11 +80,10 @@ class GPTNeoXAttention(nn.Module):
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event) input_metadata, cache_event)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
@@ -92,7 +96,8 @@ class GPTNeoXMLP(nn.Module):
config.intermediate_size, config.intermediate_size,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size, self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.act = get_act_fn(config.hidden_act) self.act = get_act_fn(config.hidden_act)
@@ -109,8 +114,10 @@ class GPTNeoXLayer(nn.Module):
def __init__(self, config: GPTNeoXConfig): def __init__(self, config: GPTNeoXConfig):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.input_layernorm = nn.LayerNorm(config.hidden_size,
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config) self.attention = GPTNeoXAttention(config)
self.mlp = GPTNeoXMLP(config) self.mlp = GPTNeoXMLP(config)
@@ -154,10 +161,13 @@ class GPTNeoXModel(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_in = VocabParallelEmbedding(config.vocab_size, config.hidden_size, self.embed_in = VocabParallelEmbedding(config.vocab_size,
config.hidden_size,
perform_initialization=False) perform_initialization=False)
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList(
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) [GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward( def forward(
self, self,
@@ -191,8 +201,10 @@ class GPTNeoXForCausalLM(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.gpt_neox = GPTNeoXModel(config) self.gpt_neox = GPTNeoXModel(config)
self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size, self.embed_out = ColumnParallelLinear(config.hidden_size,
bias=False, gather_output=False, config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
@@ -203,25 +215,29 @@ class GPTNeoXForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.gpt_neox( hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.embed_out.weight, hidden_states,
self.embed_out.weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"] _column_parallel_weights = [
"embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight",
"dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"] _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): load_format: str = "auto"):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache): model_name_or_path, cache_dir, load_format):
if ("attention.bias" in name or "attention.masked_bias" in name if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name): or "rotary_emb.inv_freq" in name):
continue continue
param = state_dict[name] param = state_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
@@ -230,17 +246,19 @@ class GPTNeoXForCausalLM(nn.Module):
# required shape is [3 * num_heads * head_size, hidden_size]. # required shape is [3 * num_heads * head_size, hidden_size].
# Thus, we need weight conversion. # Thus, we need weight conversion.
shard_size = param.shape[0] shard_size = param.shape[0]
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank loaded_weight = loaded_weight[
:shard_size * (tensor_model_parallel_rank + 1)] shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
num_heads = self.config.num_attention_heads num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
head_size = hidden_size // num_heads head_size = hidden_size // num_heads
if 'query_key_value.weight' in name: if "query_key_value.weight" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size) loaded_weight = loaded_weight.view(-1, 3, head_size,
hidden_size)
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, hidden_size) loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif 'query_key_value.bias' in name: elif "query_key_value.bias" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size) loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1) loaded_weight = loaded_weight.reshape(-1)

View File

@@ -0,0 +1,292 @@
# -*- coding: utf-8 -*-
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class InternLMMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class InternLMAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(
hidden_size,
3 * self.total_num_heads * self.head_dim,
bias=True,
gather_output=False,
perform_initialization=False,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False,
)
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
class InternLMDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = InternLMAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
)
self.mlp = InternLMMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class InternLMModel(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False)
self.layers = nn.ModuleList([
InternLMDecoderLayer(config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class InternLMForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.model = InternLMModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = [
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
if "rotary_emb.inv_freq" in name:
continue
if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
is_attention_weight = False
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)

View File

@@ -1,5 +1,6 @@
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# #
@@ -24,25 +25,25 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights) load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
hf_model_weights_iterator)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -56,15 +57,19 @@ class LlamaMLP(nn.Module):
hidden_act: str, hidden_act: str,
): ):
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size, self.gate_up_proj = ColumnParallelLinear(hidden_size,
bias=False, gather_output=False, 2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size, hidden_size, self.down_proj = RowParallelLinear(intermediate_size,
bias=False, input_is_parallel=True, hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
if hidden_act != 'silu': if hidden_act != "silu":
raise ValueError(f'Unsupported activation: {hidden_act}. ' raise ValueError(f"Unsupported activation: {hidden_act}. "
'Only silu is supported for now.') "Only silu is supported for now.")
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x):
@@ -80,19 +85,28 @@ class LlamaAttention(nn.Module):
self, self,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim ** -0.5 self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = ColumnParallelLinear(
hidden_size, hidden_size,
3 * self.total_num_heads * self.head_dim, (self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False, perform_initialization=False,
@@ -104,8 +118,12 @@ class LlamaAttention(nn.Module):
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False, perform_initialization=False,
) )
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, self.attn = PagedAttentionWithRoPE(self.num_heads,
self.scaling, rotary_dim=self.head_dim) self.head_dim,
self.scaling,
base=self.rope_theta,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
def forward( def forward(
self, self,
@@ -116,10 +134,10 @@ class LlamaAttention(nn.Module):
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) input_metadata, cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@@ -129,17 +147,23 @@ class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig): def __init__(self, config: LlamaConfig):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size,
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward( def forward(
self, self,
@@ -177,9 +201,12 @@ class LlamaModel(nn.Module):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size, vocab_size = ((config.vocab_size + 63) // 64) * 64
perform_initialization=False) self.embed_tokens = VocabParallelEmbedding(
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) vocab_size, config.hidden_size, perform_initialization=False)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward( def forward(
@@ -209,12 +236,14 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module): class LlamaForCausalLM(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = LlamaModel(config) self.model = LlamaModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size, self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size, vocab_size,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
@@ -227,41 +256,54 @@ class LlamaForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.model( hidden_states = self.model(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.lm_head.weight, hidden_states,
self.lm_head.weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight", _column_parallel_weights = [
"qkv_proj.weight", "gate_proj.weight", "qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
"up_proj.weight"] ]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): load_format: str = "auto"):
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.num_key_value_heads // tp_size)
attention_weight_specs = [
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
("k_proj", kv_proj_shard_size, q_proj_shard_size),
("v_proj", kv_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size),
]
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache): model_name_or_path, cache_dir, load_format):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
is_attention_weight = False is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]): for weight_name, shard_size, offset in attention_weight_specs:
if att_weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")] param = state_dict[name.replace(weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank shard_size * tensor_model_parallel_rank:shard_size *
:shard_size * (tensor_model_parallel_rank + 1)] (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id param_slice = param.data[offset:offset + shard_size]
:shard_size * (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight) param_slice.copy_(loaded_weight)
is_attention_weight = True is_attention_weight = True
break break
@@ -275,10 +317,10 @@ class LlamaForCausalLM(nn.Module):
param = state_dict[name.replace(weight_name, "gate_up_proj")] param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2 shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank shard_size * tensor_model_parallel_rank:shard_size *
:shard_size * (tensor_model_parallel_rank + 1)] (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id param_slice = param.data[shard_size * stride_id:shard_size *
:shard_size * (stride_id + 1)] (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight) param_slice.copy_(loaded_weight)
is_gate_up_weight = True is_gate_up_weight = True
@@ -287,6 +329,12 @@ class LlamaForCausalLM(nn.Module):
continue continue
param = state_dict[name] param = state_dict[name]
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name, load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights, self._column_parallel_weights,
self._row_parallel_weights, self._row_parallel_weights,

View File

@@ -0,0 +1,281 @@
# coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(
total_num_heads: int,
alibi_bias_max: int,
) -> torch.Tensor:
next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / next_power_of_2)
slopes = 1.0 / torch.pow(2, m)
if next_power_of_2 != total_num_heads:
slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
return slopes
class MPTAttention(nn.Module):
def __init__(self, config: MPTConfig):
super().__init__()
self.d_model = config.d_model
self.total_num_heads = config.n_heads
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
assert not config.attn_config["prefix_lm"]
assert config.attn_config["alibi"]
self.qkv_proj = ColumnParallelLinear(
self.d_model,
3 * self.d_model,
bias=not config.no_bias,
gather_output=False,
perform_initialization=False,
)
if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model)
self.k_ln = nn.LayerNorm(self.d_model)
self.out_proj = RowParallelLinear(
self.d_model,
self.d_model,
bias=not config.no_bias,
input_is_parallel=True,
perform_initialization=False,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads,
self.alibi_bias_max)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.head_dim = self.d_model // self.total_num_heads
scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
scaling, alibi_slopes)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
del position_ids # unused.
qkv, _ = self.qkv_proj(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.qk_ln:
q = self.q_ln(q)
k = self.k_ln(k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
output, _ = self.out_proj(attn_output)
return output
class MPTMLP(nn.Module):
def __init__(self, config: MPTConfig):
super().__init__()
hidden_size = config.d_model
expansion_ratio = config.expansion_ratio
intermediate_size = expansion_ratio * hidden_size
self.up_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=not config.no_bias,
gather_output=False,
perform_initialization=False)
self.act = get_act_fn("gelu")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=not config.no_bias,
input_is_parallel=True,
perform_initialization=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.up_proj(x)
x = self.act(x)
x, _ = self.down_proj(x)
return x
class MPTBlock(nn.Module):
def __init__(self, config: MPTConfig):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
x = self.norm_1(hidden_states)
x = self.attn(
position_ids=position_ids,
hidden_states=x,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = hidden_states + x
x = self.norm_2(hidden_states)
x = self.ffn(x)
hidden_states = hidden_states + x
return hidden_states
class MPTModel(nn.Module):
def __init__(self, config: MPTConfig):
super().__init__()
assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm"
self.wte = VocabParallelEmbedding(config.vocab_size,
config.d_model,
perform_initialization=False)
self.blocks = nn.ModuleList(
[MPTBlock(config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:
for module in self.modules():
if hasattr(module, "bias"):
if isinstance(module.bias, nn.Parameter):
# Remove the bias term in Linear and LayerNorm.
module.register_parameter("bias", None)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.blocks)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
block = self.blocks[i]
hidden_states = block(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm_f(hidden_states)
return hidden_states
class MPTForCausalLM(nn.Module):
def __init__(self, config: MPTConfig):
super().__init__()
self.config = config
assert config.tie_word_embeddings
self.transformer = MPTModel(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = ["wte.weight", "up_proj.weight", "up_proj.bias"]
_row_parallel_weights = ["out_proj.weight", "down_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
if "Wqkv" in name:
# NOTE(woosuk): MPT's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor model parallelism is used, we need to shard
# the weight along the hidden dimension.
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected parameter name {name}")
name = name.replace("Wqkv", "qkv_proj")
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)

View File

@@ -1,7 +1,9 @@
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -19,7 +21,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@@ -35,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -43,8 +45,9 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class OPTLearnedPositionalEmbedding(nn.Embedding): class OPTLearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int): def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 # OPT is set up so that if padding_idx is specified then offset the
# and adjust num_embeddings appropriately. Other models don't have this hack # embedding ids by 2 and adjust num_embeddings appropriately. Other
# models don't have this hack
self.offset = 2 self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim) super().__init__(num_embeddings + self.offset, embedding_dim)
@@ -62,20 +65,26 @@ class OPTAttention(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
total_num_heads = num_heads total_num_heads = num_heads
assert num_heads % tensor_model_parallel_world_size == 0 assert num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = embed_dim // total_num_heads self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias, self.qkv_proj = ColumnParallelLinear(embed_dim,
3 * embed_dim,
bias=bias,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, self.out_proj = RowParallelLinear(embed_dim,
embed_dim,
bias=bias,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scaling) scale=self.scaling)
def forward( def forward(
@@ -88,8 +97,8 @@ class OPTAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(q, k, v, key_cache, value_cache,
q, k, v, key_cache, value_cache, input_metadata, cache_event) input_metadata, cache_event)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
@@ -109,17 +118,21 @@ class OPTDecoderLayer(nn.Module):
self.activation_fn = get_act_fn(config.activation_function) self.activation_fn = get_act_fn(config.activation_function)
self.self_attn_layer_norm = nn.LayerNorm( self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) self.embed_dim,
self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim, elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = ColumnParallelLinear(self.embed_dim,
config.ffn_dim,
bias=config.enable_bias, bias=config.enable_bias,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim, self.fc2 = RowParallelLinear(config.ffn_dim,
self.embed_dim,
bias=config.enable_bias, bias=config.enable_bias,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.final_layer_norm = nn.LayerNorm( self.final_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
def forward( def forward(
self, self,
@@ -133,11 +146,10 @@ class OPTDecoderLayer(nn.Module):
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before: if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn( hidden_states = self.self_attn(hidden_states=hidden_states,
hidden_states=hidden_states, kv_cache=kv_cache,
kv_cache=kv_cache, input_metadata=input_metadata,
input_metadata=input_metadata, cache_event=cache_event)
cache_event=cache_event)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention # 350m applies layer norm AFTER attention
if not self.do_layer_norm_before: if not self.do_layer_norm_before:
@@ -167,35 +179,42 @@ class OPTDecoder(nn.Module):
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(
config.word_embed_proj_dim, config.vocab_size,
perform_initialization=False) config.word_embed_proj_dim,
perform_initialization=False)
# Positional embeddings are replicated (not sharded). # Positional embeddings are replicated (not sharded).
self.embed_positions = OPTLearnedPositionalEmbedding( self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size) config.max_position_embeddings, config.hidden_size)
# Project out & in will be replicated if they exist. # Project out & in will be replicated if they exist.
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) self.project_out = nn.Linear(config.hidden_size,
config.word_embed_proj_dim,
bias=False)
else: else:
self.project_out = None self.project_out = None
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) self.project_in = nn.Linear(config.word_embed_proj_dim,
config.hidden_size,
bias=False)
else: else:
self.project_in = None self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility # Note that the only purpose of `config._remove_final_layer_norm` is to
# with checkpoints that have been fine-tuned before transformers v4.20.1 # keep backward compatibility with checkpoints that have been fine-tuned
# before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164 # see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm: if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm( self.final_layer_norm = nn.LayerNorm(
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine config.hidden_size,
) elementwise_affine=config.layer_norm_elementwise_affine)
else: else:
self.final_layer_norm = None self.final_layer_norm = None
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList(
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
def forward( def forward(
self, self,
@@ -217,8 +236,8 @@ class OPTDecoder(nn.Module):
else: else:
cache_event = cache_events[i] cache_event = cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
hidden_states, kv_caches[i], input_metadata, cache_event) cache_event)
if self.final_layer_norm is not None: if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
@@ -241,8 +260,8 @@ class OPTModel(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
return self.decoder( return self.decoder(input_ids, positions, kv_caches, input_metadata,
input_ids, positions, kv_caches, input_metadata, cache_events) cache_events)
class OPTForCausalLM(nn.Module): class OPTForCausalLM(nn.Module):
@@ -263,24 +282,27 @@ class OPTForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.model( hidden_states = self.model(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.lm_head_weight, hidden_states,
self.lm_head_weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"] _column_parallel_weights = [
"embed_tokens.weight", "fc1.weight", "fc1.bias"
]
_row_parallel_weights = ["out_proj.weight", "fc2.weight"] _row_parallel_weights = ["out_proj.weight", "fc2.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): load_format: str = "auto"):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache): model_name_or_path, cache_dir, load_format):
if "lm_head.weight" in name: if "lm_head.weight" in name:
continue continue
@@ -288,16 +310,17 @@ class OPTForCausalLM(nn.Module):
name = "model." + name name = "model." + name
is_attention_weight = False is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]): for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name: if att_weight_name not in name:
continue continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")] param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3 shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank shard_size * tensor_model_parallel_rank:shard_size *
:shard_size * (tensor_model_parallel_rank + 1)] (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id param_slice = param.data[shard_size * stride_id:shard_size *
:shard_size * (stride_id + 1)] (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight) param_slice.copy_(loaded_weight)
is_attention_weight = True is_attention_weight = True

View File

@@ -0,0 +1,316 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import List, Optional, Tuple
import torch
from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor,
hf_model_weights_iterator,
load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights,
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.qwen import QWenConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class QWenMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str = "silu",
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(
hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.c_proj(x)
return x
class QWenAttention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int,
max_position_embeddings: int):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
# pylint: disable=invalid-name
self.c_attn = ColumnParallelLinear(
hidden_size,
3 * hidden_size,
bias=True,
gather_output=False,
perform_initialization=False,
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.c_proj(attn_output)
return output
class QWenBlock(nn.Module):
def __init__(self, config: QWenConfig):
super().__init__()
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = QWenAttention(config.n_embd, config.num_attention_heads,
config.max_position_embeddings)
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.n_embd, config.ffn_hidden_size // 2)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
hidden_states = self.attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class QWenModel(nn.Module):
def __init__(self, config: QWenConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size,
config.n_embd,
perform_initialization=False)
self.h = nn.ModuleList(
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.h)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class QWenLMHeadModel(nn.Module):
def __init__(self, config: QWenConfig):
super().__init__()
self.config = config
self.transformer = QWenModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(
config.n_embd,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = []
_row_parallel_weights = ["c_proj.weight"]
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
if "rotary_emb.inv_freq" in name:
continue
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
if "c_attn" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
if "weight" in name:
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif "bias" in name:
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["w2", "w1"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
if "wte" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tp_rank)
continue
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)

View File

@@ -1,9 +1,6 @@
import vllm.model_executor.parallel_utils.parallel_state import vllm.model_executor.parallel_utils.parallel_state
import vllm.model_executor.parallel_utils.tensor_parallel import vllm.model_executor.parallel_utils.tensor_parallel
# Alias parallel_state as mpu, its legacy name
mpu = parallel_state
__all__ = [ __all__ = [
"parallel_state", "parallel_state",
"tensor_parallel", "tensor_parallel",

View File

@@ -44,7 +44,6 @@ _PIPELINE_GLOBAL_RANKS = None
# rank when broadcasting weights from src to all other data parallel ranks # rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None _DATA_PARALLEL_GLOBAL_RANKS = None
_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None
def initialize_model_parallel( def initialize_model_parallel(
tensor_model_parallel_size: int = 1, tensor_model_parallel_size: int = 1,
@@ -196,20 +195,6 @@ def initialize_model_parallel(
if rank in ranks: if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
def initialize_all_reduce_launcher(
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
disable_graph: bool = False,
) -> None:
global _ALL_REDUCE_LAUNCHER
_ALL_REDUCE_LAUNCHER = GraphAllReduce(
max_num_tokens=max_num_tokens,
hidden_size=hidden_size,
dtype=dtype,
disable_graph=disable_graph,
)
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized.""" """Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \ if _TENSOR_MODEL_PARALLEL_GROUP is None or \
@@ -458,6 +443,7 @@ def get_pipeline_model_parallel_last_rank():
last_rank_local = get_pipeline_model_parallel_world_size() - 1 last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local] return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank(): def get_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline""" """Return the global rank that follows the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
@@ -485,10 +471,6 @@ def get_data_parallel_rank():
"""Return my rank for the data parallel group.""" """Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group()) return torch.distributed.get_rank(group=get_data_parallel_group())
def get_all_reduce_launcher() -> 'GraphAllReduce':
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
return _ALL_REDUCE_LAUNCHER
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none."""
global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP
@@ -515,56 +497,3 @@ def destroy_model_parallel():
_MPU_TENSOR_MODEL_PARALLEL_RANK = None _MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
class GraphAllReduce:
def __init__(
self,
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
disable_graph: bool = False,
) -> None:
self.max_num_tokens = max_num_tokens
self.hidden_size = hidden_size
self.disable_graph = disable_graph
tp_world_size = get_tensor_model_parallel_world_size()
if tp_world_size == 1:
return
self.group = get_tensor_model_parallel_group()
self.buffer = torch.empty(
size=(max_num_tokens, hidden_size),
dtype=dtype,
device='cuda',
)
# Build graphs for different number of tokens.
if not self.disable_graph:
self.graphs = {}
for num_tokens in range(8, max_num_tokens + 1, 8):
self.graphs[num_tokens] = self._build_graph(num_tokens)
def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph:
# Warm up.
torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group)
torch.cuda.synchronize()
# Build graph.
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
torch.distributed.all_reduce(
self.buffer[:num_tokens], group=self.group)
torch.cuda.synchronize()
return graph
def launch(self, x: torch.Tensor) -> torch.Tensor:
# NOTE: x must be a slice of self.buffer.
num_tokens = x.shape[0]
if self.disable_graph:
torch.distributed.all_reduce(x, group=self.group)
else:
self.graphs[num_tokens].replay()
return x

View File

@@ -12,6 +12,7 @@ from .mappings import (
copy_to_tensor_model_parallel_region, copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region, gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region, gather_from_sequence_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region, scatter_to_sequence_parallel_region,
) )
@@ -38,7 +39,7 @@ __all__ = [
"copy_to_tensor_model_parallel_region", "copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region", "gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region", "gather_from_sequence_parallel_region",
# "reduce_from_tensor_model_parallel_region", "reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region", "scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region", "scatter_to_sequence_parallel_region",
# random.py # random.py

View File

@@ -14,7 +14,6 @@ from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_all_reduce_launcher,
) )
from .mappings import ( from .mappings import (
copy_to_tensor_model_parallel_region, copy_to_tensor_model_parallel_region,
@@ -248,8 +247,8 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size = output_size self.output_size = output_size
self.gather_output = gather_output self.gather_output = gather_output
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = get_tensor_model_parallel_world_size() self.world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size) self.output_size_per_partition = divide(output_size, self.world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
if params_dtype is None: if params_dtype is None:
@@ -350,6 +349,7 @@ class RowParallelLinear(torch.nn.Module):
params_dtype: params_dtype:
use_cpu_initialization: use_cpu_initialization:
perform_initialization: perform_initialization:
reduce_results:
""" """
def __init__(self, input_size, output_size, *, def __init__(self, input_size, output_size, *,
@@ -360,6 +360,7 @@ class RowParallelLinear(torch.nn.Module):
params_dtype=None, params_dtype=None,
use_cpu_initialization=False, use_cpu_initialization=False,
perform_initialization=True, perform_initialization=True,
reduce_results=True,
): ):
super(RowParallelLinear, self).__init__() super(RowParallelLinear, self).__init__()
@@ -367,14 +368,19 @@ class RowParallelLinear(torch.nn.Module):
self.input_size = input_size self.input_size = input_size
self.output_size = output_size self.output_size = output_size
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = get_tensor_model_parallel_world_size() self.world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size) self.input_size_per_partition = divide(input_size, self.world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
# Parameters. # Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose. # we allocate the transpose.
@@ -427,17 +433,12 @@ class RowParallelLinear(torch.nn.Module):
input_parallel = input_ input_parallel = input_
else: else:
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
if get_tensor_model_parallel_world_size() == 1: # Matrix multiply.
# Matrix multiply. output_parallel = F.linear(input_parallel, self.weight)
output_ = F.linear(input_parallel, self.weight) if self.reduce_results and self.world_size > 1:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
else: else:
# Matrix multiply. output_ = output_parallel
all_reduce_launcher = get_all_reduce_launcher()
num_tokens = input_parallel.shape[0]
output_buffer = all_reduce_launcher.buffer[:num_tokens]
torch.matmul(input_parallel, self.weight_t, out=output_buffer)
# All-reduce across all the partitions.
output_ = all_reduce_launcher.launch(output_buffer)
if not self.skip_bias_add: if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_ output = output_ + self.bias if self.bias is not None else output_

View File

@@ -3,13 +3,19 @@ import filelock
import glob import glob
import json import json
import os import os
from typing import Iterator, List, Optional, Tuple from collections import defaultdict
from typing import Iterator, List, Optional, Tuple, Any
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file, safe_open
import numpy as np import numpy as np
import torch import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
from vllm.logger import init_logger
logger = init_logger(__name__)
class Disabledtqdm(tqdm): class Disabledtqdm(tqdm):
@@ -17,50 +23,150 @@ class Disabledtqdm(tqdm):
super().__init__(*args, **kwargs, disable=True) super().__init__(*args, **kwargs, disable=True)
def hf_model_weights_iterator( def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False,
) -> Iterator[Tuple[str, torch.Tensor]]:
# Prepare file lock directory to prevent multiple processes from
# downloading the same model weights at the same time.
lock_dir = cache_dir if cache_dir is not None else "/tmp" lock_dir = cache_dir if cache_dir is not None else "/tmp"
lock_file_name = model_name_or_path.replace("/", "-") + ".lock" lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name)) lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
return lock
def _shared_pointers(tensors):
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
failing = []
for _, names in ptrs.items():
if len(names) > 1:
failing.append(names)
return failing
def convert_bin_to_safetensor_file(
pt_filename: str,
sf_filename: str,
):
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
shared = _shared_pointers(loaded)
for shared_weights in shared:
for name in shared_weights[1:]:
loaded.pop(name)
# For tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}
dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True)
save_file(loaded, sf_filename, metadata={"format": "pt"})
# check file size
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
""")
# check if the tensors are the same
reloaded = load_file(sf_filename)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
def prepare_hf_model_weights(
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_safetensors: bool = False,
fall_back_to_pt: bool = True,
):
# Download model weights from huggingface. # Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path) is_local = os.path.isdir(model_name_or_path)
allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
if not is_local: if not is_local:
with lock: # Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path, hf_folder = snapshot_download(model_name_or_path,
allow_patterns="*.bin", allow_patterns=allow_patterns,
cache_dir=cache_dir, cache_dir=cache_dir,
tqdm_class=Disabledtqdm) tqdm_class=Disabledtqdm)
else: else:
hf_folder = model_name_or_path hf_folder = model_name_or_path
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
if not use_safetensors:
hf_weights_files = [
x for x in hf_weights_files if not x.endswith("training_args.bin")
]
hf_bin_files = glob.glob(os.path.join(hf_folder, "*.bin")) if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
return prepare_hf_model_weights(model_name_or_path,
cache_dir=cache_dir,
use_safetensors=False,
fall_back_to_pt=False)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_folder, hf_weights_files, use_safetensors
def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
) -> Iterator[Tuple[str, torch.Tensor]]:
use_safetensors = False
use_np_cache = False
fall_back_to_pt = False
if load_format == "auto":
use_safetensors = True
fall_back_to_pt = True
elif load_format == "safetensors":
use_safetensors = True
elif load_format == "pt":
pass
elif load_format == "npcache":
use_np_cache = True
else:
raise ValueError(f"Unknown load_format: {load_format}")
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
model_name_or_path,
cache_dir=cache_dir,
use_safetensors=use_safetensors,
fall_back_to_pt=fall_back_to_pt)
if use_np_cache: if use_np_cache:
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
# Convert the model weights from torch tensors to numpy arrays for # Convert the model weights from torch tensors to numpy arrays for
# faster loading. # faster loading.
np_folder = os.path.join(hf_folder, 'np') np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True) os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, 'weight_names.json') weight_names_file = os.path.join(np_folder, "weight_names.json")
with lock: # Use file lock to prevent multiple processes from
# dumping the same model weights to numpy at the same time.
with get_lock(model_name_or_path, cache_dir):
if not os.path.exists(weight_names_file): if not os.path.exists(weight_names_file):
weight_names = [] weight_names = []
for bin_file in hf_bin_files: for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu") state = torch.load(bin_file, map_location="cpu")
for name, param in state.items(): for name, param in state.items():
param_path = os.path.join(np_folder, name) param_path = os.path.join(np_folder, name)
with open(param_path, "wb") as f: with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy()) np.save(f, param.cpu().detach().numpy())
weight_names.append(name) weight_names.append(name)
with open(weight_names_file, 'w') as f: with open(weight_names_file, "w") as f:
json.dump(weight_names, f) json.dump(weight_names, f)
with open(weight_names_file, 'r') as f: with open(weight_names_file, "r") as f:
weight_names = json.load(f) weight_names = json.load(f)
for name in weight_names: for name in weight_names:
@@ -68,16 +174,52 @@ def hf_model_weights_iterator(
with open(param_path, "rb") as f: with open(param_path, "rb") as f:
param = np.load(f) param = np.load(f)
yield name, torch.from_numpy(param) yield name, torch.from_numpy(param)
elif use_safetensors:
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
for name in f.keys():
param = f.get_slice(name)
yield name, param
else: else:
for bin_file in hf_bin_files: for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu") state = torch.load(bin_file, map_location="cpu")
for name, param in state.items(): for name, param in state.items():
yield name, param yield name, param
del state
torch.cuda.empty_cache()
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
"""convert PySafeSlice object from safetensors to torch.Tensor
PySafeSlice object supports indexing, which is done before loading the
actual tensor and can reduce the amount of memory being read into the
memory. However, it does not support more advanced functionalities
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
tensor with these more complicated operators, we need to convert to
tensor first.
"""
if not isinstance(x, torch.Tensor):
x = x[:]
return x
def load_padded_tensor_parallel_vocab(
param: torch.Tensor,
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
tensor_model_parallel_rank: int,
) -> None:
shard_size = param.shape[0]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[start_idx:end_idx]
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
param[:loaded_weight.shape[0]].copy_(loaded_weight)
def load_tensor_parallel_weights( def load_tensor_parallel_weights(
param: torch.Tensor, param: torch.Tensor,
loaded_weight: torch.Tensor, loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
param_name: str, param_name: str,
column_parallel_weight_names: List[str], column_parallel_weight_names: List[str],
row_parallel_weight_names: List[str], row_parallel_weight_names: List[str],
@@ -86,19 +228,22 @@ def load_tensor_parallel_weights(
for p in column_parallel_weight_names: for p in column_parallel_weight_names:
if p in param_name: if p in param_name:
shard_size = param.shape[0] shard_size = param.shape[0]
loaded_weight = loaded_weight[ start_idx = tensor_model_parallel_rank * shard_size
shard_size * tensor_model_parallel_rank end_idx = (tensor_model_parallel_rank + 1) * shard_size
:shard_size * (tensor_model_parallel_rank + 1)] loaded_weight = loaded_weight[start_idx:end_idx]
break break
for p in row_parallel_weight_names: for p in row_parallel_weight_names:
if p in param_name: if p in param_name:
shard_size = param.shape[1] shard_size = param.shape[1]
loaded_weight = loaded_weight[ start_idx = tensor_model_parallel_rank * shard_size
:, end_idx = (tensor_model_parallel_rank + 1) * shard_size
shard_size * tensor_model_parallel_rank loaded_weight = loaded_weight[:, start_idx:end_idx]
:shard_size * (tensor_model_parallel_rank + 1)]
break break
assert param.shape == loaded_weight.shape
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
assert param.shape == loaded_weight.shape, (
f"{param_name} shape mismatch between model and checkpoint: "
f"{param.shape} != {loaded_weight.shape}")
param.data.copy_(loaded_weight) param.data.copy_(loaded_weight)

View File

@@ -53,27 +53,34 @@ class RequestOutput:
prompt: The prompt string of the request. prompt: The prompt string of the request.
prompt_token_ids: The token IDs of the prompt. prompt_token_ids: The token IDs of the prompt.
outputs: The output sequences of the request. outputs: The output sequences of the request.
finished: Whether the whole request is finished.
""" """
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
prompt: str, prompt: str,
prompt_token_ids: List[int], prompt_token_ids: List[int],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],
finished: bool,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
self.outputs = outputs self.outputs = outputs
self.finished = finished
@classmethod @classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
# Get the top-n sequences. # Get the top-n sequences.
n = seq_group.sampling_params.n n = seq_group.sampling_params.n
seqs = seq_group.get_seqs() seqs = seq_group.get_seqs()
assert n <= len(seqs) if seq_group.sampling_params.use_beam_search:
sorted_seqs = sorted( sorting_key = lambda seq: seq.get_beam_search_score(
seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True) seq_group.sampling_params.length_penalty)
else:
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n] top_n_seqs = sorted_seqs[:n]
# Create the outputs. # Create the outputs.
@@ -95,13 +102,13 @@ class RequestOutput:
# Every sequence in the sequence group should have the same prompt. # Every sequence in the sequence group should have the same prompt.
prompt = top_n_seqs[0].prompt prompt = top_n_seqs[0].prompt
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs) finished = seq_group.is_finished()
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
finished)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, " return (f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, " f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
f"outputs={self.outputs})") f"outputs={self.outputs}, "
f"finished={self.finished})")
def finished(self) -> bool:
return all(output.finished() for output in self.outputs)

View File

@@ -1,6 +1,8 @@
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
from typing import List, Optional, Union from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5
class SamplingParams: class SamplingParams:
"""Sampling parameters for text generation. """Sampling parameters for text generation.
@@ -32,6 +34,15 @@ class SamplingParams:
top_k: Integer that controls the number of top tokens to consider. Set top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens. to -1 to consider all tokens.
use_beam_search: Whether to use beam search instead of sampling. use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
early_stopping: Controls the stopping condition for beam search. It
accepts the following values: `True`, where the generation stops as
soon as there are `best_of` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very
unlikely to find better candidates; `"never"`, where the beam search
procedure only stops when there cannot be better candidates
(canonical beam search algorithm).
stop: List of strings that stop the generation when they are generated. stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings. The returned output will not contain the stop strings.
ignore_eos: Whether to ignore the EOS token and continue generating ignore_eos: Whether to ignore the EOS token and continue generating
@@ -50,7 +61,9 @@ class SamplingParams:
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
use_beam_search: bool = False, use_beam_search: bool = False,
stop: Union[str, List[str]] = [], length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Union[None, str, List[str]] = None,
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: int = 16, max_tokens: int = 16,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
@@ -63,17 +76,26 @@ class SamplingParams:
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
self.use_beam_search = use_beam_search self.use_beam_search = use_beam_search
self.stop = [stop] if isinstance(stop, str) else list(stop) self.length_penalty = length_penalty
self.early_stopping = early_stopping
if stop is None:
self.stop = []
elif isinstance(stop, str):
self.stop = [stop]
else:
self.stop = list(stop)
self.ignore_eos = ignore_eos self.ignore_eos = ignore_eos
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.logprobs = logprobs self.logprobs = logprobs
self._verify_args() self._verify_args()
if self.use_beam_search: if self.use_beam_search:
self._verity_beam_search() self._verify_beam_search()
elif self.temperature == 0.0: else:
# Zero temperature means greedy sampling. self._verify_non_beam_search()
self._verify_greedy_sampling() if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self._verify_greedy_sampling()
def _verify_args(self) -> None: def _verify_args(self) -> None:
if self.n < 1: if self.n < 1:
@@ -102,22 +124,36 @@ class SamplingParams:
raise ValueError( raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.") f"logprobs must be non-negative, got {self.logprobs}.")
def _verity_beam_search(self) -> None: def _verify_beam_search(self) -> None:
if self.best_of == 1: if self.best_of == 1:
raise ValueError("best_of must be greater than 1 when using beam " raise ValueError("best_of must be greater than 1 when using beam "
f"search. Got {self.best_of}.") f"search. Got {self.best_of}.")
if self.temperature > 0.0: if self.temperature > _SAMPLING_EPS:
raise ValueError("temperature must be 0 when using beam search.") raise ValueError("temperature must be 0 when using beam search.")
if self.top_p < 1.0: if self.top_p < 1.0 - _SAMPLING_EPS:
raise ValueError("top_p must be 1 when using beam search.") raise ValueError("top_p must be 1 when using beam search.")
if self.top_k != -1: if self.top_k != -1:
raise ValueError("top_k must be -1 when using beam search.") raise ValueError("top_k must be -1 when using beam search.")
if self.early_stopping not in [True, False, "never"]:
raise ValueError(
f"early_stopping must be True, False, or 'never', "
f"got {self.early_stopping}.")
def _verify_non_beam_search(self) -> None:
if self.early_stopping is not False:
raise ValueError("early_stopping is not effective and must be "
"False when not using beam search.")
if (self.length_penalty < 1.0 - _SAMPLING_EPS
or self.length_penalty > 1.0 + _SAMPLING_EPS):
raise ValueError(
"length_penalty is not effective and must be the "
"default value of 1.0 when not using beam search.")
def _verify_greedy_sampling(self) -> None: def _verify_greedy_sampling(self) -> None:
if self.best_of > 1: if self.best_of > 1:
raise ValueError("best_of must be 1 when using greedy sampling." raise ValueError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.") f"Got {self.best_of}.")
if self.top_p < 1.0: if self.top_p < 1.0 - _SAMPLING_EPS:
raise ValueError("top_p must be 1 when using greedy sampling.") raise ValueError("top_p must be 1 when using greedy sampling.")
if self.top_k != -1: if self.top_k != -1:
raise ValueError("top_k must be -1 when using greedy sampling.") raise ValueError("top_k must be -1 when using greedy sampling.")
@@ -131,6 +167,8 @@ class SamplingParams:
f"top_p={self.top_p}, " f"top_p={self.top_p}, "
f"top_k={self.top_k}, " f"top_k={self.top_k}, "
f"use_beam_search={self.use_beam_search}, " f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "
f"stop={self.stop}, " f"stop={self.stop}, "
f"ignore_eos={self.ignore_eos}, " f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, " f"max_tokens={self.max_tokens}, "

View File

@@ -1,3 +1,4 @@
"""Sequence and its related classes."""
import copy import copy
import enum import enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
@@ -7,12 +8,14 @@ from vllm.sampling_params import SamplingParams
class SequenceStatus(enum.Enum): class SequenceStatus(enum.Enum):
"""Status of a sequence."""
WAITING = enum.auto() WAITING = enum.auto()
RUNNING = enum.auto() RUNNING = enum.auto()
SWAPPED = enum.auto() SWAPPED = enum.auto()
FINISHED_STOPPED = enum.auto() FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto() FINISHED_ABORTED = enum.auto()
FINISHED_IGNORED = enum.auto()
@staticmethod @staticmethod
def is_finished(status: "SequenceStatus") -> bool: def is_finished(status: "SequenceStatus") -> bool:
@@ -20,6 +23,7 @@ class SequenceStatus(enum.Enum):
SequenceStatus.FINISHED_STOPPED, SequenceStatus.FINISHED_STOPPED,
SequenceStatus.FINISHED_LENGTH_CAPPED, SequenceStatus.FINISHED_LENGTH_CAPPED,
SequenceStatus.FINISHED_ABORTED, SequenceStatus.FINISHED_ABORTED,
SequenceStatus.FINISHED_IGNORED,
] ]
@staticmethod @staticmethod
@@ -30,12 +34,25 @@ class SequenceStatus(enum.Enum):
finish_reason = "length" finish_reason = "length"
elif status == SequenceStatus.FINISHED_ABORTED: elif status == SequenceStatus.FINISHED_ABORTED:
finish_reason = "abort" finish_reason = "abort"
elif status == SequenceStatus.FINISHED_IGNORED:
finish_reason = "length"
else: else:
finish_reason = None finish_reason = None
return finish_reason return finish_reason
class SequenceData: class SequenceData:
"""Data associated with a sequence.
Args:
prompt_token_ids: The token IDs of the prompt.
Attributes:
prompt_token_ids: The token IDs of the prompt.
output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output.
"""
def __init__( def __init__(
self, self,
@@ -52,6 +69,9 @@ class SequenceData:
def get_len(self) -> int: def get_len(self) -> int:
return len(self.output_token_ids) + len(self.prompt_token_ids) return len(self.output_token_ids) + len(self.prompt_token_ids)
def get_prompt_len(self) -> int:
return len(self.prompt_token_ids)
def get_output_len(self) -> int: def get_output_len(self) -> int:
return len(self.output_token_ids) return len(self.output_token_ids)
@@ -71,6 +91,15 @@ class SequenceData:
class Sequence: class Sequence:
"""Stores the data, status, and block information of a sequence.
Args:
seq_id: The ID of the sequence.
prompt: The prompt of the sequence.
prompt_token_ids: The token IDs of the prompt.
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
"""
def __init__( def __init__(
self, self,
@@ -101,7 +130,8 @@ class Sequence:
self.logical_token_blocks.append(block) self.logical_token_blocks.append(block)
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
while token_ids: cursor = 0
while cursor < len(token_ids):
if not self.logical_token_blocks: if not self.logical_token_blocks:
self._append_logical_block() self._append_logical_block()
@@ -111,8 +141,9 @@ class Sequence:
last_block = self.logical_token_blocks[-1] last_block = self.logical_token_blocks[-1]
num_empty_slots = last_block.get_num_empty_slots() num_empty_slots = last_block.get_num_empty_slots()
last_block.append_tokens(token_ids[:num_empty_slots]) last_block.append_tokens(token_ids[cursor:cursor +
token_ids = token_ids[num_empty_slots:] num_empty_slots])
cursor += num_empty_slots
def append_token_id( def append_token_id(
self, self,
@@ -127,6 +158,9 @@ class Sequence:
def get_len(self) -> int: def get_len(self) -> int:
return self.data.get_len() return self.data.get_len()
def get_prompt_len(self) -> int:
return self.data.get_prompt_len()
def get_output_len(self) -> int: def get_output_len(self) -> int:
return self.data.get_output_len() return self.data.get_output_len()
@@ -142,22 +176,48 @@ class Sequence:
def get_cumulative_logprob(self) -> float: def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob return self.data.cumulative_logprob
def get_beam_search_score(self,
length_penalty: float = 0.0,
seq_len: Optional[int] = None,
eos_token_id: Optional[int] = None) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
if seq_len is None:
seq_len = self.get_len()
# Note: HF implementation does not count the EOS token
# towards the length, we align with that here for testing.
if (eos_token_id is not None
and self.get_last_token_id() == eos_token_id):
seq_len -= 1
return self.get_cumulative_logprob() / (seq_len**length_penalty)
def is_finished(self) -> bool: def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status) return SequenceStatus.is_finished(self.status)
def fork(self, child_seq: 'Sequence') -> None: def fork(self, new_seq_id: int) -> "Sequence":
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks) new_seq = copy.deepcopy(self)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs) new_seq.seq_id = new_seq_id
child_seq.data = copy.deepcopy(self.data) return new_seq
return None
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'Sequence(seq_id={self.seq_id}, ' return (f"Sequence(seq_id={self.seq_id}, "
f'status={self.status.name}, ' f"status={self.status.name}, "
f'num_blocks={len(self.logical_token_blocks)})') f"num_blocks={len(self.logical_token_blocks)})")
class SequenceGroup: class SequenceGroup:
"""A group of sequences that are generated from the same prompt.
Args:
request_id: The ID of the request.
seqs: The list of sequences.
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
"""
def __init__( def __init__(
self, self,
@@ -167,46 +227,88 @@ class SequenceGroup:
arrival_time: float, arrival_time: float,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.seqs = seqs self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.arrival_time = arrival_time self.arrival_time = arrival_time
def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if self.sampling_params.use_beam_search:
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
return self.sampling_params.best_of
else:
if self.sampling_params.best_of > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences running.
return self.sampling_params.best_of
# At sampling stages, return the number of actual sequences
# running.
return self.num_seqs(status=SequenceStatus.RUNNING)
def get_seqs( def get_seqs(
self, self,
status: Optional[SequenceStatus] = None, status: Optional[SequenceStatus] = None,
) -> List[Sequence]: ) -> List[Sequence]:
if status is None: if status is None:
return self.seqs return list(self.seqs_dict.values())
else: else:
return [seq for seq in self.seqs if seq.status == status] return [
seq for seq in self.seqs_dict.values() if seq.status == status
]
def get_finished_seqs(self) -> List[Sequence]:
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status)) return len(self.get_seqs(status))
def find(self, seq_id: int) -> Sequence: def find(self, seq_id: int) -> Sequence:
for seq in self.seqs: if seq_id not in self.seqs_dict:
if seq.seq_id == seq_id: raise ValueError(f"Sequence {seq_id} not found.")
return seq return self.seqs_dict[seq_id]
raise ValueError(f'Sequence {seq_id} not found.')
def add(self, seq: Sequence) -> None:
if seq.seq_id in self.seqs_dict:
raise ValueError(f"Sequence {seq.seq_id} already exists.")
self.seqs_dict[seq.seq_id] = seq
def remove(self, seq_id: int) -> None:
if seq_id not in self.seqs_dict:
raise ValueError(f"Sequence {seq_id} not found.")
del self.seqs_dict[seq_id]
def is_finished(self) -> bool: def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.seqs) return all(seq.is_finished() for seq in self.get_seqs())
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, " return (f"SequenceGroup(request_id={self.request_id}, "
f"sampling_params={self.sampling_params}, " f"sampling_params={self.sampling_params}, "
f"num_seqs={len(self.seqs)})") f"num_seqs={len(self.seqs_dict)})")
class SequenceGroupMetadata: class SequenceGroupMetadata:
"""Metadata for a sequence group. Used to create `InputMetadata`.
Args:
request_id: The ID of the request.
is_prompt: Whether the request is at prompt stage.
seq_data: The sequence data. (Seq id -> sequence data)
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
"""
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
is_prompt: bool, is_prompt: bool,
seq_data: Dict[int, SequenceData], # Seq id -> sequence data. seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams, sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers. block_tables: Dict[int, List[int]],
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.is_prompt = is_prompt self.is_prompt = is_prompt
@@ -216,29 +318,39 @@ class SequenceGroupMetadata:
class SequenceOutputs: class SequenceOutputs:
"""The model output associated with a sequence.
Args:
parent_seq_id: The ID of the parent sequence (for forking in beam
search).
output_token: The output token ID.
logprobs: The logprobs of the output token.
(Token id -> logP(x_i+1 | x_0, ..., x_i))
"""
def __init__( def __init__(
self, self,
seq_id: int,
parent_seq_id: int, parent_seq_id: int,
output_token: int, output_token: int,
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i). logprobs: Dict[int, float],
) -> None: ) -> None:
self.seq_id = seq_id
self.parent_seq_id = parent_seq_id self.parent_seq_id = parent_seq_id
self.output_token = output_token self.output_token = output_token
self.logprobs = logprobs self.logprobs = logprobs
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'SequenceOutputs(seq_id={self.seq_id}, ' return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
f'parent_seq_id={self.parent_seq_id}, ' f"output_token={self.output_token}), "
f'output_token={self.output_token}), ' f"logprobs={self.logprobs}")
f'logprobs={self.logprobs}')
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs): if not isinstance(other, SequenceOutputs):
return NotImplemented return NotImplementedError()
return (self.seq_id == other.seq_id and return (self.parent_seq_id == other.parent_seq_id
self.parent_seq_id == other.parent_seq_id and and self.output_token == other.output_token
self.output_token == other.output_token and and self.logprobs == other.logprobs)
self.logprobs == other.logprobs)
# For each sequence group, we generate a list of SequenceOutputs object,
# each of which contains one possible candidate for the next token.
SamplerOutput = List[List[SequenceOutputs]]

Some files were not shown because too many files have changed in this diff Show More