Compare commits

..

172 Commits

Author SHA1 Message Date
Simon Mo
4db5176d97 bump version to v0.5.4 (#7139)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.12, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.8, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.10, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.11, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.12, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.8, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.9, 2.4.0) (push) Has been cancelled
2024-08-05 14:39:48 -07:00
Tyler Michael Smith
4cf1dc39be [Bugfix][CI/Build] Fix CUTLASS FetchContent (#7171) 2024-08-05 14:22:57 -07:00
Tyler Michael Smith
6e4852ce28 [CI/Build] Suppress divide-by-zero and missing return statement warnings (#7001) 2024-08-05 16:00:01 -04:00
Tyler Michael Smith
8571ac4672 [Kernel] Update CUTLASS to 3.5.1 (#7085) 2024-08-05 15:13:43 -04:00
Rui Qiao
997cf78308 [Misc] Fix typo in GroupCoordinator.recv() (#7167)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
2024-08-05 11:10:16 -07:00
Aditya Paliwal
57f560aa23 [BugFix] Use args.trust_remote_code (#7121) 2024-08-05 09:26:14 -07:00
Nick Hill
003f8ee128 [BugFix] Use IP4 localhost form for zmq bind (#7163) 2024-08-05 08:41:03 -07:00
Bongwon Jang
e9630458c7 [SpecDecode] Support FlashInfer in DraftModelRunner (#6926) 2024-08-05 08:05:05 -07:00
Cade Daniel
82a1b1a82b [Speculative decoding] Add periodic log with time spent in proposal/scoring/verification (#6963) 2024-08-05 08:46:44 +00:00
Jungho Christopher Cho
c0d8f1636c [Model] SiglipVisionModel ported from transformers (#6942)
Co-authored-by: Roger Wang <ywang@roblox.com>
2024-08-05 06:22:12 +00:00
Cyrus Leung
cc08fc7225 [Frontend] Reapply "Factor out code for running uvicorn" (#7095) 2024-08-04 20:40:51 -07:00
Alphi
7b86e7c9cd [Model] Add multi-image support for minicpmv (#7122)
Co-authored-by: hezhihui <hzh7269@modelbest.cn>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
2024-08-05 09:23:17 +08:00
Jee Jee Li
f80ab3521c Clean up remaining Punica C information (#7027) 2024-08-04 15:37:08 -07:00
youkaichao
16a1cc9bb2 [misc][distributed] improve libcudart.so finding (#7127) 2024-08-04 11:31:51 -07:00
Thomas Parnell
b1c9aa3daa [Bugfix] [SpecDecode] Default speculative_draft_tensor_parallel_size to 1 when using MLPSpeculator (#7105)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
2024-08-04 07:13:18 -07:00
Jee Jee Li
179a6a36f2 [Model]Refactor MiniCPMV (#7020)
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
2024-08-04 08:12:41 +00:00
youkaichao
83c644fe7e [core][misc] simply output processing with shortcut code path (#7117) 2024-08-04 00:22:19 -07:00
youkaichao
9fadc7b7a0 [misc] add zmq in collect env (#7119) 2024-08-03 22:03:46 -07:00
Yihuan Bu
654bc5ca49 Support for guided decoding for offline LLM (#6878)
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
2024-08-04 03:12:09 +00:00
Jeff Fialho
825b044863 [Frontend] Warn if user max_model_len is greater than derived max_model_len (#7080)
Signed-off-by: Jefferson Fialho <jfialho@ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
2024-08-03 16:01:38 -07:00
youkaichao
44dcb52e39 [ci][test] finalize fork_new_process_for_each_test (#7114) 2024-08-03 10:44:53 -07:00
Kuntai Du
67d745cc68 [CI] Temporarily turn off H100 performance benchmark (#7104) 2024-08-02 23:52:44 -07:00
Jee Jee Li
99d7cabd7b [LoRA] ReplicatedLinear support LoRA (#7081) 2024-08-02 22:40:19 -07:00
Zach Zheng
fb2c1c86c1 [Bugfix] Fix block table for seqs that have prefix cache hits (#7018) 2024-08-02 22:38:15 -07:00
Isotr0py
0c25435daa [Model] Refactor and decouple weight loading logic for InternVL2 model (#7067) 2024-08-02 22:36:14 -07:00
youkaichao
a0d164567c [ci][distributed] disable ray dag tests (#7099) 2024-08-02 22:32:04 -07:00
youkaichao
04e5583425 [ci][distributed] merge distributed test commands (#7097)
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
2024-08-02 21:33:53 -07:00
Cyrus Leung
8c025fa703 [Frontend] Factor out chat message parsing (#7055) 2024-08-02 21:31:27 -07:00
youkaichao
69ea15e5cc [ci][distributed] shorten wait time if server hangs (#7098) 2024-08-02 21:05:16 -07:00
Robert Shaw
ed812a73fa [ Frontend ] Multiprocessing for OpenAI Server with zeromq (#6883)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <joe@joerun.de>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
2024-08-02 18:27:28 -07:00
youkaichao
708989341e [misc] add a flag to enable compile (#7092) 2024-08-02 16:18:45 -07:00
Rui Qiao
22e718ff1a [Misc] Revive to use loopback address for driver IP (#7091)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
2024-08-02 15:50:00 -07:00
Rui Qiao
05308891e2 [Core] Pipeline parallel with Ray ADAG (#6837)
Support pipeline-parallelism with Ray accelerated DAG.

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
2024-08-02 13:55:40 -07:00
Lucas Wilkinson
a8d604ca2a [Misc] Disambiguate quantized types via a new ScalarType (#6396) 2024-08-02 13:51:58 -07:00
Michael Goin
b482b9a5b1 [CI/Build] Add support for Python 3.12 (#7035) 2024-08-02 13:51:22 -07:00
youkaichao
806949514a [ci] set timeout for test_oot_registration.py (#7082) 2024-08-02 10:03:24 -07:00
Jie Fu (傅杰)
c16eaac500 [Hardware][Intel CPU] Update torch 2.4.0 for CPU backend (#6931) 2024-08-02 08:55:58 -07:00
Peng Guanwen
db35186391 [Core] Comment out unused code in sampler (#7023) 2024-08-02 00:58:26 -07:00
youkaichao
660dea1235 [cuda][misc] remove error_on_invalid_device_count_status (#7069) 2024-08-02 00:14:21 -07:00
Bongwon Jang
cf2a1a4d9d Fix tracing.py (#7065) 2024-08-01 23:28:00 -07:00
youkaichao
252357793d [ci][distributed] try to fix pp test (#7054) 2024-08-01 22:03:12 -07:00
Cyrus Leung
3bb4b1e4cd [mypy] Speed up mypy checking (#7056) 2024-08-01 19:49:43 -07:00
Lily Liu
954f7305a1 [Kernel] Fix input for flashinfer prefill wrapper. (#7008) 2024-08-01 18:44:16 -07:00
Woosuk Kwon
6ce01f3066 [Performance] Optimize get_seqs (#7051) 2024-08-01 18:29:52 -07:00
Tyler Michael Smith
6a11fdfbb8 [CI/Build][Bugfix] Fix CUTLASS header-only line (#7034) 2024-08-01 13:51:15 -07:00
Woosuk Kwon
805a8a75f2 [Misc] Support attention logits soft-capping with flash-attn (#7022) 2024-08-01 13:14:37 -07:00
omkar kakarparthi
562e580abc Update run-amd-test.sh (#7044) 2024-08-01 13:12:37 -07:00
Murali Andoorveedu
fc912e0886 [Models] Support Qwen model with PP (#6974)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
2024-08-01 12:40:43 -07:00
Michael Goin
f4fd390f5d [Bugfix] Lower gemma's unloaded_params exception to warning (#7002) 2024-08-01 12:01:07 -07:00
Michael Goin
fb3db61688 [CI/Build] Remove sparseml requirement from testing (#7037) 2024-08-01 12:00:51 -07:00
Isotr0py
2dd34371a6 [Bugfix] Fix RMSNorm forward in InternViT attention qk_layernorm (#6992) 2024-08-01 12:00:28 -07:00
Sage Moore
7e0861bd0b [CI/Build] Update PyTorch to 2.4.0 (#6951)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
2024-08-01 11:11:24 -07:00
Alexei-V-Ivanov-AMD
a72a424b3e [Build/CI] Fixing Docker Hub quota issue. (#7043) 2024-08-01 11:07:37 -07:00
youkaichao
c8a7e93273 [core][scheduler] simplify and improve scheduler (#6867) 2024-07-31 23:51:09 -07:00
zifeitong
3c10591ef2 [Bugfix] Set SamplingParams.max_tokens for OpenAI requests if not provided by user (#6954) 2024-07-31 21:13:34 -07:00
Aurick Qiao
0437492ea9 PP comm optimization: replace send with partial send + allgather (#6695)
Co-authored-by: Aurick Qiao <aurick.qiao@snowflake.com>
2024-07-31 20:15:42 -07:00
Travis Johnson
630dd9e0ae [Bugfix][Model] Skip loading lm_head weights if using tie_word_embeddings (#6758)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
2024-07-31 19:49:11 -07:00
Woosuk Kwon
23993a7997 [Bugfix][TPU] Do not use torch.Generator for TPUs (#6981) 2024-07-31 18:50:28 -07:00
xuyi
1d2e7fb73f [Model] Pipeline parallel support for Qwen2 (#6924) 2024-07-31 18:49:51 -07:00
Jee Jee Li
7ecee34321 [Kernel][RFC] Refactor the punica kernel based on Triton (#5036) 2024-07-31 17:12:24 -07:00
Simon Mo
7eb0cb4a14 Revert "[Frontend] Factor out code for running uvicorn" (#7012)
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
2024-07-31 16:34:26 -07:00
Michael Goin
a0dce9383a [Misc] Add compressed-tensors to optimized quant list (#7006) 2024-07-31 14:40:44 -07:00
Varun Sundar Rabindranath
35e9c12bfa [Kernel] Tuned int8 Cutlass Kernels for SM75 (T4) (#6996)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
2024-07-31 14:40:32 -07:00
Varun Sundar Rabindranath
93548eb37e [Kernel] Enable FP8 Cutlass for Ada Lovelace (#6950)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
2024-07-31 14:40:22 -07:00
Michael Goin
460c1884e3 [Bugfix] Support cpu offloading with fp8 quantization (#6960) 2024-07-31 12:47:46 -07:00
Cody Yu
bd70013407 [MISC] Introduce pipeline parallelism partition strategies (#6920)
Co-authored-by: youkaichao <youkaichao@126.com>
2024-07-31 12:02:17 -07:00
Avshalom Manevich
2ee8d3ba55 [Model] use FusedMoE layer in Jamba (#6935) 2024-07-31 12:00:24 -07:00
Cyrus Leung
daed30c4a9 [Bugfix] Fix feature size calculation for LLaVA-NeXT (#6982) 2024-07-31 23:46:17 +08:00
Alphi
2f4e108f75 [Bugfix] Clean up MiniCPM-V (#6939)
Co-authored-by: hezhihui <hzh7269@modelbest.cn>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
2024-07-31 14:39:19 +00:00
HandH1998
6512937de1 Support W4A8 quantization for vllm (#5218) 2024-07-31 07:55:21 -06:00
Fei
c0644cf9ce [Bugfix] fix logit processor excceed vocab size issue (#6927) 2024-07-31 16:16:01 +08:00
Woosuk Kwon
533d1932d2 [Bugfix][TPU] Set readonly=True for non-root devices (#6980) 2024-07-31 00:19:28 -07:00
Cyrus Leung
9f0e69b653 [CI/Build] Fix mypy errors (#6968) 2024-07-30 19:49:48 -07:00
Cyrus Leung
f230cc2ca6 [Bugfix] Fix broadcasting logic for multi_modal_kwargs (#6836) 2024-07-31 10:38:45 +08:00
Cyrus Leung
da1f7cc12a [mypy] Enable following imports for some directories (#6681) 2024-07-31 10:38:03 +08:00
Cade Daniel
c32ab8be1a [Speculative decoding] Add serving benchmark for llama3 70b + speculative decoding (#6964) 2024-07-31 00:53:21 +00:00
Cade Daniel
fb4f530bf5 [CI] [nightly benchmark] Do not re-download sharegpt dataset if exists (#6706) 2024-07-30 16:28:49 -07:00
Cade Daniel
79319cedfa [Nightly benchmarking suite] Remove pkill python from run benchmark suite (#6965) 2024-07-30 16:28:05 -07:00
Simon Mo
40c27a7cbb [Build] Temporarily Disable Kernels and LoRA tests (#6961) 2024-07-30 14:59:48 -07:00
youkaichao
6ca8031e71 [core][misc] improve free_finished_seq_groups (#6865)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-07-30 14:32:12 -07:00
Tyler Michael Smith
d7a299edaa [Kernel] Remove scaled_fp8_quant kernel padding footgun (#6842) 2024-07-30 16:37:01 -04:00
Sanger Steel
052b6f8ca4 [Bugfix] Fix tensorizer memory profiling bug during testing (#6881) 2024-07-30 11:48:50 -07:00
Ilya Lavrenov
5895b24677 [OpenVINO] Updated OpenVINO requirements and build docs (#6948) 2024-07-30 11:33:01 -07:00
Tyler Michael Smith
cbbc904470 [Kernel] Squash a few more warnings (#6914) 2024-07-30 13:50:42 -04:00
Nick Hill
5cf9254a9c [BugFix] Fix use of per-request seed with pipeline parallel (#6698) 2024-07-30 10:40:08 -07:00
fzyzcjy
f058403683 [Doc] Super tiny fix doc typo (#6949) 2024-07-30 09:14:03 -07:00
Roger Wang
c66c7f86ac [Bugfix] Fix PaliGemma MMP (#6930) 2024-07-30 02:20:57 -07:00
Woosuk Kwon
6e063ea35b [TPU] Fix greedy decoding (#6933) 2024-07-30 02:06:29 -07:00
Varun Sundar Rabindranath
af647fb8b3 [Kernel] Tuned int8 kernels for Ada Lovelace (#6848)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
2024-07-29 20:24:58 -06:00
Tyler Michael Smith
61a97c32f6 [Kernel] Fix marlin divide-by-zero warnings (#6904) 2024-07-30 01:26:07 +00:00
Kevin H. Luu
4fbf4aa128 [ci] GHA workflow to remove ready label upon "/notready" comment (#6921)
Signed-off-by: kevin <kevin@anyscale.com>
2024-07-29 17:03:45 -07:00
Tyler Michael Smith
aae6d36f7e [Kernel] Remove unused variables in awq/gemm_kernels.cu (#6908) 2024-07-29 18:01:17 -06:00
Nick Hill
9f69d8245a [Frontend] New allowed_token_ids decoding request parameter (#6753) 2024-07-29 23:37:27 +00:00
Thomas Parnell
9a7e2d0534 [Bugfix] Allow vllm to still work if triton is not installed. (#6786)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
2024-07-29 14:51:27 -07:00
Earthwalker
7f8d612d24 [TPU] Support tensor parallelism in async llm engine (#6891) 2024-07-29 12:42:21 -07:00
Tyler Michael Smith
60d1c6e584 [Kernel] Fix deprecation function warnings squeezellm quant_cuda_kernel (#6901) 2024-07-29 09:59:02 -07:00
Peng Guanwen
db9e5708a9 [Core] Reduce unnecessary compute when logprobs=None (#6532) 2024-07-29 16:47:31 +00:00
Varun Sundar Rabindranath
766435e660 [Kernel] Tuned FP8 Kernels for Ada Lovelace (#6677)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
2024-07-29 09:42:35 -06:00
Isotr0py
7cbd9ec7a9 [Model] Initialize support for InternVL2 series models (#6514)
Co-authored-by: Roger Wang <ywang@roblox.com>
2024-07-29 10:16:30 +00:00
Elsa Granger
3eeb148f46 [Misc] Pass cutlass_fp8_supported correctly in fbgemm_fp8 (#6871) 2024-07-28 11:13:49 -04:00
Michael Goin
b1366a9534 Add Nemotron to PP_SUPPORTED_MODELS (#6863) 2024-07-27 15:05:17 -07:00
Alexander Matveev
75acdaa4b6 [Kernel] Increase precision of GPTQ/AWQ Marlin kernel (#6795) 2024-07-27 17:52:33 -04:00
Woosuk Kwon
fad5576c58 [TPU] Reduce compilation time & Upgrade PyTorch XLA version (#6856) 2024-07-27 10:28:33 -07:00
Chenggang Wu
f954d0715c [Docs] Add RunLLM chat widget (#6857) 2024-07-27 09:24:46 -07:00
Cyrus Leung
1ad86acf17 [Model] Initial support for BLIP-2 (#5920)
Co-authored-by: ywang96 <ywang@roblox.com>
2024-07-27 11:53:07 +00:00
Roger Wang
ecb33a28cb [CI/Build][Doc] Update CI and Doc for VLM example changes (#6860) 2024-07-27 09:54:14 +00:00
Wang Ran (汪然)
a57d75821c [bugfix] make args.stream work (#6831) 2024-07-27 09:07:02 +00:00
Roger Wang
925de97e05 [Bugfix] Fix VLM example typo (#6859) 2024-07-27 14:24:08 +08:00
Roger Wang
aa46953a20 [Misc][VLM][Doc] Consolidate offline examples for vision language models (#6858)
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
2024-07-26 22:44:13 -07:00
Travis Johnson
593e79e733 [Bugfix] torch.set_num_threads() in multiproc_gpu_executor (#6802)
[Bugfix] Use torch.set_num_threads() to configure parallelism in multiproc_gpu_executor (#6802)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
2024-07-26 22:15:20 -07:00
Harry Mellor
c53041ae3b [Doc] Add missing mock import to docs conf.py (#6834) 2024-07-27 04:47:33 +00:00
Woosuk Kwon
52f07e3dec [Hardware][TPU] Implement tensor parallelism with Ray (#5871) 2024-07-26 20:54:27 -07:00
Joe
14dbd5a767 [Model] H2O Danube3-4b (#6451) 2024-07-26 20:47:50 -07:00
tomeras91
ed94e4f427 [Bugfix][Model] Jamba assertions and no chunked prefill by default for Jamba (#6784) 2024-07-26 20:45:31 -07:00
omrishiv
3c3012398e [Doc] add VLLM_TARGET_DEVICE=neuron to documentation for neuron (#6844)
Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com>
2024-07-26 20:20:16 -07:00
Woosuk Kwon
ced36cd89b [ROCm] Upgrade PyTorch nightly version (#6845) 2024-07-26 20:16:13 -07:00
Sanger Steel
969d032265 [Bugfix]: Fix Tensorizer test failures (#6835) 2024-07-26 20:02:25 -07:00
Lucas Wilkinson
55712941e5 [Bug Fix] Illegal memory access, FP8 Llama 3.1 405b (#6852) 2024-07-27 02:27:44 +00:00
Cyrus Leung
981b0d5673 [Frontend] Factor out code for running uvicorn (#6828) 2024-07-27 09:58:25 +08:00
Woosuk Kwon
d09b94ca58 [TPU] Support collective communications in XLA devices (#6813) 2024-07-27 01:45:57 +00:00
chenqianfzh
bb5494676f enforce eager mode with bnb quantization temporarily (#6846) 2024-07-27 01:32:20 +00:00
Gurpreet Singh Dhami
b5f49ee55b Update README.md (#6847) 2024-07-27 00:26:45 +00:00
Zhanghao Wu
150a1ffbfd [Doc] Update SkyPilot doc for wrong indents and instructions for update service (#4283) 2024-07-26 14:39:10 -07:00
Michael Goin
281977bd6e [Doc] Add Nemotron to supported model docs (#6843) 2024-07-26 17:32:44 -04:00
Li, Jiang
3bbb4936dc [Hardware] [Intel] Enable Multiprocessing and tensor parallel in CPU backend and update documentation (#6125) 2024-07-26 13:50:10 -07:00
Woosuk Kwon
aa4867791e [Misc][TPU] Support TPU in initialize_ray_cluster (#6812) 2024-07-26 19:39:49 +00:00
Woosuk Kwon
71734f1bf2 [Build/CI][ROCm] Minor simplification to Dockerfile.rocm (#6811) 2024-07-26 12:28:32 -07:00
Tyler Michael Smith
50704f52c4 [Bugfix][Kernel] Promote another index to int64_t (#6838) 2024-07-26 18:41:04 +00:00
Michael Goin
07278c37dd [Model] Support Nemotron models (Nemotron-3, Nemotron-4, Minitron) (#6611) 2024-07-26 14:33:42 -04:00
youkaichao
85ad7e2d01 [doc][debugging] add known issues for hangs (#6816) 2024-07-25 21:48:05 -07:00
Peng Guanwen
89a84b0bb7 [Core] Use array to speedup padding (#6779) 2024-07-25 21:31:31 -07:00
Anthony Platanios
084a01fd35 [Bugfix] [Easy] Fixed a bug in the multiprocessing GPU executor. (#6770) 2024-07-25 21:25:35 -07:00
QQSong
062a1d0fab Fix ReplicatedLinear weight loading (#6793) 2024-07-25 19:24:58 -07:00
Kevin H. Luu
2eb9f4ff26 [ci] Mark tensorizer as soft fail and separate from grouped test (#6810)
[ci] Mark tensorizer test as soft fail and separate it from grouped test in fast check (#6810)
Signed-off-by: kevin <kevin@anyscale.com>
2024-07-25 18:08:33 -07:00
youkaichao
443c7cf4cf [ci][distributed] fix flaky tests (#6806) 2024-07-25 17:44:09 -07:00
SangBin Cho
1adddb14bf [Core] Fix ray forward_dag error mssg (#6792) 2024-07-25 16:53:25 -07:00
Woosuk Kwon
b7215de2c5 [Docs] Publish 5th meetup slides (#6799) 2024-07-25 16:47:55 -07:00
youkaichao
f3ff63c3f4 [doc][distributed] improve multinode serving doc (#6804) 2024-07-25 15:38:32 -07:00
Lucas Wilkinson
cd7edc4e87 [Bugfix] Fix empty (nullptr) channelwise scales when loading wNa16 using compressed tensors (#6798) 2024-07-25 15:05:09 -07:00
Kuntai Du
6a1e25b151 [Doc] Add documentations for nightly benchmarks (#6412) 2024-07-25 11:57:16 -07:00
Tyler Michael Smith
95db75de64 [Bugfix] Add synchronize to prevent possible data race (#6788)
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2024-07-25 10:40:01 -07:00
Michael Goin
65b1f121c8 [Bugfix] Fix kv_cache_dtype=fp8 without scales for FP8 checkpoints (#6761) 2024-07-25 09:46:15 -07:00
Robert Shaw
889da130e7 [ Misc ] fp8-marlin channelwise via compressed-tensors (#6524)
Co-authored-by: mgoin <michael@neuralmagic.com>
2024-07-25 09:46:04 -07:00
Alphi
b75e314fff [Bugfix] Add image placeholder for OpenAI Compatible Server of MiniCPM-V (#6787)
Co-authored-by: hezhihui <hzh7269@modelbest.cn>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
2024-07-25 09:42:49 -07:00
Chang Su
316a41ac1d [Bugfix] Fix encoding_format in examples/openai_embedding_client.py (#6755) 2024-07-24 22:48:07 -07:00
Alexander Matveev
0310029a2f [Bugfix] Fix awq_marlin and gptq_marlin flags (#6745) 2024-07-24 22:34:11 -07:00
Cody Yu
309aaef825 [Bugfix] Fix decode tokens w. CUDA graph (#6757) 2024-07-24 22:33:56 -07:00
Alphi
9e169a4c61 [Model] Adding support for MiniCPM-V (#4087) 2024-07-24 20:59:30 -07:00
Evan Z. Liu
5689e256ba [Frontend] Represent tokens with identifiable strings (#6626) 2024-07-25 09:51:00 +08:00
youkaichao
740374d456 [core][distributed] fix zmq hang (#6759) 2024-07-24 17:37:12 -07:00
Hongxia Yang
d88c458f44 [Doc][AMD][ROCm]Added tips to refer to mi300x tuning guide for mi300x users (#6754) 2024-07-24 14:32:57 -07:00
Michael Goin
421e218b37 [Bugfix] Bump transformers to 4.43.2 (#6752) 2024-07-24 13:22:16 -07:00
Antoni Baum
5448f67635 [Core] Tweaks to model runner/input builder developer APIs (#6712) 2024-07-24 12:17:12 -07:00
Antoni Baum
0e63494cf3 Add fp8 support to reshape_and_cache_flash (#6667) 2024-07-24 18:36:52 +00:00
Daniele
ee812580f7 [Frontend] split run_server into build_server and run_server (#6740) 2024-07-24 10:36:04 -07:00
Allen.Dou
40468b13fa [Bugfix] Miscalculated latency lead to time_to_first_token_seconds inaccurate. (#6686) 2024-07-24 08:58:42 -07:00
Nick Hill
2cf0df3381 [Bugfix] Fix speculative decode seeded test (#6743) 2024-07-24 08:58:31 -07:00
LF Marques
545146349c Adding f-string to validation error which is missing (#6748) 2024-07-24 08:55:53 -07:00
liuyhwangyh
f4f8a9d892 [Bugfix]fix modelscope compatible issue (#6730) 2024-07-24 05:04:46 -07:00
Alexei-V-Ivanov-AMD
b570811706 [Build/CI] Update run-amd-test.sh. Enable Docker Hub login. (#6711) 2024-07-24 05:01:14 -07:00
Woosuk Kwon
ccc4a73257 [Docs][ROCm] Detailed instructions to build from source (#6680) 2024-07-24 01:07:23 -07:00
Roger Wang
0a740a11ba [Bugfix] Fix token padding for chameleon (#6724) 2024-07-24 01:05:09 -07:00
Nick Hill
c882a7f5b3 [SpecDecoding] Update MLPSpeculator CI tests to use smaller model (#6714) 2024-07-24 07:34:22 +00:00
William Lin
5e8ca973eb [Bugfix] fix flashinfer cudagraph capture for PP (#6708) 2024-07-24 01:49:44 +00:00
dongmao zhang
87525fab92 [bitsandbytes]: support read bnb pre-quantized model (#5753)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
2024-07-23 23:45:09 +00:00
Thomas Parnell
2f808e69ab [Bugfix] StatLoggers: cache spec decode metrics when they get collected. (#6645)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
2024-07-23 23:05:05 +00:00
Michael Goin
01c16ede6b [CI] Add smoke test for non-uniform AutoFP8 quantization (#6702) 2024-07-23 22:45:12 +00:00
youkaichao
72fc704803 [build] relax wheel size limit (#6704) 2024-07-23 14:03:49 -07:00
Roger Wang
1bedf210e3 Bump transformers version for Llama 3.1 hotfix and patch Chameleon (#6690) 2024-07-23 13:47:48 -07:00
Travis Johnson
507ef787d8 [Model] Pipeline Parallel Support for DeepSeek v2 (#6519)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
2024-07-23 12:22:09 -07:00
Yehoshua Cohen
58f53034ad [Frontend] Add Usage data in each chunk for chat_serving. #6540 (#6652) 2024-07-23 11:41:55 -07:00
Michael Goin
0eb0757bef [Misc] Add ignored layers for fp8 quantization (#6657) 2024-07-23 14:04:04 -04:00
372 changed files with 21976 additions and 8624 deletions

View File

@@ -1,7 +1,7 @@
import os import os
import zipfile import zipfile
MAX_SIZE_MB = 200 MAX_SIZE_MB = 250
def print_top_10_largest_files(zip_file): def print_top_10_largest_files(zip_file):

View File

@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
model_name: "HandH1998/QQQ-Llama-3-8b-g128"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.409
- name: "exact_match,flexible-extract"
value: 0.406
limit: 1000
num_fewshot: 5

View File

@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nvidia/Minitron-4B-Base -b auto -l 1000 -f 5 -t 1
model_name: "nvidia/Minitron-4B-Base"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.252
- name: "exact_match,flexible-extract"
value: 0.252
limit: 1000
num_fewshot: 5

View File

@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-FP8W8 -b auto -l 1000 -f 5 -t 1
model_name: "nm-testing/Qwen2-1.5B-Instruct-FP8W8"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.578
- name: "exact_match,flexible-extract"
value: 0.585
limit: 1000
num_fewshot: 5

View File

@@ -4,4 +4,7 @@ Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Minitron-4B-Base.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml
Meta-Llama-3-8B-QQQ.yaml

View File

@@ -3,30 +3,51 @@
## Introduction ## Introduction
This directory contains the performance benchmarking CI for vllm. This directory contains two sets of benchmark for vllm.
The goal is to help developers know the impact of their PRs on the performance of vllm. - Performance benchmark: benchmark vllm's performance under various workload, for **developers** to gain clarity on whether their PR improves/degrades vllm's performance
- Nightly benchmark: compare vllm's performance against alternatives (tgi, trt-llm and lmdeploy), for **the public** to know when to choose vllm.
This benchmark will be *triggered* upon:
- A PR being merged into vllm.
- Every commit for those PRs with `perf-benchmarks` label.
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for more GPUs is comming later), with different models. See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results.
## Performance benchmark quick overview
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!), with different models.
**Benchmarking Duration**: about 1hr. **Benchmarking Duration**: about 1hr.
**For benchmarking developers**: please try your best to constraint the duration of benchmarking to less than 1.5 hr so that it won't take forever to run. **For benchmarking developers**: please try your best to constraint the duration of benchmarking to about 1 hr so that it won't take forever to run.
## Configuring the workload ## Nightly benchmark quick overview
The benchmarking workload contains three parts: **Benchmarking Coverage**: Fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) on Llama-3 8B, 70B and Mixtral 8x7B.
- Latency tests in `latency-tests.json`.
- Throughput tests in `throughput-tests.json`.
- Serving tests in `serving-tests.json`.
See [descriptions.md](tests/descriptions.md) for detailed descriptions. **Benchmarking engines**: vllm, TGI, trt-llm and lmdeploy.
### Latency test **Benchmarking Duration**: about 3.5hrs.
## Trigger the benchmark
Performance benchmark will be triggered when:
- A PR being merged into vllm.
- Every commit for those PRs with `perf-benchmarks` label.
Nightly benchmark will be triggered when:
- Every commit for those PRs with `nightly-benchmarks` label.
## Performance benchmark details
See [descriptions.md](tests/descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases.
#### Latency test
Here is an example of one test inside `latency-tests.json`: Here is an example of one test inside `latency-tests.json`:
@@ -54,12 +75,12 @@ Note that the performance numbers are highly sensitive to the value of the param
WARNING: The benchmarking script will save json results by itself, so please do not configure `--output-json` parameter in the json file. WARNING: The benchmarking script will save json results by itself, so please do not configure `--output-json` parameter in the json file.
### Throughput test #### Throughput test
The tests are specified in `throughput-tests.json`. The syntax is similar to `latency-tests.json`, except for that the parameters will be fed forward to `benchmark_throughput.py`. The tests are specified in `throughput-tests.json`. The syntax is similar to `latency-tests.json`, except for that the parameters will be fed forward to `benchmark_throughput.py`.
The number of this test is also stable -- a slight change on the value of this number might vary the performance numbers by a lot. The number of this test is also stable -- a slight change on the value of this number might vary the performance numbers by a lot.
### Serving test #### Serving test
We test the throughput by using `benchmark_serving.py` with request rate = inf to cover the online serving overhead. The corresponding parameters are in `serving-tests.json`, and here is an example: We test the throughput by using `benchmark_serving.py` with request rate = inf to cover the online serving overhead. The corresponding parameters are in `serving-tests.json`, and here is an example:
``` ```
@@ -96,9 +117,36 @@ The number of this test is less stable compared to the delay and latency benchma
WARNING: The benchmarking script will save json results by itself, so please do not configure `--save-results` or other results-saving-related parameters in `serving-tests.json`. WARNING: The benchmarking script will save json results by itself, so please do not configure `--save-results` or other results-saving-related parameters in `serving-tests.json`.
## Visualizing the results #### Visualizing the results
The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](tests/descriptions.md) with real benchmarking results. The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](tests/descriptions.md) with real benchmarking results.
You can find the result presented as a table inside the `buildkite/performance-benchmark` job page. You can find the result presented as a table inside the `buildkite/performance-benchmark` job page.
If you do not see the table, please wait till the benchmark finish running. If you do not see the table, please wait till the benchmark finish running.
The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file. The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file.
The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking. The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking.
## Nightly test details
See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines.
#### Workflow
- The [nightly-pipeline.yaml](nightly-pipeline.yaml) specifies the docker containers for different LLM serving engines.
- Inside each container, we run [run-nightly-suite.sh](run-nightly-suite.sh), which will probe the serving engine of the current container.
- The `run-nightly-suite.sh` will redirect the request to `tests/run-[llm serving engine name]-nightly.sh`, which parses the workload described in [nightly-tests.json](tests/nightly-tests.json) and performs the benchmark.
- At last, we run [scripts/plot-nightly-results.py](scripts/plot-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite.
#### Nightly tests
In [nightly-tests.json](tests/nightly-tests.json), we include the command line arguments for benchmarking commands, together with the benchmarking test cases. The format is highly similar to performance benchmark.
#### Docker containers
The docker containers for benchmarking are specified in `nightly-pipeline.yaml`.
WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `tests/run-[llm serving engine name]-nightly.sh`.
WARNING: populating `trt-llm` to latest version is not easy, as it requires updating several protobuf files in [tensorrt-demo](https://github.com/neuralmagic/tensorrt-demo.git).

View File

@@ -42,20 +42,20 @@ steps:
- name: devshm - name: devshm
emptyDir: emptyDir:
medium: Memory medium: Memory
- label: "H100" # - label: "H100"
agents: # agents:
queue: H100 # queue: H100
plugins: # plugins:
- docker#v5.11.0: # - docker#v5.11.0:
image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT # image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
command: # command:
- bash # - bash
- .buildkite/nightly-benchmarks/run-benchmarks-suite.sh # - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh
mount-buildkite-agent: true # mount-buildkite-agent: true
propagate-environment: true # propagate-environment: true
ipc: host # ipc: host
gpus: all # gpus: all
environment: # environment:
- VLLM_USAGE_SOURCE # - VLLM_USAGE_SOURCE
- HF_TOKEN # - HF_TOKEN

View File

@@ -34,6 +34,15 @@ check_hf_token() {
fi fi
} }
ensure_sharegpt_downloaded() {
local FILE=ShareGPT_V3_unfiltered_cleaned_split.json
if [ ! -f "$FILE" ]; then
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE
else
echo "$FILE already exists."
fi
}
json2args() { json2args() {
# transforms the JSON string to command line args, and '_' is replaced to '-' # transforms the JSON string to command line args, and '_' is replaced to '-'
# example: # example:
@@ -73,11 +82,6 @@ kill_gpu_processes() {
echo "All GPU processes have been killed." echo "All GPU processes have been killed."
fi fi
# Sometimes kill with pid doesn't work properly, we can also kill all process running python or python3
# since we are in container anyway
pkill -9 -f python
pkill -9 -f python3
# waiting for GPU processes to be fully killed # waiting for GPU processes to be fully killed
# loop while nvidia-smi returns any processes # loop while nvidia-smi returns any processes
while [ -n "$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)" ]; do while [ -n "$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)" ]; do
@@ -355,7 +359,7 @@ main() {
# prepare for benchmarking # prepare for benchmarking
cd benchmarks || exit 1 cd benchmarks || exit 1
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json ensure_sharegpt_downloaded
declare -g RESULTS_FOLDER=results/ declare -g RESULTS_FOLDER=results/
mkdir -p $RESULTS_FOLDER mkdir -p $RESULTS_FOLDER
QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/

View File

@@ -55,5 +55,26 @@
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"num_prompts": 200 "num_prompts": 200
} }
},
{
"test_name": "serving_llama70B_tp4_sharegpt_specdecode",
"qps_list": [2],
"server_parameters": {
"model": "meta-llama/Meta-Llama-3-70B-Instruct",
"disable_log_requests": "",
"tensor_parallel_size": 4,
"swap_space": 16,
"speculative_model": "turboderp/Qwama-0.5B-Instruct",
"num_speculative_tokens": 4,
"speculative_draft_tensor_parallel_size": 1,
"use_v2_block_manager": ""
},
"client_parameters": {
"model": "meta-llama/Meta-Llama-3-70B-Instruct",
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"num_prompts": 200
}
} }
] ]

View File

@@ -55,7 +55,7 @@ while true; do
done done
echo "--- Pulling container" echo "--- Pulling container"
image_name="rocmshared/vllm-ci:${BUILDKITE_COMMIT}" image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}"
container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
docker pull ${image_name} docker pull ${image_name}

View File

@@ -3,26 +3,38 @@
set -ex set -ex
# Try building the docker image # Try building the docker image
docker build -t cpu-test -f Dockerfile.cpu . numactl -C 48-95 -N 1 docker build -t cpu-test -f Dockerfile.cpu .
docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu . numactl -C 48-95 -N 1 docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu .
# Setup cleanup # Setup cleanup
remove_docker_container() { docker rm -f cpu-test cpu-test-avx2 || true; } remove_docker_container() { docker rm -f cpu-test cpu-test-avx2 || true; }
trap remove_docker_container EXIT trap remove_docker_container EXIT
remove_docker_container remove_docker_container
# Run the image # Run the image, setting --shm-size=4g for tensor parallel.
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \
--cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test --cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \
--cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test-avx2 cpu-test-avx2 --cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2 cpu-test-avx2
# offline inference # offline inference
docker exec cpu-test bash -c "python3 examples/offline_inference.py"
docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
# Run basic model test # Run basic model test
docker exec cpu-test bash -c "cd tests; docker exec cpu-test bash -c "
pip install pytest Pillow protobuf pip install pytest Pillow protobuf
cd ../ pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported
# online inference
docker exec cpu-test bash -c "
export VLLM_CPU_KVCACHE_SPACE=10
export VLLM_CPU_OMP_THREADS_BIND=48-92
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \
--backend vllm \
--dataset-name random \
--model facebook/opt-125m \
--num-prompts 20 \
--endpoint /v1/completions \
--tokenizer facebook/opt-125m"

View File

@@ -17,11 +17,10 @@ steps:
- pytest -v -s test_utils.py # Utils - pytest -v -s test_utils.py # Utils
- pytest -v -s worker # Worker - pytest -v -s worker # Worker
- label: Tensorizer, Metrics, Tracing Test - label: Metrics, Tracing Test
fast_check: true fast_check: true
fast_check_only: true fast_check_only: true
commands: commands:
- apt-get install -y curl libsodium23 && pytest -v -s tensorizer_loader # Tensorizer
- pytest -v -s metrics # Metrics - pytest -v -s metrics # Metrics
- "pip install \ - "pip install \
opentelemetry-sdk \ opentelemetry-sdk \
@@ -45,7 +44,7 @@ steps:
fast_check: true fast_check: true
commands: commands:
# This flashinfer installation will fail on AMD ROCm, so it is set as optional. # This flashinfer installation will fail on AMD ROCm, so it is set as optional.
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl || true - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl || true
- pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_basic_correctness.py
- pytest -v -s basic_correctness/test_cpu_offload.py - pytest -v -s basic_correctness/test_cpu_offload.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
@@ -57,7 +56,6 @@ steps:
fast_check: true fast_check: true
commands: commands:
- pytest -v -s core - pytest -v -s core
- pytest -v -s distributed/test_parallel_state.py
- label: Distributed Comm Ops Test - label: Distributed Comm Ops Test
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
@@ -84,20 +82,9 @@ steps:
num_gpus: 2 num_gpus: 2
commands: commands:
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
@@ -109,11 +96,6 @@ steps:
fast_check: true fast_check: true
commands: commands:
- pytest -v -s distributed/test_pynccl.py - pytest -v -s distributed/test_pynccl.py
# We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here.
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
- label: Pipeline Parallelism Test - label: Pipeline Parallelism Test
@@ -141,14 +123,13 @@ steps:
working_dir: "/vllm-workspace/examples" working_dir: "/vllm-workspace/examples"
mirror_hardwares: [amd] mirror_hardwares: [amd]
commands: commands:
# install aws cli for llava_example.py
# install tensorizer for tensorize_vllm_model.py # install tensorizer for tensorize_vllm_model.py
- pip install awscli tensorizer - pip install awscli tensorizer
- python3 offline_inference.py - python3 offline_inference.py
- python3 cpu_offload.py - python3 cpu_offload.py
- python3 offline_inference_with_prefix.py - python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py - python3 llm_engine_example.py
- python3 llava_example.py - python3 offline_inference_vision_language.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- label: Inputs Test - label: Inputs Test
@@ -157,17 +138,17 @@ steps:
- pytest -v -s test_inputs.py - pytest -v -s test_inputs.py
- pytest -v -s multimodal - pytest -v -s multimodal
- label: Kernels Test %N # - label: Kernels Test %N
#mirror_hardwares: [amd] # #mirror_hardwares: [amd]
commands: # commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl # - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT # - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4 # parallelism: 4
- label: Models Test - label: Models Test
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
commands: commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
- pytest -v -s models -m \"not vlm\" - pytest -v -s models -m \"not vlm\"
- label: Vision Language Models Test - label: Vision Language Models Test
@@ -204,23 +185,24 @@ steps:
- export VLLM_ATTENTION_BACKEND=XFORMERS - export VLLM_ATTENTION_BACKEND=XFORMERS
- pytest -v -s spec_decode - pytest -v -s spec_decode
- label: LoRA Test %N # - label: LoRA Test %N
#mirror_hardwares: [amd] # #mirror_hardwares: [amd]
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py # command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
parallelism: 4 # parallelism: 4
- label: LoRA Long Context (Distributed) # - label: LoRA Long Context (Distributed)
#mirror_hardwares: [amd] # #mirror_hardwares: [amd]
num_gpus: 4 # num_gpus: 4
# This test runs llama 13B, so it is required to run on 4 GPUs. # # This test runs llama 13B, so it is required to run on 4 GPUs.
commands: # commands:
# FIXIT: find out which code initialize cuda before running the test # # FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it # # before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn # - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s -x lora/test_long_context.py # - pytest -v -s -x lora/test_long_context.py
- label: Tensorizer Test - label: Tensorizer Test
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
fast_check: true
commands: commands:
- apt-get install -y curl libsodium23 - apt-get install -y curl libsodium23
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
@@ -281,9 +263,6 @@ steps:
# NOTE: don't test llama model here, it seems hf implementation is buggy # NOTE: don't test llama model here, it seems hf implementation is buggy
# see https://github.com/vllm-project/vllm/pull/5689 for details # see https://github.com/vllm-project/vllm/pull/5689 for details
- pytest -v -s distributed/test_custom_all_reduce.py - pytest -v -s distributed/test_custom_all_reduce.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s -x lora/test_mixtral.py - pytest -v -s -x lora/test_mixtral.py

View File

@@ -30,12 +30,6 @@ jobs:
run: | run: |
EXCLUDES=( EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu' 'csrc/moe/topk_softmax_kernels.cu'
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
'csrc/punica/bgmv/bgmv_config.h'
'csrc/punica/bgmv/bgmv_impl.cuh'
'csrc/punica/bgmv/vec_dtypes.cuh'
'csrc/punica/punica_ops.cu'
'csrc/punica/type_convert.h'
) )
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"] python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
@@ -32,22 +32,17 @@ jobs:
pip install types-setuptools pip install types-setuptools
- name: Mypy - name: Mypy
run: | run: |
mypy tests --config-file pyproject.toml mypy
mypy vllm/*.py --config-file pyproject.toml mypy tests --follow-imports skip
mypy vllm/attention --config-file pyproject.toml mypy vllm/attention --follow-imports skip
mypy vllm/core --config-file pyproject.toml mypy vllm/core --follow-imports skip
mypy vllm/distributed --config-file pyproject.toml mypy vllm/distributed --follow-imports skip
mypy vllm/engine --config-file pyproject.toml mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --config-file pyproject.toml mypy vllm/executor --follow-imports skip
mypy vllm/inputs --config-file pyproject.toml mypy vllm/lora --follow-imports skip
mypy vllm/logging --config-file pyproject.toml mypy vllm/model_executor --follow-imports skip
mypy vllm/lora --config-file pyproject.toml mypy vllm/prompt_adapter --follow-imports skip
mypy vllm/model_executor --config-file pyproject.toml mypy vllm/spec_decode --follow-imports skip
mypy vllm/multimodal --config-file pyproject.toml mypy vllm/worker --follow-imports skip
mypy vllm/platforms --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml

View File

@@ -48,8 +48,8 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
os: ['ubuntu-20.04'] os: ['ubuntu-20.04']
python-version: ['3.8', '3.9', '3.10', '3.11'] python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
pytorch-version: ['2.3.1'] # Must be the most recent version that meets requirements-cuda.txt. pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt.
cuda-version: ['11.8', '12.1'] cuda-version: ['11.8', '12.1']
steps: steps:

View File

@@ -0,0 +1,23 @@
name: Remove ready Label on notready Comment
on:
issue_comment:
types: [created]
jobs:
add-ready-label:
runs-on: ubuntu-latest
if: github.event.issue.pull_request && contains(github.event.comment.body, '/notready')
steps:
- name: Remove ready label
uses: actions/github-script@v5
with:
script: |
github.rest.issues.removeLabel({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
name: 'ready'
})
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"] python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}

View File

@@ -13,8 +13,6 @@ $python_executable -m pip install -r requirements-cuda.txt
# Limit the number of parallel jobs to avoid OOM # Limit the number of parallel jobs to avoid OOM
export MAX_JOBS=1 export MAX_JOBS=1
# Make sure punica is built for the release (for LoRA)
export VLLM_INSTALL_PUNICA_KERNELS=1
# Make sure release wheels are built for the following architectures # Make sure release wheels are built for the following architectures
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
# Build # Build

View File

@@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"] python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}

View File

@@ -10,6 +10,7 @@ build:
sphinx: sphinx:
configuration: docs/source/conf.py configuration: docs/source/conf.py
fail_on_warning: true
# If using Sphinx, optionally build your docs in additional formats such as PDF # If using Sphinx, optionally build your docs in additional formats such as PDF
formats: formats:

View File

@@ -14,7 +14,7 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
# Supported python versions. These versions will be searched in order, the # Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py. # first match will be selected. These should be kept in sync with setup.py.
# #
set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
# Supported NVIDIA architectures. # Supported NVIDIA architectures.
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
@@ -32,7 +32,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# requirements.txt files and should be kept consistent. The ROCm torch # requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from Dockerfile.rocm # versions are derived from Dockerfile.rocm
# #
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1") set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0")
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0") set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
# #
@@ -66,6 +66,39 @@ endif()
# #
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
#
# Add the `default` target which detects which extensions should be
# built based on platform/architecture. This is the same logic that
# setup.py uses to select which extensions should be built and should
# be kept in sync.
#
# The `default` target makes direct use of cmake easier since knowledge
# of which extensions are supported has been factored in, e.g.
#
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
# cmake --build . --target default
#
add_custom_target(default)
message(STATUS "Enabling core extension.")
# Define _core_C extension
# built for (almost) every target platform, (excludes TPU and Neuron)
set(VLLM_EXT_SRC
"csrc/core/torch_bindings.cpp")
define_gpu_extension_target(
_core_C
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3
WITH_SOABI)
add_dependencies(default _core_C)
# #
# Forward the non-CUDA device extensions to external CMake scripts. # Forward the non-CUDA device extensions to external CMake scripts.
# #
@@ -74,7 +107,7 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
if (VLLM_TARGET_DEVICE STREQUAL "cpu") if (VLLM_TARGET_DEVICE STREQUAL "cpu")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
else() else()
message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}") return()
endif() endif()
return() return()
endif() endif()
@@ -132,7 +165,7 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
# #
# Define extension targets # Define other extension targets
# #
# #
@@ -156,12 +189,13 @@ set(VLLM_EXT_SRC
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent) include(FetchContent)
SET(CUTLASS_ENABLE_HEADERS_ONLY=ON) SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
FetchContent_Declare( FetchContent_Declare(
cutlass cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git GIT_REPOSITORY https://github.com/nvidia/cutlass.git
# CUTLASS 3.5.0 # CUTLASS 3.5.1
GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9
GIT_PROGRESS TRUE
) )
FetchContent_MakeAvailable(cutlass) FetchContent_MakeAvailable(cutlass)
@@ -170,6 +204,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
@@ -200,7 +235,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_EXT_SRC} SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS} COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES} ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
@@ -222,76 +257,7 @@ define_gpu_extension_target(
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
#
# _punica_C extension
#
set(VLLM_PUNICA_EXT_SRC
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
"csrc/punica/punica_ops.cu"
"csrc/punica/torch_bindings.cpp")
#
# Copy GPU compilation flags+update for punica
#
set(VLLM_PUNICA_GPU_FLAGS ${VLLM_GPU_FLAGS})
list(REMOVE_ITEM VLLM_PUNICA_GPU_FLAGS
"-D__CUDA_NO_HALF_OPERATORS__"
"-D__CUDA_NO_HALF_CONVERSIONS__"
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
"-D__CUDA_NO_HALF2_OPERATORS__")
#
# Filter out CUDA architectures < 8.0 for punica.
#
if (${VLLM_GPU_LANG} STREQUAL "CUDA")
set(VLLM_PUNICA_GPU_ARCHES)
foreach(ARCH ${VLLM_GPU_ARCHES})
string_to_ver(CODE_VER ${ARCH})
if (CODE_VER GREATER_EQUAL 8.0)
list(APPEND VLLM_PUNICA_GPU_ARCHES ${ARCH})
endif()
endforeach()
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
elseif(${VLLM_GPU_LANG} STREQUAL "HIP")
set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES})
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
endif()
if (VLLM_PUNICA_GPU_ARCHES)
define_gpu_extension_target(
_punica_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_PUNICA_EXT_SRC}
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
USE_SABI 3
WITH_SOABI)
else()
message(WARNING "Unable to create _punica_C target because none of the "
"requested architectures (${VLLM_GPU_ARCHES}) are supported, i.e. >= 8.0")
endif()
#
# Add the `default` target which detects which extensions should be
# built based on platform/architecture. This is the same logic that
# setup.py uses to select which extensions should be built and should
# be kept in sync.
#
# The `default` target makes direct use of cmake easier since knowledge
# of which extensions are supported has been factored in, e.g.
#
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
# cmake --build . --target default
#
add_custom_target(default)
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling C extension.") message(STATUS "Enabling C extension.")
@@ -300,12 +266,4 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling moe extension.") message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C) add_dependencies(default _moe_C)
# Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
# VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
# there are supported target arches.
if (VLLM_PUNICA_GPU_ARCHES AND
(ENV{VLLM_INSTALL_PUNICA_KERNELS} OR VLLM_INSTALL_PUNICA_KERNELS))
message(STATUS "Enabling punica extension.")
add_dependencies(default _punica_C)
endif()
endif() endif()

View File

@@ -42,6 +42,7 @@ WORKDIR /workspace
# install build and runtime dependencies # install build and runtime dependencies
COPY requirements-common.txt requirements-common.txt COPY requirements-common.txt requirements-common.txt
COPY requirements-adag.txt requirements-adag.txt
COPY requirements-cuda.txt requirements-cuda.txt COPY requirements-cuda.txt requirements-cuda.txt
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-cuda.txt python3 -m pip install -r requirements-cuda.txt
@@ -78,6 +79,7 @@ COPY setup.py setup.py
COPY cmake cmake COPY cmake cmake
COPY CMakeLists.txt CMakeLists.txt COPY CMakeLists.txt CMakeLists.txt
COPY requirements-common.txt requirements-common.txt COPY requirements-common.txt requirements-common.txt
COPY requirements-adag.txt requirements-adag.txt
COPY requirements-cuda.txt requirements-cuda.txt COPY requirements-cuda.txt requirements-cuda.txt
COPY pyproject.toml pyproject.toml COPY pyproject.toml pyproject.toml
COPY vllm vllm COPY vllm vllm
@@ -88,8 +90,6 @@ ENV MAX_JOBS=${max_jobs}
# number of threads used by nvcc # number of threads used by nvcc
ARG nvcc_threads=8 ARG nvcc_threads=8
ENV NVCC_THREADS=$nvcc_threads ENV NVCC_THREADS=$nvcc_threads
# make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
ARG buildkite_commit ARG buildkite_commit
ENV BUILDKITE_COMMIT=${buildkite_commit} ENV BUILDKITE_COMMIT=${buildkite_commit}
@@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.9/flashinfer-0.0.9+cu121torch2.3-cp310-cp310-linux_x86_64.whl python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
#################### vLLM installation IMAGE #################### #################### vLLM installation IMAGE ####################

View File

@@ -2,8 +2,8 @@
FROM ubuntu:22.04 AS cpu-test-1 FROM ubuntu:22.04 AS cpu-test-1
RUN apt-get update -y \ RUN apt-get update -y \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \ && apt-get install -y curl git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
# https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html # https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html
@@ -13,8 +13,9 @@ RUN pip install intel-openmp
ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD" ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD"
RUN echo 'ulimit -c 0' >> ~/.bashrc
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl
RUN pip install --upgrade pip \ RUN pip install --upgrade pip \
&& pip install wheel packaging ninja "setuptools>=49.4.0" numpy && pip install wheel packaging ninja "setuptools>=49.4.0" numpy

View File

@@ -1,7 +1,7 @@
# The vLLM Dockerfile is used to construct vLLM image that can be directly used # The vLLM Dockerfile is used to construct vLLM image that can be directly used
# to run the OpenAI compatible server. # to run the OpenAI compatible server.
FROM ubuntu:20.04 AS dev FROM ubuntu:22.04 AS dev
RUN apt-get update -y && \ RUN apt-get update -y && \
apt-get install -y python3-pip git apt-get install -y python3-pip git
@@ -13,12 +13,15 @@ COPY requirements-common.txt /workspace/vllm/
COPY requirements-openvino.txt /workspace/vllm/ COPY requirements-openvino.txt /workspace/vllm/
COPY vllm/ /workspace/vllm/vllm COPY vllm/ /workspace/vllm/vllm
COPY csrc/core /workspace/vllm/csrc/core
COPY cmake/utils.cmake /workspace/vllm/cmake/
COPY CMakeLists.txt /workspace/vllm/
COPY setup.py /workspace/vllm/ COPY setup.py /workspace/vllm/
# install build requirements # install build requirements
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt
# build vLLM with OpenVINO backend # build vLLM with OpenVINO backend
RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/ RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/pre-release" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/
COPY examples/ /workspace/vllm/examples COPY examples/ /workspace/vllm/examples
COPY benchmarks/ /workspace/vllm/benchmarks COPY benchmarks/ /workspace/vllm/benchmarks

View File

@@ -53,10 +53,10 @@ RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(whic
# Install torch == 2.5.0 on ROCm # Install torch == 2.5.0 on ROCm
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-6.1"*) \ *"rocm-6.1"*) \
python3 -m pip uninstall -y torch torchaudio torchvision \ python3 -m pip uninstall -y torch torchvision \
&& python3 -m pip install --no-cache-dir --pre \ && python3 -m pip install --no-cache-dir --pre \
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \ torch==2.5.0.dev20240726 \
torchvision==0.20.0.dev20240710 \ torchvision==0.20.0.dev20240726 \
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \ --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
*) ;; esac *) ;; esac
@@ -127,19 +127,11 @@ FROM base AS final
# Import the vLLM development directory from the build context # Import the vLLM development directory from the build context
COPY . . COPY . .
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
# Manually remove it so that later steps of numpy upgrade can continue
RUN case "$(which python3)" in \
*"/opt/conda/envs/py_3.9"*) \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
*) ;; esac
# Package upgrades for useful functionality or to avoid dependency issues # Package upgrades for useful functionality or to avoid dependency issues
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install --upgrade numba scipy huggingface-hub[cli] python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
# Make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
# Workaround for ray >= 2.10.0 # Workaround for ray >= 2.10.0
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
# Silences the HF Tokenizers warning # Silences the HF Tokenizers warning

View File

@@ -1,4 +1,4 @@
ARG NIGHTLY_DATE="20240713" ARG NIGHTLY_DATE="20240726"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
FROM $BASE_IMAGE FROM $BASE_IMAGE
@@ -12,6 +12,9 @@ RUN pip install "numpy<2"
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Fix FastAPI dependence
RUN pip install "starlette<0.38.0"
# Build vLLM. # Build vLLM.
COPY . /workspace/vllm COPY . /workspace/vllm
ENV VLLM_TARGET_DEVICE="tpu" ENV VLLM_TARGET_DEVICE="tpu"

View File

@@ -1,4 +1,5 @@
include LICENSE include LICENSE
include requirements-adag.txt
include requirements-common.txt include requirements-common.txt
include requirements-cuda.txt include requirements-cuda.txt
include requirements-rocm.txt include requirements-rocm.txt

View File

@@ -16,16 +16,8 @@ Easy, fast, and cheap LLM serving for everyone
--- ---
**The Fifth vLLM Bay Area Meetup (July 24th 5pm-8pm PT)**
We are excited to announce our fifth vLLM Meetup!
Join us to hear the vLLM's recent updates and the upcoming roadmap.
Additionally, our collaborators from AWS will be presenting their insights and experiences in deploying vLLM.
Register now [here](https://lu.ma/lp0gyjqr) and be part of the event!
---
*Latest News* 🔥 *Latest News* 🔥
- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing).
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html). - [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing). - [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing). - [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
@@ -47,7 +39,7 @@ vLLM is fast with:
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache - Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache
- Optimized CUDA kernels - Optimized CUDA kernels
**Performance benchmark**: We include a [performance benchmark](https://buildkite.com/vllm/performance-benchmark/builds/3924) that compares the performance of vllm against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [text-generation-inference](https://github.com/huggingface/text-generation-inference) and [lmdeploy](https://github.com/InternLM/lmdeploy)). **Performance benchmark**: We include a [performance benchmark](https://buildkite.com/vllm/performance-benchmark/builds/4068) that compares the performance of vllm against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [text-generation-inference](https://github.com/huggingface/text-generation-inference) and [lmdeploy](https://github.com/InternLM/lmdeploy)).
vLLM is flexible and easy to use with: vLLM is flexible and easy to use with:

View File

@@ -13,7 +13,7 @@ from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:] DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1] DEFAULT_TP_SIZES = [1]
@@ -112,13 +112,20 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
timers = [] timers = []
# pytorch impl # pytorch impl - bfloat16
timers.append( timers.append(
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
torch.bfloat16, label, sub_label, pytorch_mm_impl, torch.bfloat16, label, sub_label, pytorch_mm_impl,
"pytorch_bf16_bf16_bf16_matmul-no-scales")) "pytorch_bf16_bf16_bf16_matmul-no-scales"))
# pytorch impl - float16
timers.append(
bench_fn(a.to(dtype=torch.float16, device="cuda"),
b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b,
torch.float16, label, sub_label, pytorch_mm_impl,
"pytorch_fp16_fp16_fp16_matmul-no-scales"))
# cutlass impl # cutlass impl
timers.append( timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,

View File

@@ -7,16 +7,17 @@ from benchmark_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace, marlin_quantize) MarlinWorkspace, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize) marlin_24_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights) gptq_pack, gptq_quantize_weights, sort_weights)
from vllm.scalar_type import ScalarType
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
@@ -27,13 +28,14 @@ K_FULL_OPTS = [False, True]
def bench_run(results: List[benchmark.Measurement], model: str, def bench_run(results: List[benchmark.Measurement], model: str,
act_order: bool, is_k_full: bool, num_bits: int, group_size: int, act_order: bool, is_k_full: bool, quant_type: ScalarType,
size_m: int, size_k: int, size_n: int): group_size: int, size_m: int, size_k: int, size_n: int):
label = "Quant Matmul" label = "Quant Matmul"
sub_label = ("{}, act={} k_full={}, b={}, g={}, " sub_label = ("{}, act={} k_full={}, q={}, g={}, "
"MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits, "MKN=({}x{}x{})".format(model, act_order, is_k_full,
group_size, size_m, size_k, size_n)) str(quant_type), group_size, size_m,
size_k, size_n))
print(f"Testing: {sub_label}") print(f"Testing: {sub_label}")
@@ -50,16 +52,18 @@ def bench_run(results: List[benchmark.Measurement], model: str,
marlin_g_idx, marlin_g_idx,
marlin_sort_indices, marlin_sort_indices,
marlin_rand_perm, marlin_rand_perm,
) = marlin_quantize(b, num_bits, group_size, act_order) ) = marlin_quantize(b, quant_type, group_size, act_order)
# Marlin_24 quant # Marlin_24 quant
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size) marlin_24_s) = marlin_24_quantize(b, quant_type, group_size)
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
# GPTQ quant # GPTQ quant
(w_ref, q_w, s, g_idx, (w_ref, q_w, s, g_idx,
rand_perm) = quantize_weights(b, num_bits, group_size, act_order) rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order)
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx" # For act_order, sort the "weights" and "g_idx"
# so that group ids are increasing # so that group ids are increasing
@@ -73,10 +77,11 @@ def bench_run(results: List[benchmark.Measurement], model: str,
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL) GPTQ_MARLIN_24_MAX_PARALLEL)
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
globals = { globals = {
# Gen params # Gen params
"num_bits": num_bits, "quant_type": quant_type,
"group_size": group_size, "group_size": group_size,
"size_m": size_m, "size_m": size_m,
"size_n": size_n, "size_n": size_n,
@@ -87,6 +92,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
"marlin_w_ref": marlin_w_ref, "marlin_w_ref": marlin_w_ref,
"marlin_q_w": marlin_q_w, "marlin_q_w": marlin_q_w,
"marlin_s": marlin_s, "marlin_s": marlin_s,
"marlin_zp": marlin_zp,
"marlin_g_idx": marlin_g_idx, "marlin_g_idx": marlin_g_idx,
"marlin_sort_indices": marlin_sort_indices, "marlin_sort_indices": marlin_sort_indices,
"marlin_rand_perm": marlin_rand_perm, "marlin_rand_perm": marlin_rand_perm,
@@ -125,19 +131,29 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt= stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501 "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
description="gptq_marlin_gemm", description="gptq_marlin_gemm_fp16",
).blocked_autorange(min_run_time=min_run_time)) ).blocked_autorange(min_run_time=min_run_time))
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm_fp32",
).blocked_autorange(min_run_time=min_run_time))
if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt= stmt=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501 "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
@@ -147,7 +163,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt= stmt=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501 "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
@@ -183,12 +199,13 @@ def main(args):
) > 0 and is_k_full not in args.limit_k_full: ) > 0 and is_k_full not in args.limit_k_full:
continue continue
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS: for quant_type in query_marlin_supported_quant_types(
if len(args.limit_num_bits False):
) > 0 and num_bits not in args.limit_num_bits: if len(args.limit_num_bits) > 0 and \
quant_type.size_bits not in args.limit_num_bits:
continue continue
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
if len( if len(
args.limit_group_size args.limit_group_size
) > 0 and group_size not in args.limit_group_size: ) > 0 and group_size not in args.limit_group_size:
@@ -202,8 +219,8 @@ def main(args):
for size_m in args.batch_sizes: for size_m in args.batch_sizes:
bench_run(results, model, act_order, is_k_full, bench_run(results, model, act_order, is_k_full,
num_bits, group_size, size_m, size_k, quant_type, group_size, size_m,
size_n) size_k, size_n)
compare = benchmark.Compare(results) compare = benchmark.Compare(results)
compare.print() compare.print()

View File

@@ -175,7 +175,7 @@ if __name__ == '__main__':
parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument("--head-size",
type=int, type=int,
choices=[64, 80, 96, 112, 128, 192, 256], choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128) default=128)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true") parser.add_argument("--use-alibi", action="store_true")

View File

@@ -94,7 +94,7 @@ if __name__ == '__main__':
parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument("--head-size",
type=int, type=int,
choices=[64, 80, 96, 112, 128, 192, 256], choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128) default=128)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype", parser.add_argument("--dtype",

View File

@@ -83,6 +83,8 @@ endif()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
list(APPEND LIBS "numa")
# #
# Define extension targets # Define extension targets
@@ -95,6 +97,7 @@ set(VLLM_EXT_SRC
"csrc/cpu/activation.cpp" "csrc/cpu/activation.cpp"
"csrc/cpu/attention.cpp" "csrc/cpu/attention.cpp"
"csrc/cpu/cache.cpp" "csrc/cpu/cache.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/layernorm.cpp" "csrc/cpu/layernorm.cpp"
"csrc/cpu/pos_encoding.cpp" "csrc/cpu/pos_encoding.cpp"
"csrc/cpu/torch_bindings.cpp") "csrc/cpu/torch_bindings.cpp")
@@ -104,11 +107,11 @@ define_gpu_extension_target(
DESTINATION vllm DESTINATION vllm
LANGUAGE CXX LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC} SOURCES ${VLLM_EXT_SRC}
LIBRARIES ${LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS} COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3 USE_SABI 3
WITH_SOABI WITH_SOABI
) )
add_custom_target(default)
message(STATUS "Enabling C extension.") message(STATUS "Enabling C extension.")
add_dependencies(default _C) add_dependencies(default _C)

View File

@@ -181,7 +181,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
# #
# The torch cmake setup hardcodes the detected architecture flags in # The torch cmake setup hardcodes the detected architecture flags in
# `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it # `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it
# can't modified on a per-target basis, e.g. for the `punica` extension. # can't modified on a per-target basis.
# So, all the `-gencode` flags need to be extracted and removed from # So, all the `-gencode` flags need to be extracted and removed from
# `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method. # `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method.
# Since it's not possible to use `target_compiler_options` for adding target # Since it's not possible to use `target_compiler_options` for adding target

View File

@@ -65,6 +65,7 @@ DEFAULT_CONDA_PATTERNS = {
"optree", "optree",
"nccl", "nccl",
"transformers", "transformers",
"zmq",
} }
DEFAULT_PIP_PATTERNS = { DEFAULT_PIP_PATTERNS = {
@@ -77,6 +78,7 @@ DEFAULT_PIP_PATTERNS = {
"onnx", "onnx",
"nccl", "nccl",
"transformers", "transformers",
"zmq",
} }

View File

@@ -706,7 +706,7 @@ void paged_attention_v1_launcher(
int kv_block_stride = key_cache.stride(0); int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1); int kv_head_stride = key_cache.stride(1);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
@@ -751,6 +751,9 @@ void paged_attention_v1_launcher(
case 112: case 112:
LAUNCH_PAGED_ATTENTION_V1(112); LAUNCH_PAGED_ATTENTION_V1(112);
break; break;
case 120:
LAUNCH_PAGED_ATTENTION_V1(120);
break;
case 128: case 128:
LAUNCH_PAGED_ATTENTION_V1(128); LAUNCH_PAGED_ATTENTION_V1(128);
break; break;
@@ -862,7 +865,7 @@ void paged_attention_v2_launcher(
int kv_block_stride = key_cache.stride(0); int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1); int kv_head_stride = key_cache.stride(1);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
@@ -912,6 +915,9 @@ void paged_attention_v2_launcher(
case 112: case 112:
LAUNCH_PAGED_ATTENTION_V2(112); LAUNCH_PAGED_ATTENTION_V2(112);
break; break;
case 120:
LAUNCH_PAGED_ATTENTION_V2(120);
break;
case 128: case 128:
LAUNCH_PAGED_ATTENTION_V2(128); LAUNCH_PAGED_ATTENTION_V2(128);
break; break;

View File

@@ -94,6 +94,7 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#else #else
return __bfloat1622float2(val); return __bfloat1622float2(val);
#endif #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
@@ -102,6 +103,7 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#else #else
return __bfloat162bfloat162(val); return __bfloat162bfloat162(val);
#endif #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
// Vector addition. // Vector addition.
@@ -115,6 +117,7 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
return __hadd(a, b); return __hadd(a, b);
#endif #endif
#endif #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
@@ -123,6 +126,7 @@ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
#else #else
return __hadd2(a, b); return __hadd2(a, b);
#endif #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
@@ -170,6 +174,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#else #else
return __hmul(a, b); return __hmul(a, b);
#endif #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
template <> template <>
@@ -179,6 +184,7 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#else #else
return __hmul2(a, b); return __hmul2(a, b);
#endif #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
template <> template <>
@@ -289,6 +295,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
#else #else
return __hfma2(a, b, c); return __hfma2(a, b, c);
#endif #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
@@ -298,6 +305,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
#else #else
return __hfma2(bf162bf162(a), b, c); return __hfma2(bf162bf162(a), b, c);
#endif #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {

View File

@@ -25,7 +25,8 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype); const std::string& kv_cache_dtype,
const double k_scale, const double v_scale);
// Just for unittest // Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,

View File

@@ -203,17 +203,18 @@ __global__ void reshape_and_cache_kernel(
} }
} }
template <typename scalar_t> template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_flash_kernel( __global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
// head_size] // head_size]
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
// head_size] // head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens] const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, const int key_stride, const int value_stride, const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size) { const int num_heads, const int head_size, const int block_size,
const float k_scale, const float v_scale) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded // NOTE: slot_idx can be -1 if the token is padded
@@ -228,11 +229,20 @@ __global__ void reshape_and_cache_flash_kernel(
const int64_t src_value_idx = token_idx * value_stride + i; const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size; const int head_idx = i / head_size;
const int head_offset = i % head_size; const int head_offset = i % head_size;
const int64_t tgt_value_idx = block_idx * block_stride + const int64_t tgt_key_value_idx = block_idx * block_stride +
block_offset * num_heads * head_size + block_offset * num_heads * head_size +
head_idx * head_size + head_offset; head_idx * head_size + head_offset;
k_cache[tgt_value_idx] = key[src_key_idx]; scalar_t tgt_key = key[src_key_idx];
v_cache[tgt_value_idx] = value[src_value_idx]; scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_value_idx] = tgt_key;
value_cache[tgt_key_value_idx] = tgt_value;
} else {
key_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
value_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
}
} }
} }
} // namespace vllm } // namespace vllm
@@ -278,40 +288,45 @@ void reshape_and_cache(
CALL_RESHAPE_AND_CACHE) CALL_RESHAPE_AND_CACHE)
} }
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
void reshape_and_cache_flash( void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor&
value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype) { const std::string& kv_cache_dtype, const double k_scale,
// FIXME: only support auto datatype, does not support fp8 const double v_scale) {
if (kv_cache_dtype != "auto") {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);
int block_size = k_cache.size(1); int block_size = key_cache.size(1);
int key_stride = key.stride(0); int key_stride = key.stride(0);
int value_stride = value.stride(0); int value_stride = value.stride(0);
int block_stride = k_cache.stride(0); int block_stride = key_cache.stride(0);
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0)); TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512)); dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "reshape_and_cache_flash", [&] { DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
vllm::reshape_and_cache_flash_kernel<scalar_t> CALL_RESHAPE_AND_CACHE_FLASH);
<<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
value_stride, num_heads, head_size, block_size);
});
} }
namespace vllm { namespace vllm {

382
csrc/core/scalar_type.hpp Normal file
View File

@@ -0,0 +1,382 @@
#pragma once
#include <torch/custom_class.h>
namespace vllm {
//
// ScalarType can represent a wide range of floating point and integer types,
// in particular it can be used to represent sub-byte data types (something
// that torch.dtype currently does not support).
//
// ScalarTypeTorch is a subclass of ScalarType that is compatible with
// TORCH_LIBRARY, making it accessible from Python as well meaning this class
// can be used as a argument for custom operators, helping to simplify these
// interfaces.
//
// The type definitions on the Python side can be found in: vllm/_core_ext.pyi
// these type definitions should be kept up to date with any Python API changes
// here.
//
class ScalarType {
public:
enum NanRepr : int64_t {
NAN_NONE = 0, // nans are not supported
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
NAN_REPR_ID_MAX
};
constexpr ScalarType(bool signed_, int64_t exponent, int64_t mantissa,
int64_t bias, bool finite_values_only = false,
NanRepr nan_repr = NAN_IEEE_754)
: exponent(exponent),
mantissa(mantissa),
bias(bias),
signed_(signed_),
finite_values_only(finite_values_only),
nan_repr(nan_repr){};
static constexpr ScalarType int_(int64_t size_bits, int64_t bias = 0) {
return ScalarType(true, 0, size_bits - 1, bias);
}
static constexpr ScalarType uint(int64_t size_bits, int64_t bias = 0) {
return ScalarType(false, 0, size_bits, bias);
}
// IEEE 754 compliant floating point type
static constexpr ScalarType float_IEEE754(int64_t exponent,
int64_t mantissa) {
TORCH_CHECK(mantissa > 0 && exponent > 0);
return ScalarType(true, exponent, mantissa, 0, false, NAN_IEEE_754);
}
// IEEE 754 non-compliant floating point type
static constexpr ScalarType float_(int64_t exponent, int64_t mantissa,
bool finite_values_only,
NanRepr nan_repr) {
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
TORCH_CHECK(mantissa > 0 && exponent > 0);
TORCH_CHECK(nan_repr != NAN_IEEE_754,
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions");
return ScalarType(true, exponent, mantissa, 0, finite_values_only,
nan_repr);
}
int64_t const exponent; // size of the exponent field (0 for integer types)
int64_t const mantissa; // size of the mantissa field (size of the integer
// excluding the sign bit for integer types)
int64_t const bias; // stored values equal value + bias,
// used for quantized type
bool const signed_; // flag if the type supports negative numbers (i.e. has a
// sign bit)
// Extra Floating point info
bool const finite_values_only; // i.e. no +/-inf if true
NanRepr const nan_repr; // how NaNs are represented
// (not applicable for integer types)
int64_t size_bits() const { return mantissa + exponent + is_signed(); }
bool is_signed() const { return signed_; }
bool is_integer() const { return exponent == 0; }
bool is_floating_point() const { return exponent > 0; }
bool is_ieee_754() const {
return is_floating_point() && finite_values_only == false &&
nan_repr == NAN_IEEE_754;
}
bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; }
bool has_infs() const {
return is_floating_point() && finite_values_only == false;
}
bool has_bias() const { return bias != 0; }
private:
double _floating_point_max() const {
TORCH_CHECK(mantissa <= 52 && exponent <= 11,
"Cannot represent max/min as a double for type ", str());
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
max_mantissa -= 1;
}
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
TORCH_CHECK(exponent < 11,
"Cannot represent max/min as a double for type ", str());
max_exponent += 1;
}
// adjust the exponent to match that of a double
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
// is the exponent bits), there is some precedent for non-standard biases,
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
// but to avoid premature over complication we are just assuming the
// standard exponent bias until there is a need to support non-standard
// biases
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
uint64_t max_exponent_double =
max_exponent - exponent_bias + exponent_bias_double;
// shift the mantissa into the position for a double and
// the exponent
uint64_t double_raw =
(max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
return *reinterpret_cast<double*>(&double_raw);
}
std::variant<int64_t, double> _raw_max() const {
if (is_floating_point()) {
return {_floating_point_max()};
} else {
TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(),
"Cannot represent max as a int64_t");
return {(int64_t(1) << mantissa) - 1};
}
}
std::variant<int64_t, double> _raw_min() const {
if (is_floating_point()) {
TORCH_CHECK(is_signed(),
"We currently assume all floating point types are signed");
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
double max = _floating_point_max();
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
uint64_t min_raw = max_raw | sign_bit_double;
return {*reinterpret_cast<double*>(&min_raw)};
} else {
TORCH_CHECK(!is_signed() || size_bits() <= 64,
"Cannot represent min as a int64_t");
if (is_signed()) {
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
// then perform an arithmetic shift right to set all the bits above
// (size_bits() - 1) to 1
return {INT64_MIN >> (64 - size_bits())};
} else {
return {int64_t(0)};
}
}
}
public:
// Max representable value for this scalar type.
// (accounting for bias if there is one)
std::variant<int64_t, double> max() const {
return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_max());
}
// Min representable value for this scalar type.
// (accounting for bias if there is one)
std::variant<int64_t, double> min() const {
return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_min());
}
std::string str() const {
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
* for floating point types (leading f) the scheme is:
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
* flags:
* - no-flags: means it follows IEEE 754 conventions
* - f: means finite values only (no infinities)
* - n: means nans are supported (non-standard encoding)
* for integer types the scheme is:
* `[u]int<size_bits>[b<bias>]`
* - if bias is not present it means its zero
*/
if (is_floating_point()) {
auto ret = "float" + std::to_string(size_bits()) + "_e" +
std::to_string(exponent) + "m" + std::to_string(mantissa);
if (!is_ieee_754()) {
if (finite_values_only) {
ret += "f";
}
if (nan_repr != NAN_NONE) {
ret += "n";
}
}
return ret;
} else {
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
if (has_bias()) {
ret += "b" + std::to_string(bias);
}
return ret;
}
}
bool operator==(ScalarType const& other) const {
return mantissa == other.mantissa && exponent == other.exponent &&
bias == other.bias && signed_ == other.signed_ &&
finite_values_only == other.finite_values_only &&
nan_repr == other.nan_repr;
}
};
// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from
// torch::CustomClassHolder), we use multiple inheritance here since we cannot
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
// constructor at the same time (torch::CustomClassHolder does not have a
// constexpr destructor)
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
public:
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
bool _signed)
: ScalarType(exponent, mantissa, bias, _signed){};
ScalarTypeTorch(ScalarType type) : ScalarType(type){};
using Base = ScalarType;
using Self = ScalarTypeTorch;
using SelfPtr = c10::intrusive_ptr<Self>;
static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
return c10::make_intrusive<Self>(
ScalarType::int_(size_bits, bias.value_or(0)));
}
static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
return c10::make_intrusive<Self>(
ScalarType::uint(size_bits, bias.value_or(0)));
}
static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
return c10::make_intrusive<Self>(
ScalarType::float_IEEE754(exponent, mantissa));
}
static SelfPtr float_(int64_t exponent, int64_t mantissa,
bool finite_values_only, int64_t nan_repr) {
return c10::make_intrusive<Self>(ScalarType::float_(
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
}
template <typename T>
static void bind_readonly_property(torch::class_<Self>& cls,
std::string const& name, T Base::*field) {
auto getter_func = [field = std::move(field)](SelfPtr const& self) {
if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
return (self.get()->*field)();
} else {
return self.get()->*field;
}
};
cls.def_property(name, getter_func);
}
template <typename MemberFunc, typename Cls>
static void bind_function(torch::class_<Self>& cls, const std::string& name,
MemberFunc Cls::*member) {
cls.def(name, [member = std::move(member)](SelfPtr const& self) {
return (self.get()->*member)();
});
}
template <typename Func>
static void bind_function(torch::class_<Self>& cls, const std::string& name,
Func func) {
cls.def(name, func);
}
template <typename Func>
static void bind_static_function(torch::class_<Self>& cls,
const std::string& name, Func func) {
cls.def_static(name, func);
}
static void bind_class(torch::Library& lib) {
auto cls = lib.class_<ScalarTypeTorch>("ScalarType")
.def(torch::init<int64_t, int64_t, int64_t, bool>());
// Bind Properties
bind_readonly_property(cls, "mantissa", &Base::mantissa);
bind_readonly_property(cls, "exponent", &Base::exponent);
bind_readonly_property(cls, "bias", &Base::bias);
bind_readonly_property(cls, "signed", &Base::is_signed);
bind_readonly_property(cls, "size_bits", &Base::size_bits);
// Bind member functions
bind_function(cls, "is_signed", &Base::is_signed);
bind_function(cls, "is_integer", &Base::is_integer);
bind_function(cls, "is_floating_point", &Base::is_floating_point);
bind_function(cls, "is_ieee_754", &Base::is_ieee_754);
bind_function(cls, "has_nans", &Base::has_nans);
bind_function(cls, "has_infs", &Base::has_infs);
bind_function(cls, "has_bias", &Base::has_bias);
bind_function(cls, "max", [](SelfPtr const& self) {
return std::visit([](auto arg) { return c10::IValue(arg); },
self.get()->max());
});
bind_function(cls, "min", [](SelfPtr const& self) {
return std::visit([](auto arg) { return c10::IValue(arg); },
self.get()->min());
});
bind_function(cls, "__str__", &Base::str);
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
return *self == *other;
});
bind_function(cls, "__repr__", [](SelfPtr const& self) {
return "ScalarType." + self.get()->str();
});
// Bind static functions (convenience constructors)
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754);
bind_static_function(cls, "float_", &ScalarTypeTorch::float_);
}
};
using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
// "rust style" names generally following:
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
static inline constexpr auto kS4 = ScalarType::int_(4);
static inline constexpr auto kU4 = ScalarType::uint(4);
static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
static inline constexpr auto kS8 = ScalarType::int_(8);
static inline constexpr auto kU8 = ScalarType::uint(8);
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
static inline constexpr auto kFE3M2f =
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn =
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
// Fixed width style names, generally following:
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
static inline constexpr auto kInt4 = kS4;
static inline constexpr auto kUint4 = kU4;
static inline constexpr auto kUint4b8 = kU4B8;
static inline constexpr auto kInt8 = kS8;
static inline constexpr auto kUint8 = kU8;
static inline constexpr auto kUint8b128 = kU8B128;
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
static inline constexpr auto kFloat16_e8m7 = kFE8M7;
static inline constexpr auto kFloat16_e5m10 = kFE5M10;
// colloquial names
static inline constexpr auto kHalf = kFE5M10;
static inline constexpr auto kFloat16 = kHalf;
static inline constexpr auto kBFloat16 = kFE8M7;
}; // namespace vllm

View File

@@ -0,0 +1,16 @@
#include <torch/library.h>
#include "scalar_type.hpp"
#include "registration.h"
// Note the CORE exstension will be built for (almost) all hardware targets so
// new additions must account for this. (currently not built for TPU and Neuron)
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) {
// ScalarType, a custom class for representing data types that supports
// quantized types, declared here so it can be used when creating interfaces
// for custom ops.
vllm::ScalarTypeTorch::bind_class(lib);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

View File

@@ -1,9 +1,11 @@
#include "cache.h" #include "cache.h"
#include "ops.h" #include "ops.h"
#include "registration.h" #include "core/registration.h"
#include <torch/library.h> #include <torch/library.h>
void init_cpu_threads_env(const std::string& cpu_ids);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops // vLLM custom ops
@@ -107,4 +109,9 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
} }
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
// CPU utils
utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

65
csrc/cpu/utils.cpp Normal file
View File

@@ -0,0 +1,65 @@
#include <numa.h>
#include <unistd.h>
#include <string>
#include <sched.h>
#include "cpu_types.hpp"
void init_cpu_threads_env(const std::string& cpu_ids) {
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
TORCH_CHECK(omp_cpu_mask->size > 0);
std::vector<int> omp_cpu_ids;
omp_cpu_ids.reserve(omp_cpu_mask->size);
constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp);
for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) {
unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size];
int i = 0;
while (group_mask) {
if (group_mask & 1) {
omp_cpu_ids.emplace_back(offset + i);
}
++i;
group_mask >>= 1;
}
}
// Memory node binding
if (numa_available() != -1) {
int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front());
bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str());
bitmask* src_mask = numa_get_membind();
int pid = getpid();
// move all existing pages to the specified numa node.
*(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
int page_num = numa_migrate_pages(pid, src_mask, mask);
if (page_num == -1) {
TORCH_CHECK(false,
"numa_migrate_pages failed. errno: " + std::to_string(errno));
}
// restrict memory allocation node.
numa_set_membind(mask);
numa_set_strict(1);
}
// OMP threads binding
omp_set_num_threads((int)omp_cpu_ids.size());
torch::set_num_threads((int)omp_cpu_ids.size());
TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
#pragma omp parallel for schedule(static, 1)
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size);
size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size);
CPU_ZERO_S(size, mask);
CPU_SET_S(omp_cpu_ids[i], size, mask);
sched_setaffinity(0, sizeof(cpu_set_t), mask);
CPU_FREE(mask);
}
numa_free_nodemask(omp_cpu_mask);
}

View File

@@ -1,4 +1,4 @@
#include "registration.h" #include "core/registration.h"
#include "moe_ops.h" #include "moe_ops.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {

View File

@@ -3,6 +3,8 @@
#include <optional> #include <optional>
#include <torch/library.h> #include <torch/library.h>
#include "core/scalar_type.hpp"
void paged_attention_v1( void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
@@ -84,16 +86,19 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta, torch::Tensor& b_meta,
torch::Tensor& b_scales, torch::Tensor& b_scales,
torch::Tensor& workspace, int64_t num_bits, torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_m, int64_t size_n,
int64_t size_k); int64_t size_k);
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace, int64_t num_bits, torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k, int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp); bool is_k_full, bool has_zp,
bool use_fp32_reduce);
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
@@ -114,6 +119,13 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias); c10::optional<torch::Tensor> const& bias);
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
torch::Tensor const& b_q_weight,
torch::Tensor const& s_tok,
torch::Tensor const& s_ch,
torch::Tensor const& s_group,
torch::Tensor& workspace, int64_t size_m,
int64_t size_n, int64_t size_k);
#endif #endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,

View File

@@ -1,217 +0,0 @@
Contains code from https://github.com/punica-ai/punica
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
------------------------------------------------------------------------------------
This product bundles various third-party components under other open source licenses.
This section summarizes those components and their licenses. See licenses/
for text of these licenses.
Apache-2.0
* third_party/nvbench (with LLVM exception)
* third_party/flashinfer
BSD-3-Clause:
* third_party/cutlass

View File

@@ -1,5 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)

View File

@@ -1,5 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16)

View File

@@ -1,218 +0,0 @@
#pragma once
template <int feat_in, int feat_out, typename in_T, typename out_T,
typename W_T>
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale);
// clang-format off
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, narrow, 128) \
f(in_T, out_T, W_T, narrow, 256) \
f(in_T, out_T, W_T, narrow, 512) \
f(in_T, out_T, W_T, narrow, 640) \
f(in_T, out_T, W_T, narrow, 768) \
f(in_T, out_T, W_T, narrow, 896) \
f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1152) \
f(in_T, out_T, W_T, narrow, 1216) \
f(in_T, out_T, W_T, narrow, 1280) \
f(in_T, out_T, W_T, narrow, 1536) \
f(in_T, out_T, W_T, narrow, 1664) \
f(in_T, out_T, W_T, narrow, 1728) \
f(in_T, out_T, W_T, narrow, 1792) \
f(in_T, out_T, W_T, narrow, 2048) \
f(in_T, out_T, W_T, narrow, 2240) \
f(in_T, out_T, W_T, narrow, 2304) \
f(in_T, out_T, W_T, narrow, 2368) \
f(in_T, out_T, W_T, narrow, 2432) \
f(in_T, out_T, W_T, narrow, 2560) \
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \
f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3328) \
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 3712) \
f(in_T, out_T, W_T, narrow, 4096) \
f(in_T, out_T, W_T, narrow, 4480) \
f(in_T, out_T, W_T, narrow, 4608) \
f(in_T, out_T, W_T, narrow, 4736) \
f(in_T, out_T, W_T, narrow, 4864) \
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 5888) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6400) \
f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 7424) \
f(in_T, out_T, W_T, narrow, 8192) \
f(in_T, out_T, W_T, narrow, 8960) \
f(in_T, out_T, W_T, narrow, 9216) \
f(in_T, out_T, W_T, narrow, 9472) \
f(in_T, out_T, W_T, narrow, 10240) \
f(in_T, out_T, W_T, narrow, 11008) \
f(in_T, out_T, W_T, narrow, 11264) \
f(in_T, out_T, W_T, narrow, 12288) \
f(in_T, out_T, W_T, narrow, 13696) \
f(in_T, out_T, W_T, narrow, 13824) \
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 14784) \
f(in_T, out_T, W_T, narrow, 14848) \
f(in_T, out_T, W_T, narrow, 15360) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 18944) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 22016) \
f(in_T, out_T, W_T, narrow, 22528) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 27392) \
f(in_T, out_T, W_T, narrow, 27648) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 29568) \
f(in_T, out_T, W_T, narrow, 29696) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
f(in_T, out_T, W_T, narrow, 32512) \
f(in_T, out_T, W_T, narrow, 32768) \
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 43264) \
f(in_T, out_T, W_T, narrow, 49152) \
f(in_T, out_T, W_T, narrow, 49408) \
f(in_T, out_T, W_T, narrow, 60544) \
f(in_T, out_T, W_T, narrow, 60672) \
f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64256) \
f(in_T, out_T, W_T, narrow, 64512) \
f(in_T, out_T, W_T, narrow, 102400) \
f(in_T, out_T, W_T, narrow, 102656) \
f(in_T, out_T, W_T, narrow, 102912) \
f(in_T, out_T, W_T, narrow, 128000) \
f(in_T, out_T, W_T, narrow, 128256) \
f(in_T, out_T, W_T, narrow, 128512) \
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py
// Used for defining kernels going from the variety of
// dim in to the narrow dim out
// Using it for the fully sharded column
// parallel LoRA A which splits the rank dim
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, 128, narrow) \
f(in_T, out_T, W_T, 256, narrow) \
f(in_T, out_T, W_T, 512, narrow) \
f(in_T, out_T, W_T, 640, narrow) \
f(in_T, out_T, W_T, 768, narrow) \
f(in_T, out_T, W_T, 896, narrow) \
f(in_T, out_T, W_T, 1024, narrow) \
f(in_T, out_T, W_T, 1152, narrow) \
f(in_T, out_T, W_T, 1216, narrow) \
f(in_T, out_T, W_T, 1280, narrow) \
f(in_T, out_T, W_T, 1536, narrow) \
f(in_T, out_T, W_T, 1664, narrow) \
f(in_T, out_T, W_T, 1728, narrow) \
f(in_T, out_T, W_T, 1792, narrow) \
f(in_T, out_T, W_T, 2048, narrow) \
f(in_T, out_T, W_T, 2240, narrow) \
f(in_T, out_T, W_T, 2304, narrow) \
f(in_T, out_T, W_T, 2368, narrow) \
f(in_T, out_T, W_T, 2432, narrow) \
f(in_T, out_T, W_T, 2560, narrow) \
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3328, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 3712, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
f(in_T, out_T, W_T, 4480, narrow) \
f(in_T, out_T, W_T, 4608, narrow) \
f(in_T, out_T, W_T, 4736, narrow) \
f(in_T, out_T, W_T, 4864, narrow) \
f(in_T, out_T, W_T, 5120, narrow) \
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 5888, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6400, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
f(in_T, out_T, W_T, 7424, narrow) \
f(in_T, out_T, W_T, 8192, narrow) \
f(in_T, out_T, W_T, 8960, narrow) \
f(in_T, out_T, W_T, 9216, narrow) \
f(in_T, out_T, W_T, 9472, narrow) \
f(in_T, out_T, W_T, 10240, narrow) \
f(in_T, out_T, W_T, 11008, narrow) \
f(in_T, out_T, W_T, 11264, narrow) \
f(in_T, out_T, W_T, 12288, narrow) \
f(in_T, out_T, W_T, 13696, narrow) \
f(in_T, out_T, W_T, 13824, narrow) \
f(in_T, out_T, W_T, 14336, narrow) \
f(in_T, out_T, W_T, 14784, narrow) \
f(in_T, out_T, W_T, 14848, narrow) \
f(in_T, out_T, W_T, 15360, narrow) \
f(in_T, out_T, W_T, 16384, narrow) \
f(in_T, out_T, W_T, 18944, narrow) \
f(in_T, out_T, W_T, 20480, narrow) \
f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 22528, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 27648, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 29568, narrow) \
f(in_T, out_T, W_T, 29696, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \
f(in_T, out_T, W_T, 32512, narrow) \
f(in_T, out_T, W_T, 32768, narrow) \
f(in_T, out_T, W_T, 33024, narrow) \
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 49408, narrow) \
f(in_T, out_T, W_T, 60544, narrow) \
f(in_T, out_T, W_T, 60672, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \
f(in_T, out_T, W_T, 102400, narrow) \
f(in_T, out_T, W_T, 102656, narrow) \
f(in_T, out_T, W_T, 102912, narrow) \
f(in_T, out_T, W_T, 128000, narrow) \
f(in_T, out_T, W_T, 128256, narrow) \
f(in_T, out_T, W_T, 128512, narrow) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
f(in_T, out_T, W_T, 8, 64) \
f(in_T, out_T, W_T, 16, 64) \
f(in_T, out_T, W_T, 32, 64) \
f(in_T, out_T, W_T, 64, 64)
// clang-format on

View File

@@ -1,5 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)

View File

@@ -1,5 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half)

View File

@@ -1,5 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16)

View File

@@ -1,5 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half)

View File

@@ -1,451 +0,0 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#ifndef USE_ROCM
#include <cooperative_groups.h>
#else
#include <hip/hip_cooperative_groups.h>
#endif
#ifndef USE_ROCM
#include <cuda/pipeline>
#endif
#include <cuda_runtime.h>
#include <iostream>
#include <stdio.h>
#include "vec_dtypes.cuh"
namespace cg = cooperative_groups;
#ifdef USE_ROCM
template <size_t len>
__host__ __device__
inline void* memcpy_blocking(void *dst, const void *src) {
// Does not handle the case of long datatypes
char *d = reinterpret_cast<char *>(dst);
const char *s = reinterpret_cast<const char *>(src);
size_t i = 0;
#pragma unroll
for (i = 0; i < len; ++i) {
d[i] = s[i];
}
return dst;
}
#endif
#ifndef USE_ROCM
// nthrs = (32, 4)
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
typename out_T, typename W_T>
__global__ void
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
auto block = cg::this_thread_block();
size_t j = blockIdx.x;
constexpr size_t num_pipeline_stages = 2;
constexpr size_t tile_size = tx * ty * vec_size;
__shared__ W_T W_shared[num_pipeline_stages * tile_size];
__shared__ in_T X_shared[num_pipeline_stages * tile_size];
__shared__ float y_warpwise[ty];
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
auto pipe = cuda::make_pipeline();
// pipeline load W/X and compute WX;
pipe.producer_acquire();
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
pipe.producer_commit();
size_t copy_idx, compute_idx;
float y = 0.f;
vec_t<in_T, vec_size> x_vec;
vec_t<W_T, vec_size> w_vec;
size_t tile_idx;
#pragma unroll
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
++tile_idx) {
copy_idx = tile_idx % num_pipeline_stages;
// pipeline stage: async copy W fragment
pipe.producer_acquire();
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) + tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
}
pipe.producer_commit();
compute_idx = (tile_idx - 1) % num_pipeline_stages;
// pipeline stage: compute WX
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] = sum;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}
block.sync();
pipe.consumer_release();
}
compute_idx = (tile_idx - 1) % num_pipeline_stages;
// final pipeline stage
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] =
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
? sum
: 0.f;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}
block.sync();
pipe.consumer_release();
// write Y;
if (block.thread_rank() == 0) {
Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
}
}
#else
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
typename out_T, typename W_T>
__global__ void
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
size_t j = blockIdx.x;
constexpr size_t tile_size = tx * ty * vec_size;
constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size;
__shared__ float y_warpwise[ty];
float y = 0;
vec_t<in_T, vec_size> x_vec;
vec_t<W_T, vec_size> w_vec;
size_t tile_idx;
#pragma unroll
for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
x_vec.load(X + (batch_idx * feat_in) +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W + (idx * feat_out + j) * feat_in +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size);
}
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += VLLM_SHFL_DOWN_SYNC(sum, offset);
}
__syncthreads();
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
y += sum;
}
}
if (threadIdx.x == 0) {
y_warpwise[threadIdx.y] = y;
}
__syncthreads();
float y_write = 0.f;
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y_write += y_warpwise[i];
}
// write Y;
if (threadIdx.x == 0 && threadIdx.y == 0) {
size_t y_idx = batch_idx * full_y_size + y_offset + j;
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(y_write));
}
}
#endif
// nthrs = (2, 16, 4)
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
typename in_T, typename out_T, typename W_T>
__global__ void
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
auto block = cg::this_thread_block();
size_t tile_idx = blockIdx.x;
// load X;
vec_t<in_T, vec_size> x_vec;
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
// load W;
vec_t<W_T, vec_size> w_vec;
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
block.thread_rank() * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
#ifndef USE_ROCM
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
#else
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
#endif
}
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += g.shfl_down(sum, offset);
}
sum = g.shfl(sum, 0);
if (threadIdx.x == 0) {
#ifndef USE_ROCM
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
#else
size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y;
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(sum));
#endif
}
}
template <int feat_in, int feat_out, typename in_T, typename out_T,
typename W_T>
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
constexpr size_t vec_size = 8;
constexpr int tz = 4;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if constexpr (feat_in <= feat_out) {
static_assert(feat_in % vec_size == 0);
constexpr int tx = feat_in / vec_size;
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
constexpr int ty = 32 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
constexpr int ty = 16 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else {
constexpr int ty = 8 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
}
} else {
#ifndef USE_ROCM
static_assert(feat_in % (vec_size * 32) == 0 ||
feat_in % (vec_size * 16) == 0 ||
feat_in % (vec_size * 8) == 0);
if constexpr (feat_in % (vec_size * 32) == 0) {
constexpr int tx = 32;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
vec_size * sizeof(W_T), tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
constexpr int tx = 32;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
vec_size * sizeof(in_T) / 2,
vec_size * sizeof(W_T) / 2, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
constexpr int tx = 16;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
vec_size * sizeof(in_T) / 2,
vec_size * sizeof(W_T) / 2, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
}
#else
constexpr size_t rocm_warp_size = warpSize;
#define CHECK_INPUT_TILEABLE_BY(vec_size_) \
feat_in % (rocm_warp_size * vec_size_) == 0
#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \
if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \
constexpr size_t vec_size_shrink = vec_size_; \
constexpr int tx = tx_; \
constexpr int ty = ty_; \
dim3 nblks(feat_out, batch_size); \
dim3 nthrs(tx, ty); \
bgmv_shrink_kernel<feat_in, feat_out, vec_size_shrink, \
vec_size_shrink * sizeof(in_T), \
vec_size_shrink * sizeof(W_T), \
tx, ty, tz> \
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, \
full_y_size, num_layers, layer_idx, \
scale); \
}
static_assert(CHECK_INPUT_TILEABLE_BY(32) ||
CHECK_INPUT_TILEABLE_BY(16) ||
CHECK_INPUT_TILEABLE_BY( 8) ||
CHECK_INPUT_TILEABLE_BY( 4) ||
CHECK_INPUT_TILEABLE_BY( 2) ||
CHECK_INPUT_TILEABLE_BY( 1));
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1)
#undef CHECK_INPUT_TILEABLE_BY
#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM
#endif
}
}
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
template void bgmv_kernel<feat_in, feat_out>( \
out_T * __restrict__ Y, const in_T *__restrict__ X, \
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)

View File

@@ -1,48 +0,0 @@
DTYPES = ["fp16", "bf16", "fp32"]
DTYPE_MAP = {
"fp16": "nv_half",
"bf16": "nv_bfloat16",
"fp32": "float",
}
TEMPLATE = """
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip() # noqa: E501
for input_dtype in DTYPES:
for output_dtype in DTYPES:
for weight_dtype in DTYPES:
if weight_dtype == "fp32":
# FP32 weights are not supported.
continue
if output_dtype == "fp32":
# LoRA A matrix.
if input_dtype != weight_dtype:
# NOTE(woosuk): While Punica supports the case where the
# input and weight dtypes are different, we only generate
# the kernels the same dtypes to reduce the binary size.
continue
elif input_dtype == "fp32":
# LoRA B matrix.
if output_dtype != weight_dtype:
# NOTE(woosuk): While Punica supports the case where the
# output and weight dtypes are different, we only generate
# the kernels the same dtypes to reduce the binary size.
continue
elif not (input_dtype == output_dtype == weight_dtype):
# NOTE(woosuk): While Punica supports mixed data types for
# input, output, and weight, we only generate the kernels with
# the same data types to reduce the binary size.
continue
kernel_definition = TEMPLATE.format(
input_dtype=DTYPE_MAP[input_dtype],
output_dtype=DTYPE_MAP[output_dtype],
weight_dtype=DTYPE_MAP[weight_dtype])
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
with open(filename, "w") as f:
f.write(kernel_definition)

File diff suppressed because it is too large Load Diff

View File

@@ -1,569 +0,0 @@
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cstdint>
#include "type_convert.h"
#include "../cuda_compat.h"
#include "bgmv/bgmv_config.h"
//====== utils ======
inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
const char *a_name, const char *b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
a.dim(), " vs ", b.dim());
for (int i = 0; i < a.dim(); ++i) {
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
".size(", i, ")");
}
}
inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
return (uint64_t(a) << 32) | uint64_t(b);
}
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_DIM(d, x) \
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) \
TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
//====== bgmv ======
template <typename in_T, typename out_T, typename W_T>
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
const int64_t *lora_indices,
uint32_t in_features, uint32_t out_features,
int64_t y_offset, int64_t full_y_size,
int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
// NOTE(woosuk): While Punica supports various combinations of input/output
// data types, we limit the supported data types to reduce the binary size.
constexpr bool is_input_float = std::is_same<in_T, float>::value;
constexpr bool is_output_float = std::is_same<out_T, float>::value;
if (is_input_float) {
if (!std::is_same<out_T, W_T>::value) {
return false;
}
} else if (is_output_float) {
if (!std::is_same<in_T, W_T>::value) {
return false;
}
} else if (!(std::is_same<in_T, W_T>::value &&
std::is_same<out_T, W_T>::value)) {
return false;
}
switch (pack_u32(in_features, out_features)) {
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u32(feat_in, feat_out): \
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
full_y_size, batch_size, num_layers, \
layer_idx, scale); \
break;
#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _)
#undef CASE
#undef CASE_ONESIDE
default:
return false;
}
return true;
}
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, double scale) {
CHECK_INPUT(y);
CHECK_INPUT(x);
CHECK_INPUT(w);
CHECK_INPUT(indicies);
CHECK_DIM(2, y);
CHECK_DIM(2, x);
CHECK_DIM(4, w);
CHECK_DIM(1, indicies);
int64_t B = x.size(0);
int64_t h_in = x.size(1);
int64_t h_out = y.size(1);
int64_t num_layers = w.size(1);
CHECK_EQ(w.size(3), h_in);
CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false;
if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
default:
break;
}
}
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
}
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx,
double scale, int64_t h_in, int64_t h_out,
int64_t y_offset) {
CHECK_INPUT(y);
CHECK_INPUT(x);
CHECK_INPUT(w);
CHECK_INPUT(indicies);
CHECK_DIM(2, y);
CHECK_DIM(2, x);
CHECK_DIM(4, w);
CHECK_DIM(1, indicies);
int64_t B = x.size(0);
int64_t num_layers = w.size(1);
int64_t full_y_size = y.size(1);
CHECK_EQ(w.size(3), h_in);
CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false;
if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
default:
break;
}
}
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
}

View File

@@ -1,11 +0,0 @@
#pragma once
#include <torch/all.h>
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, double scale);
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx,
double scale, int64_t h_in, int64_t h_out,
int64_t y_offset);

View File

@@ -1,18 +0,0 @@
#include "registration.h"
#include "punica_ops.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
"layer_idx, float scale) -> ()");
m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);
m.def(
"dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
"Tensor indicies, int layer_idx,"
"float scale, int h_in, int h_out,"
"int y_offset) -> ()");
m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

View File

@@ -1,82 +0,0 @@
#ifndef CSRC__PUNICA__TYPE_CONVERT_H__
#define CSRC__PUNICA__TYPE_CONVERT_H__
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__
typedef __half nv_half;
typedef __hip_bfloat16 nv_bfloat16;
typedef __hip_bfloat162 nv_bfloat162;
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) {
return __hip_bfloat162{val, val};
}
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) {
return __hip_bfloat162{vall, valr};
}
template <typename T_src, typename T_dst>
__TYPE_CONVERT__HOST_DEVICE__
inline T_dst convert_type(T_src val) {
return static_cast<T_dst>(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline float convert_type<__half, float>(__half val) {
return __half2float(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __half convert_type<float, __half>(float val) {
return __float2half(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) {
return __bfloat162float(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat16 convert_type<float, __hip_bfloat16>(float val) {
return __float2bfloat16(val);
}
template <typename T>
__TYPE_CONVERT__HOST_DEVICE__
inline T vllm_add(T a, T b) {
return a + b;
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __half vllm_add<__half>(__half a, __half b) {
return __hadd(a, b);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) {
return __hadd(a, b);
}
#undef __TYPE_CONVERT__HOST_DEVICE__
#endif // USE_ROCM
#endif // CSRC__PUNICA__TYPE_CONVERT_H__

View File

@@ -273,8 +273,6 @@ __global__ void Code2x8Dequant(
} }
__syncthreads(); __syncthreads();
float res = 0;
int iters = (prob_k / 8 - 1) / (8 * 32) + 1; int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
while (iters--) { while (iters--) {
if (pred && a_gl_rd < a_gl_end) { if (pred && a_gl_rd < a_gl_end) {

View File

@@ -95,6 +95,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
return result; return result;
#endif #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
} // namespace awq } // namespace awq

View File

@@ -17,14 +17,6 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
namespace vllm { namespace vllm {
namespace awq { namespace awq {
// Pack two half values.
static inline __device__ __host__ unsigned __pack_half2(const half x,
const half y) {
unsigned v0 = *((unsigned short*)&x);
unsigned v1 = *((unsigned short*)&y);
return (v1 << 16) | v0;
}
template <int N> template <int N>
__global__ void __launch_bounds__(64) __global__ void __launch_bounds__(64)
gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters, gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
@@ -42,11 +34,7 @@ __global__ void __launch_bounds__(64)
__shared__ half A_shared[16 * (32 + 8)]; __shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (N + 8)]; __shared__ half B_shared[32 * (N + 8)];
__shared__ half scaling_factors_shared[N];
__shared__ half zeros_shared[N];
int j_factors1 = ((OC + N - 1) / N); int j_factors1 = ((OC + N - 1) / N);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
@@ -60,7 +48,6 @@ __global__ void __launch_bounds__(64)
static constexpr int row_stride_warp = 32 * 8 / 32; static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / N; static constexpr int row_stride = 2 * 32 * 8 / N;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = bool ld_A_flag =
(blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp +
@@ -145,11 +132,7 @@ __global__ void __launch_bounds__(64)
uint32_t B_loaded = uint32_t B_loaded =
*(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
// uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
// 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
// % (cta_N / 8)) * 8);
// - zero and * scale // - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = // TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
// q * scale - zero * scale. // q * scale - zero * scale.
@@ -367,17 +350,11 @@ __global__ void __launch_bounds__(64)
__global__ void __launch_bounds__(64) __global__ void __launch_bounds__(64)
dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors, dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors,
int* __restrict__ zeros, half* __restrict__ C, int G) { int* __restrict__ zeros, half* __restrict__ C, int G) {
int j_factors1 = 4;
int row_stride2 = 4;
int split_k_iters = 1;
static constexpr uint32_t ZERO = 0x0; static constexpr uint32_t ZERO = 0x0;
half B_shared[32 * (128 + 8)]; half B_shared[32 * (128 + 8)];
half* B_shared_ptr2 = B_shared; half* B_shared_ptr2 = B_shared;
half B_shared_warp[32];
int OC = 512;
int N = blockDim.x * gridDim.x; // 2 int N = blockDim.x * gridDim.x; // 2
int col = (blockIdx.x * blockDim.x + threadIdx.x); int col = (blockIdx.x * blockDim.x + threadIdx.x);
int row = blockIdx.y * blockDim.y + threadIdx.y; int row = blockIdx.y * blockDim.y + threadIdx.y;

View File

@@ -64,8 +64,6 @@ using namespace detail;
// Row vector broadcast // Row vector broadcast
template< template<
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
int Stages, int Stages,
class CtaTileShapeMNK, class CtaTileShapeMNK,
class Element, class Element,
@@ -73,14 +71,12 @@ template<
int Alignment = 128 / sizeof_bits_v<Element> int Alignment = 128 / sizeof_bits_v<Element>
> >
struct Sm90RowOrScalarBroadcast { struct Sm90RowOrScalarBroadcast {
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet"); static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
static_assert( static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
(cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
(cute::is_same_v<StrideMNL, Stride<_0,_1,int>>)); // batched row vector broadcast
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem struct SharedStorage {
struct SharedStorage { array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row;
}; };
// This struct has been modified to have a bool indicating that ptr_row is a // This struct has been modified to have a bool indicating that ptr_row is a
@@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast {
return args; return args;
} }
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape> template <class ProblemShape>
static size_t static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
@@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast {
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params(params), : params(params)
smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { } , smem(const_cast<Element*>(shared_storage.smem.data())) { }
Params params; Params params;
Element* smem_row; Element *smem = nullptr;
CUTLASS_DEVICE bool CUTLASS_DEVICE bool
is_producer_load_needed() const { is_producer_load_needed() const {
return true; return false;
} }
CUTLASS_DEVICE bool CUTLASS_DEVICE bool
@@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast {
return (!params.row_broadcast && *(params.ptr_row) == Element(0)); return (!params.row_broadcast && *(params.ptr_row) == Element(0));
} }
template <int EpiTiles, class GTensor, class STensor>
struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
CUTLASS_DEVICE
ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params)
: gRow(cute::forward<GTensor>(gRow)),
sRow(cute::forward<STensor>(sRow)),
params(params) {}
GTensor gRow; // (CTA_M,CTA_N)
STensor sRow; // (CTA_M,CTA_N,PIPE)
Params const& params;
CUTLASS_DEVICE void
begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {
if (!params.row_broadcast) {
return;
}
if (issue_tma_load) {
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8;
cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes);
// Issue the TMA bulk copy
auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr);
// Filter so we don't issue redundant copies over stride-0 modes
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index)));
}
}
};
template <class... Args> template <class... Args>
CUTLASS_DEVICE auto CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) { get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>(
cute::move(gRow), cute::move(sRow), params);
} }
template <int EpiTiles, class RTensor, class STensor> template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE CUTLASS_DEVICE
ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params) ConsumerStoreCallbacks(
: tCrRow(cute::forward<RTensor>(tCrRow)), GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
tCsRow(cute::forward<STensor>(tCsRow)), GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
params(params) {} SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
: tGS_gRow(tGS_gRow_)
, tGS_sRow(tGS_sRow_)
, tGS_cRow(tGS_cRow_)
, tiled_G2S(tiled_g2s_)
, tSR_sRow(tSR_sRow_)
, tSR_rRow(tSR_rRow_)
, tCcRow(tCcRow_)
, residue_tCcRow(residue_tCcRow_)
, params(params_) {}
RTensor tCrRow; // (CPY,CPY_M,CPY_N) GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
Tiled_G2S tiled_G2S;
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
ThrResidue residue_tCcRow; // (m, n)
ThrNum thr_num;
Params const& params; Params const& params;
CUTLASS_DEVICE void CUTLASS_DEVICE void
previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { begin() {
if (!params.row_broadcast) { if (!params.row_broadcast) {
fill(tCrRow, *(params.ptr_row)); fill(tSR_rRow, *(params.ptr_row));
return; return;
} }
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
continue; // OOB of SMEM,
}
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
tGS_sRow_flt(i) = tGS_gRow_flt(i);
}
else {
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
}
}
synchronize();
}
CUTLASS_DEVICE void
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop if (epi_m == 0) { // Assumes M-major subtile loop
// Filter so we don't issue redundant copies over stride-0 modes if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
// (only works if 0-strides are in same location, which is by construction) Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow)); copy(tSR_sRow_flt, tSR_rRow_flt);
} }
} }
@@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast {
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) { for (int i = 0; i < FragmentSize; ++i) {
frg_row[i] = tCrRow(epi_v * FragmentSize + i); frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
} }
return frg_row; return frg_row;
@@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast {
> >
CUTLASS_DEVICE auto CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) { get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); Tensor sRow = make_tensor(make_smem_ptr(smem),
Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
sRow, args.epi_tile, args.tiled_copy, args.thread_idx); //// G2S: Gmem to Smem
Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N) auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
Layout< Shape<_1, ThreadCount>,
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; //// G2S: Coord
return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>( auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
cute::move(tCrRow), cute::move(tCsRow), params); Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
tGS_sRow,
tGS_cRow, tiled_g2s,
tSR_sRow,
tSR_rRow,
args.tCcD,
args.residue_cD,
ThreadCount{},
params);
} }
}; };
@@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast {
return args; return args;
} }
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape> template <class ProblemShape>
static size_t static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
@@ -328,20 +358,36 @@ struct Sm90ColOrScalarBroadcast {
return EmptyProducerLoadCallbacks{}; return EmptyProducerLoadCallbacks{};
} }
template<class GTensor, class RTensor> template<class GTensor, class RTensor, class CTensor, class ProblemShape>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE CUTLASS_DEVICE
ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params) ConsumerStoreCallbacks(
: tCgCol(cute::forward<GTensor>(tCgCol)), GTensor&& tCgCol,
tCrCol(cute::forward<RTensor>(tCrCol)), RTensor&& tCrCol,
params(params) {} CTensor&& tCcCol,
ProblemShape problem_shape,
Params const& params
):
tCgCol(cute::forward<GTensor>(tCgCol)),
tCrCol(cute::forward<RTensor>(tCrCol)),
tCcCol(cute::forward<CTensor>(tCcCol)),
m(get<0>(problem_shape)),
params(params) {}
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) RTensor tCrCol;
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params const& params; Params const& params;
int m;
CUTLASS_DEVICE void CUTLASS_DEVICE void
begin() { begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
}
if (!params.col_broadcast) { if (!params.col_broadcast) {
fill(tCrCol, *(params.ptr_col)); fill(tCrCol, *(params.ptr_col));
return; return;
@@ -349,7 +395,7 @@ struct Sm90ColOrScalarBroadcast {
// Filter so we don't issue redundant copies over stride-0 modes // Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction) // (only works if 0-strides are in same location, which is by construction)
copy_aligned(filter(tCgCol), filter(tCrCol)); copy_if(pred, filter(tCgCol), filter(tCrCol));
} }
template <typename ElementAccumulator, int FragmentSize> template <typename ElementAccumulator, int FragmentSize>
@@ -381,8 +427,20 @@ struct Sm90ColOrScalarBroadcast {
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
return ConsumerStoreCallbacks<decltype(tCgCol), decltype(tCrCol)>( // Generate an identity tensor matching the shape of the global tensor and
cute::move(tCgCol), cute::move(tCrCol), params); // partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(
cute::move(tCgCol),
cute::move(tCrCol),
cute::move(tCcCol),
args.problem_shape_mnkl,
params
);
} }
}; };

View File

@@ -1,470 +1,18 @@
#include <stddef.h> #include <stddef.h>
#include <torch/all.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "scaled_mm_c2x.cuh"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" #include "scaled_mm_c2x_sm75_dispatch.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "broadcast_load_epilogue_c2x.hpp" #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "common.hpp" #include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
// clang-format on
using namespace cute;
/* /*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper). NVIDIA GPUs with SM versions prior to sm90 (Hopper).
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/ */
namespace {
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm75_to_sm80 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm80_to_sm89 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm89_to_sm90 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
/*
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
typename EVTCompute0::Arguments evt0_compute_args{b_args};
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
return evt_compute_args;
}
};
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, ElementD, Stride<Int<0>, Int<1>, Int<0>>>;
public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;
using BiasArgs = typename Bias::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
BiasArgs bias_args{static_cast<ElementD*>(bias.data_ptr()), {}};
typename EVTCompute0::Arguments evt0_compute_args{b_args};
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
bias_args};
return evt_compute_args;
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Operator =
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
cutlass::arch::OpMultiplyAddSaturate,
cutlass::arch::OpMultiplyAdd>::type;
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
using EVTCompute = typename Epilogue::EVTCompute;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
Stride<int64_t, Int<1>, Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType =
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
float, cutlass::layout::RowMajor, 4,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
TileShape, WarpShape, InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
MainLoopStages, Operator,
1 /* epilogue stages */
>::GemmKernel>;
// clang-format on
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
};
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);
cutlass::gemm::GemmCoord problem_size{m, n, k};
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
using Epilogue = typename Gemm::Epilogue;
auto evt_args =
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
typename Gemm::EVTD::Arguments epilogue_args{
evt_args,
d_args,
};
typename Gemm::Op::Arguments args{
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
problem_size, // problem size
1, // batch count
epilogue_args,
a_ptr,
b_ptr,
nullptr,
nullptr,
0,
0,
0,
0,
lda,
ldb,
ldc,
ldc};
// Launch the CUTLASS GEMM kernel.
typename Gemm::Op gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace.get(), stream);
CUTLASS_CHECK(status);
}
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
// In some cases, the GPU isn't able to accommodate the
// shared memory requirements of the Gemm. In such cases, use
// the FallbackGemm instead.
static const int max_shared_mem_per_block_opt_in =
get_cuda_max_shared_memory_per_block_opt_in(0);
size_t const gemm_shared_mem_size =
sizeof(typename Gemm::KernelType::SharedStorage);
size_t const fallback_gemm_shared_mem_size =
sizeof(typename FallbackGemm::KernelType::SharedStorage);
if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
return cutlass_gemm_caller<Gemm>(out, a, b,
std::forward<EpilogueArgs>(args)...);
} else {
TORCH_CHECK(fallback_gemm_shared_mem_size <=
max_shared_mem_per_block_opt_in);
return cutlass_gemm_caller<FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_default {
// This config is used in 2 cases,
// - M in (128, inf)
// - M in (64, 128] and N >= 8192
// Shared Memory required by this Gemm - 81920 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M64 {
// This config is used in 2 cases,
// - M in (32, 64]
// - M in (64, 128] and N < 8192
// Shared Memory required by this Gemm - 122880 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M32 {
// M in (16, 32]
// Shared Memory required by this Gemm - 61440 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M16 {
// M in [1, 16]
// Shared Memory required by this Gemm - 51200 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
} // namespace
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
using Cutlass2xGemmDefault =
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM128BigN =
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM128SmallN =
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM64 =
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM32 =
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM16 =
typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm80_config_M16 has the least shared-memory requirement. However,
// based on some profiling, we select sm80_config_M32 as a better alternative
// performance wise.
using FallbackGemm =
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
uint32_t const n = out.size(1);
bool const small_n = n < 8192;
if (small_n) {
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN,
FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
} else {
// M in (128, inf)
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
template <template <typename, typename> typename Epilogue, template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
@@ -473,20 +21,13 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8);
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm< return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t,
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t, Epilogue>(
Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm< return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} }
@@ -501,11 +42,11 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogueBias>( return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias); out, a, b, a_scales, b_scales, *bias);
} else { } else {
return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogue>(out, a, b, a_scales, return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogue>(
b_scales); out, a, b, a_scales, b_scales);
} }
} }
@@ -518,11 +59,12 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>( return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>( return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} }
@@ -537,11 +79,11 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogueBias>( return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias); out, a, b, a_scales, b_scales, *bias);
} else { } else {
return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogue>(out, a, b, a_scales, return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogue>(
b_scales); out, a, b, a_scales, b_scales);
} }
} }
@@ -550,23 +92,17 @@ template <template <typename, typename> typename Epilogue,
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) { EpilogueArgs&&... epilogue_args) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
if (a.dtype() == torch::kInt8) { if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm< return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t, Epilogue>(
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
assert(out.dtype() == torch::kFloat16); assert(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm< return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t, Epilogue>(
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} else { } else {
@@ -574,17 +110,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller< return vllm::cutlass_gemm_sm89_fp8_dispatch<
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue,
TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller< return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::half_t, Epilogue>(
cutlass::float_e4m3_t, cutlass::half_t, Epilogue,
TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} }
@@ -600,10 +132,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogueBias>( return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias); out, a, b, a_scales, b_scales, *bias);
} else { } else {
return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogue>(out, a, b, a_scales, return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogue>(
b_scales); out, a, b, a_scales, b_scales);
} }
} }

View File

@@ -0,0 +1,340 @@
#pragma once
#include <stddef.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
// clang-format on
using namespace cute;
/*
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace vllm {
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm75_to_sm80 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm80_to_sm89 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm89_to_sm90 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
/*
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
typename EVTCompute0::Arguments evt0_compute_args{b_args};
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
return evt_compute_args;
}
};
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, ElementD, Stride<Int<0>, Int<1>, Int<0>>>;
public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;
using BiasArgs = typename Bias::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
BiasArgs bias_args{static_cast<ElementD*>(bias.data_ptr()), {}};
typename EVTCompute0::Arguments evt0_compute_args{b_args};
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
bias_args};
return evt_compute_args;
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages,
typename FP8MathOperator = cutlass::arch::OpMultiplyAdd>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Operator =
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
cutlass::arch::OpMultiplyAddSaturate,
FP8MathOperator>::type;
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
using EVTCompute = typename Epilogue::EVTCompute;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
Stride<int64_t, Int<1>, Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType =
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
float, cutlass::layout::RowMajor, 4,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
TileShape, WarpShape, InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
MainLoopStages, Operator,
1 /* epilogue stages */
>::GemmKernel>;
// clang-format on
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
};
template <typename Gemm, typename... EpilogueArgs>
inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);
cutlass::gemm::GemmCoord problem_size{m, n, k};
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
using Epilogue = typename Gemm::Epilogue;
auto evt_args =
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
typename Gemm::EVTD::Arguments epilogue_args{
evt_args,
d_args,
};
typename Gemm::Op::Arguments args{
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
problem_size, // problem size
1, // batch count
epilogue_args,
a_ptr,
b_ptr,
nullptr,
nullptr,
0,
0,
0,
0,
lda,
ldb,
ldc,
ldc};
// Launch the CUTLASS GEMM kernel.
typename Gemm::Op gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
}
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
// In some cases, the GPU isn't able to accommodate the
// shared memory requirements of the Gemm. In such cases, use
// the FallbackGemm instead.
static const int max_shared_mem_per_block_opt_in =
get_cuda_max_shared_memory_per_block_opt_in(0);
size_t const gemm_shared_mem_size =
sizeof(typename Gemm::KernelType::SharedStorage);
size_t const fallback_gemm_shared_mem_size =
sizeof(typename FallbackGemm::KernelType::SharedStorage);
if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
return cutlass_gemm_caller<Gemm>(out, a, b,
std::forward<EpilogueArgs>(args)...);
} else {
TORCH_CHECK(fallback_gemm_shared_mem_size <=
max_shared_mem_per_block_opt_in);
return cutlass_gemm_caller<FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace vllm

View File

@@ -0,0 +1,123 @@
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM75 based on the Gemm
* shape.
*/
namespace vllm {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_default {
// This config is used in 2 cases,
// - M in (256, inf]
// - M in (64, 128]
// Shared memory required by this Gemm 32768
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_M256 {
// M in (128, 256]
// Shared memory required by this Gemm 65536
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_M64 {
// M in (32, 64]
// Shared memory required by this Gemm 49152
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_M32 {
// M in [1, 32]
// Shared memory required by this Gemm 49152
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
using Cutlass2xGemmDefault =
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM256 =
typename sm75_config_M256<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM128 = Cutlass2xGemmDefault;
using Cutlass2xGemmM64 =
typename sm75_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM32 =
typename sm75_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm75_config_default has the least shared-memory requirements.
using FallbackGemm = Cutlass2xGemmDefault;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 32) {
// M in [1, 32]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// M in (128, 256]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM256, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// M in (256, inf)
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace vllm

View File

@@ -0,0 +1,139 @@
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM80 based on the Gemm
* shape.
*/
namespace vllm {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_default {
// This config is used in 2 cases,
// - M in (128, inf)
// - M in (64, 128] and N >= 8192
// Shared Memory required by this Gemm - 81920 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M64 {
// This config is used in 2 cases,
// - M in (32, 64]
// - M in (64, 128] and N < 8192
// Shared Memory required by this Gemm - 122880 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M32 {
// M in (16, 32]
// Shared Memory required by this Gemm - 61440 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M16 {
// M in [1, 16]
// Shared Memory required by this Gemm - 51200 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm80_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
using Cutlass2xGemmDefault =
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM128BigN =
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM128SmallN =
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM64 =
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM32 =
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM16 =
typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm80_config_M16 has the least shared-memory requirement. However,
// based on some profiling, we select sm80_config_M32 as a better alternative
// performance wise.
using FallbackGemm =
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
uint32_t const n = out.size(1);
bool const small_n = n < 8192;
if (small_n) {
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN,
FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
} else {
// M in (128, inf)
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace vllm

View File

@@ -0,0 +1,368 @@
#pragma once
#include "scaled_mm_c2x.cuh"
#include "cutlass/float8.h"
/**
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
* shape.
*/
namespace vllm {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm89_fp8_fallback_gemm {
// Shared Memory required by this Gemm - 61440 bytes
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5,
FP8MathOperator>;
};
struct sm89_fp8_config_default {
// M in (256, inf)
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M256 {
// M in (128, 256]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M128 {
// M in (64, 128]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M64 {
// M in (32, 64]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8196) {
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M32 {
// M in (16, 32]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M16 {
// M in [1, 16]
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
static const int32_t MainLoopStages = 5;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages,
FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 24576) {
using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages,
FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages,
FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// M in (128, 256]
return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// M in (256, inf)
return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace vllm

View File

@@ -0,0 +1,353 @@
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM89 (int8) based on the
* Gemm shape.
*/
namespace vllm {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm89_int8_fallback_gemm {
// Shared mem requirement : 61440
static_assert(std::is_same<InType, int8_t>());
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
static int32_t const MainLoopStages = 5;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
struct sm89_int8_config_default {
// M in (256, inf)
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M256 {
// M in (128, 256]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M128 {
// M in (64, 128]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M64 {
// M in (32, 64]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M32 {
// M in (16, 32]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M16 {
// M in [1, 16]
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<16, 64, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<16, 128, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
return sm89_int8_config_M16::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return sm89_int8_config_M32::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return sm89_int8_config_M64::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
return sm89_int8_config_M128::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// M in (128, 256]
return sm89_int8_config_M256::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// M in (256, inf)
return sm89_int8_config_default::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace vllm

View File

@@ -18,8 +18,6 @@
#include "cute/atom/mma_atom.hpp" #include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h" #include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp"
@@ -72,13 +70,9 @@ struct ScaledEpilogueBase {
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
Stride<Int<1>, Int<0>, Int<0>>>; Stride<Int<1>, Int<0>, Int<0>>>;
using ScaleBDescriptor =
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
EpilogueDescriptor, float>;
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>; Stride<Int<0>, Int<1>, Int<0>>>;
}; };
/* /*
@@ -154,12 +148,8 @@ struct ScaledEpilogueBias
cutlass::multiply_add, ElementD, float, cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>; cutlass::FloatRoundStyle::round_to_nearest>;
using BiasDescriptor =
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
EpilogueDescriptor, ElementD>;
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
BiasDescriptor::Stages, typename EpilogueDescriptor::TileShape, ElementD, 0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>; Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
public: public:
@@ -251,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
int64_t ldb = b.stride(1); int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0); int64_t ldc = out.stride(0);
using StrideA = Stride<int64_t, Int<1>, Int<0>>; using StrideA = Stride<int64_t, Int<1>, int64_t>;
using StrideB = Stride<int64_t, Int<1>, Int<0>>; using StrideB = Stride<int64_t, Int<1>, int64_t>;
using StrideC = typename Gemm::StrideC; using StrideC = typename Gemm::StrideC;
StrideA a_stride{lda, Int<1>{}, Int<0>{}}; StrideA a_stride{lda, Int<1>{}, 0};
StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; StrideB b_stride{ldb, Int<1>{}, 0};
StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
using GemmKernel = typename Gemm::GemmKernel; using GemmKernel = typename Gemm::GemmKernel;
@@ -282,11 +272,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK(gemm_op.can_implement(args)); CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args); size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
cutlass::Status status = gemm_op.run(args, workspace.get(), stream); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status); CUTLASS_CHECK(status);
} }

View File

@@ -38,13 +38,7 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
if (cuda_device_capability >= 90) { if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000; return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 89) { } else if (cuda_device_capability >= 89) {
// CUTLASS Kernels have not been tuned for Ada Lovelace systems return CUDA_VERSION >= 12040;
// and are slower than torch.mm. Return false unconditionally in this case.
return false;
// Once the CUTLASS kernels have been optimized for Lovelace systems,
// use the following check:
// return CUDA_VERSION >= 12040;
} }
#endif #endif

View File

@@ -526,6 +526,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
} }
#endif #endif
assert(false); assert(false);
return {}; // Squash missing return statement warning
} }
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
@@ -536,6 +537,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
} }
#endif #endif
assert(false); assert(false);
return {}; // Squash missing return statement warning
} }
// The following macro is used to dispatch the conversion function based on // The following macro is used to dispatch the conversion function based on

View File

@@ -48,7 +48,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
int64_t num_elems) { int64_t num_elems) {
__shared__ float cache[1024]; __shared__ float cache[1024];
int i = blockDim.x * blockIdx.x + threadIdx.x; int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
// First store maximum for all values processes by // First store maximum for all values processes by
// the current thread in cache[threadIdx.x] // the current thread in cache[threadIdx.x]

View File

@@ -475,6 +475,7 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
__NV_SATFINITE, fp8_type); __NV_SATFINITE, fp8_type);
return (uint8_t)res; return (uint8_t)res;
#endif #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
// float -> fp8 // float -> fp8
@@ -508,6 +509,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
} }
#endif #endif
assert(false); assert(false);
__builtin_unreachable(); // Suppress missing return statement warning
} }
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
@@ -520,6 +522,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
} }
#endif #endif
assert(false); assert(false);
__builtin_unreachable(); // Suppress missing return statement warning
} }
// The following macro is used to dispatch the conversion function based on // The following macro is used to dispatch the conversion function based on

View File

@@ -21,6 +21,7 @@
#include "marlin.cuh" #include "marlin.cuh"
#include "marlin_dtypes.cuh" #include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \ static_assert(std::is_same<scalar_t, half>::value || \
@@ -59,24 +60,27 @@ __global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn // (k/groupsize)xn
const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ g_idx, // int32 group indices of shape k
int num_groups, // number of scale groups per output channel int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m int prob_m, // batch dimension m
int prob_n, // output dimension n int prob_n, // output dimension n
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization int* locks, // extra global storage for barrier synchronization
bool use_fp32_reduce // whether to use fp32 global reduce
) {} ) {}
} // namespace gptq_marlin } // namespace marlin
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace, int64_t num_bits, torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k, int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full) { bool is_k_full, bool has_zp) {
TORCH_CHECK_NOT_IMPLEMENTED(false, TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"); "marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1}); return torch::empty({1, 1});
@@ -532,16 +536,18 @@ __global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn // (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor) // (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ g_idx, // int32 group indices of shape k
int num_groups, // number of scale groups per output channel int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m int prob_m, // batch dimension m
int prob_n, // output dimension n int prob_n, // output dimension n
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization int* locks, // extra global storage for barrier synchronization
bool use_fp32_reduce // whether to use fp32 global reduce
) { ) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the // Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 * // same size, which might involve multiple column "slices" (of width 16 *
@@ -595,6 +601,8 @@ __global__ void Marlin(
int slice_idx; // index of threadblock in current slice; numbered bottom to int slice_idx; // index of threadblock in current slice; numbered bottom to
// top // top
int par_id = 0;
// We can easily implement parallel problem execution by just remapping // We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers // indices and advancing global pointers
if (slice_col_par >= n_tiles) { if (slice_col_par >= n_tiles) {
@@ -602,6 +610,7 @@ __global__ void Marlin(
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
locks += (slice_col_par / n_tiles) * n_tiles; locks += (slice_col_par / n_tiles) * n_tiles;
slice_col = slice_col_par % n_tiles; slice_col = slice_col_par % n_tiles;
par_id = slice_col_par / n_tiles;
} }
// Compute all information about the current slice which is required for // Compute all information about the current slice which is required for
@@ -632,6 +641,7 @@ __global__ void Marlin(
C += 16 * thread_m_blocks * prob_n / 8; C += 16 * thread_m_blocks * prob_n / 8;
locks += n_tiles; locks += n_tiles;
slice_col = 0; slice_col = 0;
par_id++;
} }
}; };
init_slice(); init_slice();
@@ -1120,44 +1130,53 @@ __global__ void Marlin(
}; };
auto fetch_zp_to_registers = [&](int k, int full_pipe) { auto fetch_zp_to_registers = [&](int k, int full_pipe) {
if constexpr (!has_zp) { // This code does not handle group_blocks == 0,
return; // which signifies act_order.
} // has_zp implies AWQ, which doesn't have act_order,
static_assert(!has_zp || group_blocks != 0);
int pipe = full_pipe % stages; if constexpr (has_zp) {
int pipe = full_pipe % stages;
if constexpr (group_blocks == -1) { if constexpr (group_blocks == -1) {
for (int i = 0; i < num_ints_per_thread; i++) { for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i]; frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
} }
} else if constexpr (group_blocks >= thread_k_blocks) { } else if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage = int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks))); (pipe / (group_blocks / thread_k_blocks)));
for (int i = 0; i < num_ints_per_thread; i++) { for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i]; (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
} }
} else { } else {
int warp_id = threadIdx.x / 32; int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4; int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps; int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16; int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters); cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16; int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / group_blocks; int cur_group_id = 0;
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; // Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
sh_zp_stage += cur_group_id * zp_sh_stride; int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
for (int i = 0; i < num_ints_per_thread; i++) { sh_zp_stage += cur_group_id * zp_sh_stride;
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i]; for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} }
} }
}; };
@@ -1321,7 +1340,7 @@ __global__ void Marlin(
// finally have to globally reduce over the results. As the striped // finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are // partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache. // usually rather small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) { auto global_reduce_fp16 = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to // We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out // maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute). // results in FP16 (but still reduce with FP32 compute).
@@ -1382,6 +1401,53 @@ __global__ void Marlin(
} }
}; };
// Globally reduce over threadblocks that compute the same column block.
// We use a tmp C buffer to reduce in full fp32 precision.
auto global_reduce_fp32 = [&](bool first = false, bool last = false) {
constexpr int tb_m = thread_m_blocks * 16;
constexpr int tb_n = thread_n_blocks * 16;
constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
constexpr int active_threads = 32 * thread_n_blocks / 4;
bool is_th_active = threadIdx.x < active_threads;
int par_offset = c_size * n_tiles * par_id;
int slice_offset = c_size * slice_col;
constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
constexpr int th_size = num_floats * sizeof(float) / 16;
int c_cur_offset = par_offset + slice_offset;
if (!is_th_active) {
return;
}
if (!first) {
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
#pragma unroll
for (int k = 0; k < th_size; k++) {
sh[threadIdx.x] =
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
#pragma unroll
for (int f = 0; f < 4; f++) {
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
}
}
}
if (!last) {
int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
#pragma unroll
for (int k = 0; k < th_size; k++) {
C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];
}
}
};
// Write out the reduce final result in the correct layout. We only actually // Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed // reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout. // in fragment layout.
@@ -1606,7 +1672,11 @@ __global__ void Marlin(
if (slice_count > 1) { // only globally reduce if there is more than one if (slice_count > 1) { // only globally reduce if there is more than one
// block in a slice // block in a slice
barrier_acquire(&locks[slice_col], slice_idx); barrier_acquire(&locks[slice_col], slice_idx);
global_reduce(slice_idx == 0, last); if (use_fp32_reduce) {
global_reduce_fp32(slice_idx == 0, last);
} else {
global_reduce_fp16(slice_idx == 0, last);
}
barrier_release(&locks[slice_col], last); barrier_release(&locks[slice_col], last);
} }
if (last) // only the last block in a slice actually writes the result if (last) // only the last block in a slice actually writes the result
@@ -1661,8 +1731,8 @@ __global__ void Marlin(
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \ THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS> \ HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \ <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \ A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
prob_m, prob_n, prob_k, locks); \ num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
} }
typedef struct { typedef struct {
@@ -1801,6 +1871,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
return true; return true;
} }
int determine_reduce_max_m(int prob_m, int max_par) {
constexpr int tile_m_size = 16;
if (prob_m <= tile_m_size) {
return tile_m_size;
} else if (prob_m <= tile_m_size * 2) {
return tile_m_size * 2;
} else if (prob_m <= tile_m_size * 3) {
return tile_m_size * 3;
} else if (prob_m <= tile_m_size * 4) {
return tile_m_size * 4;
} else {
int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par);
return tile_m_size * 4 * cur_par;
}
}
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int num_bits, int group_size, int num_bits, int group_size,
bool has_act_order, bool is_k_full, bool has_act_order, bool is_k_full,
@@ -1880,18 +1971,29 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template <typename scalar_t> template <typename scalar_t>
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp, void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
void* g_idx, void* perm, void* a_tmp, int prob_m, void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
int prob_n, int prob_k, void* workspace, int num_bits, int prob_n, int prob_k, void* workspace,
bool has_act_order, bool is_k_full, bool has_zp, vllm::ScalarType const& q_type, bool has_act_order,
int num_groups, int group_size, int dev, bool is_k_full, bool has_zp, int num_groups, int group_size,
cudaStream_t stream, int thread_k, int thread_n, int sms, int dev, cudaStream_t stream, int thread_k, int thread_n,
int max_par) { int sms, int max_par, bool use_fp32_reduce) {
TORCH_CHECK(num_bits == 4 || num_bits == 8, if (has_zp) {
"num_bits must be 4 or 8. Got = ", num_bits); TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]"); ", ", prob_n, ", ", prob_k, "]");
// TODO: remove alias when we start supporting other 8bit types
int num_bits = q_type.size_bits();
int tot_m = prob_m; int tot_m = prob_m;
int tot_m_blocks = div_ceil(tot_m, 16); int tot_m_blocks = div_ceil(tot_m, 16);
int pad = 16 * tot_m_blocks - tot_m; int pad = 16 * tot_m_blocks - tot_m;
@@ -1970,6 +2072,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
const int4* A_ptr = (const int4*)A; const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B; const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C; int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* s_ptr = (const int4*)s; const int4* s_ptr = (const int4*)s;
const int4* zp_ptr = (const int4*)zp; const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx; const int* g_idx_ptr = (const int*)g_idx;
@@ -2042,18 +2145,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
} }
} }
} // namespace gptq_marlin } // namespace marlin
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace, int64_t num_bits, torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k, int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp) { bool is_k_full, bool has_zp,
// Verify num_bits bool use_fp32_reduce) {
TORCH_CHECK(num_bits == 4 || num_bits == 8, if (has_zp) {
"num_bits must be 4 or 8. Got = ", num_bits); TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8,
int pack_factor = 32 / num_bits; "b_q_type must be u4 or u8 when has_zp = True. Got = ",
b_q_type->str());
} else {
TORCH_CHECK(
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
b_q_type->str());
}
int pack_factor = 32 / b_q_type->size_bits();
// Verify A // Verify A
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
@@ -2099,6 +2212,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor c = torch::empty({size_m, size_n}, options); torch::Tensor c = torch::empty({size_m, size_n}, options);
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
// Alloc C tmp buffer that is going to be used for the global reduce
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
int reduce_n = size_n;
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (!use_fp32_reduce) {
reduce_max_m = 0;
reduce_n = 0;
}
torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1) // auto -1)
int thread_k = -1; int thread_k = -1;
@@ -2169,22 +2293,23 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int dev = a.get_device(); int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) { if (a.scalar_type() == at::ScalarType::Half) {
marlin::marlin_mm_f16i4<half>( marlin::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
b_scales.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(), c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
marlin::marlin_mm_f16i4<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k, a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par); thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
marlin::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
} else { } else {
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
} }

View File

@@ -0,0 +1,32 @@
/*
* Modified by HandH1998
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) { return elems[i]; }
};

View File

@@ -0,0 +1,89 @@
/*
* Modified by HandH1998
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
// Asynchronous global->shared copy
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
// Async copy fence.
__device__ inline void cp_async_fence() {
asm volatile("cp.async.commit_group;\n" ::);
}
// Wait until at most `n` async copy stages are still pending.
template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
}
__syncthreads();
}
// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
}
}

View File

@@ -25,6 +25,12 @@
#include <iostream> #include <iostream>
#include "common/base.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "common/mem.h"
#endif
template <typename T> template <typename T>
inline std::string str(T x) { inline std::string str(T x) {
return std::to_string(x); return std::to_string(x);
@@ -32,23 +38,9 @@ inline std::string str(T x) {
namespace marlin_dense { namespace marlin_dense {
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) { return elems[i]; }
};
using I4 = Vec<int, 4>; using I4 = Vec<int, 4>;
// Matrix fragments for tensor core instructions; their precise layout is // Matrix fragments for tensor core instructions; their precise layout is
// documented here: // documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
@@ -57,43 +49,6 @@ using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales using FragS = Vec<half2, 1>; // quantization scales
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
// Asynchronous global->shared copy
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
// Async copy fence.
__device__ inline void cp_async_fence() {
asm volatile("cp.async.commit_group;\n" ::);
}
// Wait until at most `n` async copy stages are still pending.
template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation. // output/accumulation.
__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, __device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
@@ -164,39 +119,6 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
frag_b[1] = __hmul2(frag_b[1], s); frag_b[1] = __hmul2(frag_b[1], s);
} }
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
}
__syncthreads();
}
// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
}
}
template <const int threads, // number of threads in a threadblock template <const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
@@ -452,10 +374,15 @@ __global__ void Marlin(
B_ptr[i] += b_gl_rd_delta_o; B_ptr[i] += b_gl_rd_delta_o;
} }
// Only fetch scales if this tile starts a new group // Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { if constexpr (group_blocks != -1) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe; // This assumes group_blocks >= thread_k_blocks
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); // and would need to be modified to support smaller groups.
s_gl_rd += s_gl_rd_delta; static_assert(group_blocks >= thread_k_blocks);
if (pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
} }
} }
// Insert a fence even when we are winding down the pipeline to ensure that // Insert a fence even when we are winding down the pipeline to ensure that
@@ -480,7 +407,10 @@ __global__ void Marlin(
// however, this does not seem to be a significant bottleneck, while some // however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by // theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance. // the compiler and correspondingly a noticeable drop in performance.
if (group_blocks != -1) { if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
int4* sh_s_stage = int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks))); (pipe / (group_blocks / thread_k_blocks)));

File diff suppressed because it is too large Load Diff

View File

@@ -27,6 +27,7 @@
#include <iostream> #include <iostream>
#include "common/base.h" #include "common/base.h"
#include "core/scalar_type.hpp"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
@@ -86,7 +87,8 @@ __global__ void Marlin_24(
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta, torch::Tensor& b_meta,
torch::Tensor& b_scales, torch::Tensor& b_scales,
torch::Tensor& workspace, int64_t num_bits, torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_m, int64_t size_n,
int64_t size_k) { int64_t size_k) {
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
@@ -404,10 +406,15 @@ __global__ void Marlin_24(
meta_ptr[i] += m_gl_rd_delta_o; meta_ptr[i] += m_gl_rd_delta_o;
} }
// Only fetch scales if this tile starts a new group // Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { if constexpr (group_blocks != -1) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe; // This assumes group_blocks >= thread_k_blocks
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); // and would need to be modified to support smaller groups.
s_gl_rd += s_gl_rd_delta; static_assert(group_blocks >= thread_k_blocks);
if (pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
} }
} }
// Insert a fence even when we are winding down the pipeline to ensure that // Insert a fence even when we are winding down the pipeline to ensure that
@@ -432,7 +439,10 @@ __global__ void Marlin_24(
// however, this does not seem to be a significant bottleneck, while some // however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by // theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance. // the compiler and correspondingly a noticeable drop in performance.
if (group_blocks != -1) { if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
int4* sh_s_stage = int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks))); (pipe / (group_blocks / thread_k_blocks)));
@@ -1017,13 +1027,14 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta, torch::Tensor& b_meta,
torch::Tensor& b_scales, torch::Tensor& b_scales,
torch::Tensor& workspace, int64_t num_bits, torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_m, int64_t size_n,
int64_t size_k) { int64_t size_k) {
// Verify num_bits // Verify num_bits
TORCH_CHECK(num_bits == 4 || num_bits == 8, TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
"num_bits must be 4 or 8. Got = ", num_bits); "num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str());
int pack_factor = 32 / num_bits; int pack_factor = 32 / b_q_type->size_bits();
// Verify M // Verify M
TORCH_CHECK(size_m == a.size(0), TORCH_CHECK(size_m == a.size(0),
@@ -1118,8 +1129,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
marlin_24::marlin_cuda_2_4( marlin_24::marlin_cuda_2_4(
a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(), a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(), b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, b_q_type->size_bits(), groupsize, dev,
thread_m, sms, max_par); at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par);
return c; return c;
} }

View File

@@ -197,13 +197,14 @@ void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>( vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
#ifndef USE_ROCM #ifndef USE_ROCM
(half2*)vec.data<at::Half>(), (half2*)vec.data_ptr<at::Half>(),
#else #else
(__half2*)vec.data_ptr<at::Half>(), (__half2*)vec.data_ptr<at::Half>(),
#endif #endif
mat.data_ptr<int>(), mat.data_ptr<int>(),
#ifndef USE_ROCM #ifndef USE_ROCM
(half2*)mul.data<at::Half>(), (__half*)lookup_table.data<at::Half>(), (half2*)mul.data_ptr<at::Half>(),
(__half*)lookup_table.data_ptr<at::Half>(),
#else #else
(float2*)mul.data_ptr<float>(), (float2*)mul.data_ptr<float>(),
(__half*)lookup_table.data_ptr<at::Half>(), (__half*)lookup_table.data_ptr<at::Half>(),

View File

@@ -1,7 +1,7 @@
#include "cache.h" #include "cache.h"
#include "cuda_utils.h" #include "cuda_utils.h"
#include "ops.h" #include "ops.h"
#include "registration.h" #include "core/registration.h"
#include <torch/library.h> #include <torch/library.h>
@@ -149,6 +149,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm); ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm); ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
// marlin_qqq_gemm for QQQ.
ops.def("marlin_qqq_gemm", &marlin_qqq_gemm);
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization. // quantization.
ops.def( ops.def(
@@ -248,7 +252,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor! key_cache," " Tensor! key_cache,"
" Tensor! value_cache," " Tensor! value_cache,"
" Tensor slot_mapping," " Tensor slot_mapping,"
" str kv_cache_dtype) -> ()"); " str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
&reshape_and_cache_flash); &reshape_and_cache_flash);

View File

@@ -0,0 +1,16 @@
document.addEventListener("DOMContentLoaded", function () {
var script = document.createElement("script");
script.type = "module";
script.id = "runllm-widget-script"
script.src = "https://widget.runllm.com";
script.setAttribute("version", "stable");
script.setAttribute("runllm-keyboard-shortcut", "Mod+j"); // cmd-j or ctrl-j to open the widget.
script.setAttribute("runllm-name", "vLLM");
script.setAttribute("runllm-position", "BOTTOM_RIGHT");
script.setAttribute("runllm-assistant-id", "207");
script.async = true;
document.head.appendChild(script);
});

View File

@@ -5,6 +5,7 @@ vLLM Meetups
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
- `The fifth vLLM meetup <https://lu.ma/lp0gyjqr>`__, with AWS, July 24th 2024. `[Slides] <https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing>`__
- `The fourth vLLM meetup <https://lu.ma/agivllm>`__, with Cloudflare and BentoML, June 11th 2024. `[Slides] <https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing>`__ - `The fourth vLLM meetup <https://lu.ma/agivllm>`__, with Cloudflare and BentoML, June 11th 2024. `[Slides] <https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing>`__
- `The third vLLM meetup <https://robloxandvllmmeetup2024.splashthat.com/>`__, with Roblox, April 2nd 2024. `[Slides] <https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing>`__ - `The third vLLM meetup <https://robloxandvllmmeetup2024.splashthat.com/>`__, with Roblox, April 2nd 2024. `[Slides] <https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing>`__
- `The second vLLM meetup <https://lu.ma/ygxbpzhl>`__, with IBM Research, January 31st 2024. `[Slides] <https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing>`__ `[Video (vLLM Update)] <https://youtu.be/Y0C-DUvEnZQ>`__ `[Video (IBM Research & torch.compile)] <https://youtu.be/m0dMtFLI-dg>`__ - `The second vLLM meetup <https://lu.ma/ygxbpzhl>`__, with IBM Research, January 31st 2024. `[Slides] <https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing>`__ `[Video (vLLM Update)] <https://youtu.be/Y0C-DUvEnZQ>`__ `[Video (IBM Research & torch.compile)] <https://youtu.be/m0dMtFLI-dg>`__

View File

@@ -68,6 +68,8 @@ html_theme_options = {
'use_repository_button': True, 'use_repository_button': True,
'use_edit_page_button': True, 'use_edit_page_button': True,
} }
html_static_path = ["_static"]
html_js_files = ["custom.js"]
# see https://docs.readthedocs.io/en/stable/reference/environment-variables.html # noqa # see https://docs.readthedocs.io/en/stable/reference/environment-variables.html # noqa
READTHEDOCS_VERSION_TYPE = os.environ.get('READTHEDOCS_VERSION_TYPE') READTHEDOCS_VERSION_TYPE = os.environ.get('READTHEDOCS_VERSION_TYPE')
@@ -94,6 +96,7 @@ def setup(app):
# Mock out external dependencies here, otherwise the autodoc pages may be blank. # Mock out external dependencies here, otherwise the autodoc pages may be blank.
autodoc_mock_imports = [ autodoc_mock_imports = [
"aiohttp",
"cpuinfo", "cpuinfo",
"torch", "torch",
"transformers", "transformers",
@@ -108,6 +111,7 @@ autodoc_mock_imports = [
"tqdm", "tqdm",
"tensorizer", "tensorizer",
"pynvml", "pynvml",
"outlines",
] ]
for mock_target in autodoc_mock_imports: for mock_target in autodoc_mock_imports:
@@ -141,5 +145,6 @@ intersphinx_mapping = {
} }
autodoc_preserve_defaults = True autodoc_preserve_defaults = True
autodoc_warningiserror = True
navigation_with_keys = False navigation_with_keys = False

View File

@@ -40,8 +40,12 @@ Registry
Base Classes Base Classes
------------ ------------
.. autodata:: vllm.multimodal.NestedTensors
.. autodata:: vllm.multimodal.BatchedTensors .. autodata:: vllm.multimodal.BatchedTensors
.. autodata:: vllm.multimodal.BatchedTensorInputs
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins .. autoclass:: vllm.multimodal.MultiModalDataBuiltins
:members: :members:
:show-inheritance: :show-inheritance:

View File

@@ -107,9 +107,45 @@ Alternatively, wheels intended for vLLM use can be accessed under the releases.
$ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation $ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
.. tip::
For example, vLLM v0.5.3 on ROCM 6.1 can be built with the following steps:
.. code-block:: console
$ pip install --upgrade pip
$ # Install PyTorch
$ pip uninstall torch -y
$ pip install --no-cache-dir --pre torch==2.5.0.dev20240726 --index-url https://download.pytorch.org/whl/nightly/rocm6.1
$ # Build & install AMD SMI
$ pip install /opt/rocm/share/amd_smi
$ # Install dependencies
$ pip install --upgrade numba scipy huggingface-hub[cli]
$ pip install "numpy<2"
$ pip install -r requirements-rocm.txt
$ # Apply the patch to ROCM 6.1 (requires root permission)
$ wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib
$ rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so*
$ # Build vLLM for MI210/MI250/MI300.
$ export PYTORCH_ROCM_ARCH="gfx90a;gfx942"
$ python3 setup.py develop
.. tip:: .. tip::
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
- Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support. - Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support.
- To use CK flash-attention or PyTorch naive attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention. - To use CK flash-attention or PyTorch naive attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
- The ROCm version of PyTorch, ideally, should match the ROCm driver version. - The ROCm version of PyTorch, ideally, should match the ROCm driver version.
.. tip::
- For MI300x (gfx942) users, to achieve optimal performance, please refer to `MI300x tuning guide <https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/index.html>`_ for performance optimization and tuning tips on system and workflow level.
For vLLM, please refer to `vLLM performance optimization <https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization>`_.

View File

@@ -10,6 +10,7 @@ Table of contents:
#. :ref:`Requirements <cpu_backend_requirements>` #. :ref:`Requirements <cpu_backend_requirements>`
#. :ref:`Quick start using Dockerfile <cpu_backend_quick_start_dockerfile>` #. :ref:`Quick start using Dockerfile <cpu_backend_quick_start_dockerfile>`
#. :ref:`Build from source <build_cpu_backend_from_source>` #. :ref:`Build from source <build_cpu_backend_from_source>`
#. :ref:`Related runtime environment variables <env_intro>`
#. :ref:`Intel Extension for PyTorch <ipex_guidance>` #. :ref:`Intel Extension for PyTorch <ipex_guidance>`
#. :ref:`Performance tips <cpu_backend_performance_tips>` #. :ref:`Performance tips <cpu_backend_performance_tips>`
@@ -47,7 +48,7 @@ Build from source
.. code-block:: console .. code-block:: console
$ sudo apt-get update -y $ sudo apt-get update -y
$ sudo apt-get install -y gcc-12 g++-12 $ sudo apt-get install -y gcc-12 g++-12 libnuma-dev
$ sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 $ sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
- Second, install Python packages for vLLM CPU backend building: - Second, install Python packages for vLLM CPU backend building:
@@ -71,6 +72,15 @@ Build from source
- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building. - If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.
.. _env_intro:
Related runtime environment variables
-------------------------------------
- ``VLLM_CPU_KVCACHE_SPACE``: specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
- ``VLLM_CPU_OMP_THREADS_BIND``: specify the CPU cores dedicated to the OpenMP threads. For example, ``VLLM_CPU_OMP_THREADS_BIND=0-31`` means there will be 32 OpenMP threads bound on 0-31 CPU cores. ``VLLM_CPU_OMP_THREADS_BIND=0-31|32-63`` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores.
.. _ipex_guidance: .. _ipex_guidance:
Intel Extension for PyTorch Intel Extension for PyTorch
@@ -78,15 +88,11 @@ Intel Extension for PyTorch
- `Intel Extension for PyTorch (IPEX) <https://github.com/intel/intel-extension-for-pytorch>`_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware. - `Intel Extension for PyTorch (IPEX) <https://github.com/intel/intel-extension-for-pytorch>`_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware.
- IPEX after the ``2.3.0`` can be enabled in the CPU backend by default if it is installed.
.. _cpu_backend_performance_tips: .. _cpu_backend_performance_tips:
Performance tips Performance tips
----------------- -----------------
- vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
- We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run: - We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run:
.. code-block:: console .. code-block:: console
@@ -96,11 +102,44 @@ Performance tips
$ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD $ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD
$ python examples/offline_inference.py # run vLLM $ python examples/offline_inference.py # run vLLM
- vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription. - When using the online serving, it is recommended to reserve 1-2 CPU cores for the serving framework to avoid CPU oversubscription. For example, on a platform with 32 physical CPU cores, reserving CPU 30 and 31 for the framework and using CPU 0-29 for OpenMP:
- If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading. .. code-block:: console
- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores and memory nodes, to avoid the remote memory node access. ``numactl`` is an useful tool for CPU core and memory binding on NUMA platform. Besides, ``--cpuset-cpus`` and ``--cpuset-mems`` arguments of ``docker run`` are also useful. $ export VLLM_CPU_KVCACHE_SPACE=40
$ export VLLM_CPU_OMP_THREADS_BIND=0-29
$ vllm serve facebook/opt-125m
- If using vLLM CPU backend on a machine with hyper-threading, it is recommended to bind only one OpenMP thread on each physical CPU core using ``VLLM_CPU_OMP_THREADS_BIND``. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
.. code-block:: console
$ lscpu -e # check the mapping between logical CPU cores and physical CPU cores
# The "CPU" column means the logical CPU core IDs, and the "CORE" column means the physical core IDs. On this platform, two logical cores are sharing one physical core.
CPU NODE SOCKET CORE L1d:L1i:L2:L3 ONLINE MAXMHZ MINMHZ MHZ
0 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
1 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
2 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
3 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
4 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
5 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
6 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
7 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
8 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
9 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
10 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
11 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
12 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
13 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
14 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
15 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
# On this platform, it is recommend to only bind openMP threads on logical CPU cores 0-7 or 8-15
$ export VLLM_CPU_OMP_THREADS_BIND=0-7
$ python examples/offline_inference.py
- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores using ``VLLM_CPU_OMP_THREADS_BIND`` to avoid cross NUMA node memory access.

View File

@@ -65,6 +65,10 @@ Here are some common issues that can cause hangs:
If the problem persists, feel free to `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_, with a detailed description of the issue, your environment, and the logs. If the problem persists, feel free to `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_, with a detailed description of the issue, your environment, and the logs.
Some known issues:
- In ``v0.5.2``, ``v0.5.3``, and ``v0.5.3.post1``, there is a bug caused by `zmq <https://github.com/zeromq/pyzmq/issues/2000>`_ , which can cause hangs at a low probability (once in about 20 times, depending on the machine configuration). The solution is to upgrade to the latest version of ``vllm`` to include the `fix <https://github.com/vllm-project/vllm/pull/6759>`_ .
.. warning:: .. warning::
After you find the root cause and solve the issue, remember to turn off all the debugging environment variables defined above, or simply start a new shell to avoid being affected by the debugging settings. If you don't do this, the system might be slow because many debugging functionalities are turned on. After you find the root cause and solve the issue, remember to turn off all the debugging environment variables defined above, or simply start a new shell to avoid being affected by the debugging settings. If you don't do this, the system might be slow because many debugging functionalities are turned on.

View File

@@ -9,7 +9,7 @@ Requirements
------------ ------------
* OS: Linux * OS: Linux
* Python: 3.8 -- 3.11 * Python: 3.8 -- 3.12
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, H100, etc.) * GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, H100, etc.)
Install with pip Install with pip
@@ -48,7 +48,7 @@ You can install vLLM using pip:
.. code-block:: console .. code-block:: console
$ export VLLM_VERSION=0.5.2 # vLLM's main branch version is currently set to latest released tag $ export VLLM_VERSION=0.5.4 # vLLM's main branch version is currently set to latest released tag
$ pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-${VLLM_VERSION}-cp38-abi3-manylinux1_x86_64.whl $ pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-${VLLM_VERSION}-cp38-abi3-manylinux1_x86_64.whl
$ # You can also access a specific commit $ # You can also access a specific commit
$ # export VLLM_COMMIT=... $ # export VLLM_COMMIT=...
@@ -66,7 +66,6 @@ You can also build and install vLLM from source:
$ git clone https://github.com/vllm-project/vllm.git $ git clone https://github.com/vllm-project/vllm.git
$ cd vllm $ cd vllm
$ # export VLLM_INSTALL_PUNICA_KERNELS=1 # optionally build for multi-LoRA capability
$ pip install -e . # This may take 5-10 minutes. $ pip install -e . # This may take 5-10 minutes.
.. tip:: .. tip::

View File

@@ -131,6 +131,6 @@ Once neuronx-cc and transformers-neuronx packages are installed, we will be able
$ git clone https://github.com/vllm-project/vllm.git $ git clone https://github.com/vllm-project/vllm.git
$ cd vllm $ cd vllm
$ pip install -U -r requirements-neuron.txt $ pip install -U -r requirements-neuron.txt
$ pip install . $ VLLM_TARGET_DEVICE="neuron" pip install .
If neuron packages are detected correctly in the installation process, ``vllm-0.3.0+neuron212`` will be installed. If neuron packages are detected correctly in the installation process, ``vllm-0.3.0+neuron212`` will be installed.

View File

@@ -57,7 +57,7 @@ Install from source
.. code-block:: console .. code-block:: console
$ PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE=openvino python -m pip install -v . $ PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/pre-release" VLLM_TARGET_DEVICE=openvino python -m pip install -v .
.. _openvino_backend_performance_tips: .. _openvino_backend_performance_tips:

View File

@@ -56,7 +56,7 @@ First, install the dependencies:
$ pip uninstall torch torch-xla -y $ pip uninstall torch torch-xla -y
$ # Install PyTorch and PyTorch XLA. $ # Install PyTorch and PyTorch XLA.
$ export DATE="+20240713" $ export DATE="+20240726"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
@@ -75,6 +75,13 @@ Next, build vLLM from source. This will only take a few seconds:
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop $ VLLM_TARGET_DEVICE="tpu" python setup.py develop
.. note::
Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape.
The compilation time may take 20~30 minutes in the first run.
However, the compilation time reduces to ~5 minutes afterwards because the XLA graphs are cached in the disk (in :code:`VLLM_XLA_CACHE_PATH` or :code:`~/.cache/vllm/xla_cache` by default).
.. tip:: .. tip::
If you encounter the following error: If you encounter the following error:

View File

@@ -105,6 +105,7 @@ Documentation
quantization/supported_hardware quantization/supported_hardware
quantization/auto_awq quantization/auto_awq
quantization/bnb
quantization/fp8 quantization/fp8
quantization/fp8_e5m2_kvcache quantization/fp8_e5m2_kvcache
quantization/fp8_e4m3_kvcache quantization/fp8_e4m3_kvcache
@@ -116,6 +117,12 @@ Documentation
automatic_prefix_caching/apc automatic_prefix_caching/apc
automatic_prefix_caching/details automatic_prefix_caching/details
.. toctree::
:maxdepth: 1
:caption: Performance benchmarks
performance_benchmark/benchmarks
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
:caption: Developer Documentation :caption: Developer Documentation

View File

@@ -7,6 +7,8 @@ vLLM supports a variety of generative Transformer models in `HuggingFace Transfo
The following is the list of model architectures that are currently supported by vLLM. The following is the list of model architectures that are currently supported by vLLM.
Alongside each architecture, we include some popular models that use it. Alongside each architecture, we include some popular models that use it.
----
Decoder-only Language Models Decoder-only Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. list-table:: .. list-table::
@@ -113,6 +115,10 @@ Decoder-only Language Models
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc. - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
- -
* - :code:`NemotronForCausalLM`
- Nemotron-3, Nemotron-4, Minitron
- :code:`nvidia/Minitron-8B-Base`, :code:`mgoin/Nemotron-4-340B-Base-hf-FP8`, etc.
- ✅︎
* - :code:`OLMoForCausalLM` * - :code:`OLMoForCausalLM`
- OLMo - OLMo
- :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc. - :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc.
@@ -182,6 +188,10 @@ Vision Language Models
- Models - Models
- Example HuggingFace Models - Example HuggingFace Models
- :ref:`LoRA <lora>` - :ref:`LoRA <lora>`
* - :code:`Blip2ForConditionalGeneration`
- BLIP-2
- :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc.
-
* - :code:`ChameleonForConditionalGeneration` * - :code:`ChameleonForConditionalGeneration`
- Chameleon - Chameleon
- :code:`facebook/chameleon-7b` etc. - :code:`facebook/chameleon-7b` etc.
@@ -190,6 +200,10 @@ Vision Language Models
- Fuyu - Fuyu
- :code:`adept/fuyu-8b` etc. - :code:`adept/fuyu-8b` etc.
- -
* - :code:`InternVLChatModel`
- InternVL2
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
-
* - :code:`LlavaForConditionalGeneration` * - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5 - LLaVA-1.5
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
@@ -206,6 +220,16 @@ Vision Language Models
- Phi-3-Vision - Phi-3-Vision
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc. - :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
- -
* - :code:`MiniCPMV`
- MiniCPM-V
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
-
.. note::
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
----
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>` Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>`

View File

@@ -73,7 +73,7 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptI
generated_text = o.outputs[0].text generated_text = o.outputs[0].text
print(generated_text) print(generated_text)
A code example can be found in `examples/llava_example.py <https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py>`_. A code example can be found in `examples/offline_inference_vision_language.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_vision_language.py>`_.
Online OpenAI Vision API Compatible Inference Online OpenAI Vision API Compatible Inference

View File

@@ -0,0 +1,23 @@
.. _benchmarks:
Benchmark suites of vLLM
========================
vLLM contains two sets of benchmarks:
+ **Performance benchmarks**: benchmark vLLM's performance under various workloads at a high frequency (when a pull request (PR for short) of vLLM is being merged). See `vLLM performance dashboard <https://perf.vllm.ai>`_ for the latest performance results.
+ **Nightly benchmarks**: compare vLLM's performance against alternatives (tgi, trt-llm, and lmdeploy) when there are major updates of vLLM (e.g., bumping up to a new version). The latest results are available in the `vLLM GitHub README <https://github.com/vllm-project/vllm/blob/main/README.md>`_.
Trigger a benchmark
-------------------
The performance benchmarks and nightly benchmarks can be triggered by submitting a PR to vLLM, and label the PR with `perf-benchmarks` and `nightly-benchmarks`.
.. note::
Please refer to `vLLM performance benchmark descriptions <https://github.com/vllm-project/vllm/blob/main/.buildkite/nightly-benchmarks/tests/descriptions.md>`_ and `vLLM nightly benchmark descriptions <https://github.com/vllm-project/vllm/blob/main/.buildkite/nightly-benchmarks/nightly-descriptions.md>`_ for detailed descriptions on benchmark environment, workload and metrics.

View File

@@ -0,0 +1,43 @@
.. _bits_and_bytes:
BitsAndBytes
==================
vLLM now supports `BitsAndBytes <https://github.com/TimDettmers/bitsandbytes>`_ for more efficient model inference.
BitsAndBytes quantizes models to reduce memory usage and enhance performance without significantly sacrificing accuracy.
Compared to other quantization methods, BitsAndBytes eliminates the need for calibrating the quantized model with input data.
Below are the steps to utilize BitsAndBytes with vLLM.
.. code-block:: console
$ pip install bitsandbytes>=0.42.0
vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint.
You can find bitsandbytes quantized models on https://huggingface.co/models?other=bitsandbytes.
And usually, these repositories have a config.json file that includes a quantization_config section.
Read quantized checkpoint.
--------------------------
.. code-block:: python
from vllm import LLM
import torch
# unsloth/tinyllama-bnb-4bit is a pre-quantized checkpoint.
model_id = "unsloth/tinyllama-bnb-4bit"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
quantization="bitsandbytes", load_format="bitsandbytes")
Inflight quantization: load as 4bit quantization
------------------------------------------------
.. code-block:: python
from vllm import LLM
import torch
model_id = "huggyllama/llama-7b"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
quantization="bitsandbytes", load_format="bitsandbytes")

Some files were not shown because too many files have changed in this diff Show More