Compare commits

..

382 Commits

Author SHA1 Message Date
youkaichao
8f89d72090 [Doc] add common case for long waiting time (#5430)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.8, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.10, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.11, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.8, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.9, 2.3.0) (push) Has been cancelled
2024-06-11 11:12:13 -07:00
Nick Hill
99dac099ab [Core][Doc] Default to multiprocessing for single-node distributed case (#5230)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
2024-06-11 11:10:41 -07:00
youkaichao
c4bd03c7c5 [Core][Distributed] add same-node detection (#5369) 2024-06-11 10:53:59 -07:00
sasha0552
dcbf4286af [Frontend] Customizable RoPE theta (#5197) 2024-06-11 10:42:26 -07:00
Ali Panahi
00e6a2dc53 [Bugfix] fix lora_dtype value type in arg_utils.py (#5398) 2024-06-11 10:40:23 -07:00
Junichi Sato
2e02311a1b [Bugfix] Fix MultiprocessingGPUExecutor.check_health when world_size == 1 (#5254) 2024-06-11 10:38:07 -07:00
Cade Daniel
89ec06c33b [Docs] [Spec decode] Fix docs error in code example (#5427) 2024-06-11 10:31:56 -07:00
Kuntai Du
9fde251bf0 [Doc] Add an automatic prefix caching section in vllm documentation (#5324)
Co-authored-by: simon-mo <simon.mo@hey.com>
2024-06-11 10:24:59 -07:00
Cade Daniel
4c2ffb28ff [Speculative decoding] Initial spec decode docs (#5400) 2024-06-11 10:15:40 -07:00
SangBin Cho
246598a6b1 [CI] docfix (#5410)
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: ywang96 <ywang@roblox.com>
2024-06-11 01:28:50 -07:00
Woosuk Kwon
8bab4959be [Misc] Remove VLLM_BUILD_WITH_NEURON env variable (#5389) 2024-06-11 00:37:56 -07:00
Roger Wang
3c4cebf751 [Doc][Typo] Fixing Missing Comma (#5403) 2024-06-11 00:20:28 -07:00
youkaichao
d8f31f2f8b [Doc] add debugging tips (#5409) 2024-06-10 23:21:43 -07:00
Cyrus Leung
640052b069 [Bugfix][Frontend] Cleanup "fix chat logprobs" (#5026) 2024-06-10 22:36:46 -07:00
maor-ps
351d5e7b82 [Bugfix] OpenAI entrypoint limits logprobs while ignoring server defined --max-logprobs (#5312)
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
2024-06-11 10:30:31 +08:00
Nick Hill
a008629807 [Misc] Various simplifications and typing fixes (#5368) 2024-06-11 10:29:02 +08:00
Kevin H. Luu
76477a93b7 [ci] Fix Buildkite agent path (#5392)
Signed-off-by: kevin <kevin@anyscale.com>
2024-06-10 18:58:07 -07:00
Michael Goin
77c87beb06 [Doc] Add documentation for FP8 W8A8 (#5388) 2024-06-10 18:55:12 -06:00
Simon Mo
114332b88e Bump version to v0.5.0 (#5384) 2024-06-10 15:56:06 -07:00
Woosuk Kwon
cb77ad836f [Docs] Alphabetically sort sponsors (#5386) 2024-06-10 15:17:19 -05:00
Roger Wang
856c990041 [Docs] Add Docs on Limitations of VLM Support (#5383) 2024-06-10 09:53:50 -07:00
Kevin H. Luu
c5602f0baa [ci] Mount buildkite agent on Docker container to upload benchmark results (#5330)
Signed-off-by: kevin <kevin@anyscale.com>
2024-06-10 09:22:34 -07:00
Kevin H. Luu
f7f9c5f97b [ci] Use small_cpu_queue for doc build (#5331)
Signed-off-by: kevin <kevin@anyscale.com>
2024-06-10 09:21:11 -07:00
Cyrus Leung
2c0d933594 [Bugfix] Fix LLaVA-NeXT (#5380) 2024-06-10 15:38:47 +00:00
Itay Etelis
774d1035e4 [Feature][Frontend]: Continued stream_options implementation also in CompletionRequest (#5319) 2024-06-10 14:22:09 +00:00
Cyrus Leung
6b29d6fe70 [Model] Initial support for LLaVA-NeXT (#4199)
Co-authored-by: Roger Wang <ywang@roblox.com>
2024-06-10 12:47:15 +00:00
Cyrus Leung
0bfa1c4f13 [Misc] Improve error message when LoRA parsing fails (#5194) 2024-06-10 19:38:49 +08:00
youkaichao
c81da5f56d [misc][typo] fix typo (#5372) 2024-06-10 09:51:02 +00:00
Roger Wang
68bc81703e [Frontend][Misc] Enforce Pixel Values as Input Type for VLMs in API Server (#5374) 2024-06-10 09:13:39 +00:00
Dipika Sikka
5884c2b454 [Misc] Update to comply with the new compressed-tensors config (#5350)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
2024-06-10 03:49:46 +00:00
Bla_ckB
45f92c00cf [Bugfix] Fix KeyError: 1 When Using LoRA adapters (#5164) 2024-06-09 16:23:14 -07:00
bnellnm
5467ac3196 [Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#5047) 2024-06-09 16:23:30 -04:00
youkaichao
5d7e3d0176 [mis][ci/test] fix flaky test in test_sharded_state_loader.py (#5361)
[mis][ci/test] fix flaky test in tests/test_sharded_state_loader.py (#5361)
2024-06-09 03:50:14 +00:00
youkaichao
0373e1837e [Core][CUDA Graph] add output buffer for cudagraph (#5074)
[Core][CUDA Graph] add output buffer for cudagraph to reduce memory footprint (#5074)
2024-06-08 19:14:43 -07:00
Michael Goin
c09dade2a2 [Misc][Breaking] Change FP8 checkpoint format from act_scale -> input_scale (#5353) 2024-06-08 13:54:05 -04:00
youkaichao
8ea5e44a43 [CI/Test] improve robustness of test (vllm_runner) (#5357)
[CI/Test] improve robustness of test by replacing del with context manager (vllm_runner) (#5357)
2024-06-08 08:59:20 +00:00
youkaichao
9fb900f90c [CI/Test] improve robustness of test (hf_runner) (#5347)
[CI/Test] improve robustness of test by replacing del with context manager (hf_runner) (#5347)
2024-06-07 22:31:32 -07:00
Hongxia Yang
c96fc06747 [ROCm][AMD] Use pytorch sdpa math backend to do naive attention (#4965) 2024-06-07 19:13:12 -07:00
Benjamin Kitor
b3376e5c76 [Misc] Add args for selecting distributed executor to benchmarks (#5335) 2024-06-08 09:20:16 +08:00
Cheng Li
e69ded7d1c [Bug Fix] Fix the support check for FP8 CUTLASS (#5352)
Bug description:
With torch 2.4.0.dev20240603+cu121,
cutlass_fp8_supported outputs False, and the (capability, version) before the comparison is (90, 11111111112)

This PR fixes the support check for FP8 CUTLASS ( cutlass_fp8_supported) which was introduced in https://github.com/vllm-project/vllm/pull/5183.
2024-06-08 00:42:05 +00:00
Calvinn Ng
767c727a81 fix DbrxFusedNormAttention missing cache_config (#5340)
Co-authored-by: team <calvinn.ng@ahrefs.com>
2024-06-07 14:10:21 -07:00
Jie Fu (傅杰)
6840a71610 [Misc] Remove unused cuda_utils.h in CPU backend (#5345) 2024-06-07 14:09:13 -07:00
Roger Wang
7a9cb294ae [Frontend] Add OpenAI Vision API Support (#5237)
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
2024-06-07 11:23:32 -07:00
Dipika Sikka
ca3ea51bde [Kernel] Dynamic Per-Token Activation Quantization (#5037)
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
2024-06-07 09:36:26 -07:00
limingshu
dc49fb892c Addition of lacked ignored_seq_groups in _schedule_chunked_prefill (#5296) 2024-06-07 13:35:42 +00:00
Antoni Baum
18a277b52d Remove Ray health check (#4693) 2024-06-07 10:01:56 +00:00
Tyler Michael Smith
8d75fe48ca [Kernel] Switch fp8 layers to use the CUTLASS kernels (#5183)
Switching from torch._scaled_mm to vLLM's cutlass fp8 kernels when supported as we are seeing 5-15% improvement in e2e performance on neuralmagic/Meta-Llama-3-8B-Instruct-FP8

see https://docs.google.com/spreadsheets/d/1GiAnmzyGHgZ6zL_LDSTm35Bdrt4A8AaFEurDlISYYA4/ for some quick e2e benchmarks and #5144 for comparisons across different GEMM sizes.
2024-06-07 08:42:35 +00:00
youkaichao
388596c914 [Misc][Utils] allow get_open_port to be called for multiple times (#5333) 2024-06-06 22:15:11 -07:00
Itay Etelis
baa15a9ec3 [Feature][Frontend]: Add support for stream_options in ChatCompletionRequest (#5135) 2024-06-07 03:29:24 +00:00
Jie Fu (傅杰)
15063741e3 [Misc] Missing error message for custom ops import (#5282) 2024-06-06 20:17:21 -07:00
Antoni Baum
ccdc490dda [Core] Change LoRA embedding sharding to support loading methods (#5038) 2024-06-06 19:07:57 -07:00
Antoni Baum
a31cab7556 [Core] Avoid copying prompt/output tokens if no penalties are used (#5289) 2024-06-06 18:12:00 -07:00
Matthew Goldey
828da0d44e [Frontend] enable passing multiple LoRA adapters at once to generate() (#5300) 2024-06-06 15:48:13 -05:00
Philipp Moritz
abe855d637 [Kernel] Retune Mixtral 8x22b configs for FP8 on H100 (#5294) 2024-06-06 09:29:29 -07:00
liuyhwangyh
4efff036f0 Bugfix: fix broken of download models from modelscope (#5233)
Co-authored-by: mulin.lyh <mulin.lyh@taobao.com>
2024-06-06 09:28:10 -07:00
Cyrus Leung
89c920785f [CI/Build] Update vision tests (#5307) 2024-06-06 05:17:18 -05:00
Breno Faria
7b0a0dfb22 [Frontend][Core] Update Outlines Integration from FSM to Guide (#4109)
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Breno Faria <breno.faria@intrafind.com>
2024-06-05 16:49:12 -07:00
Simon Mo
3a6ae1d33c [CI] Disable flash_attn backend for spec decode (#5286) 2024-06-05 15:49:27 -07:00
Simon Mo
8f1729b829 [Docs] Add Ray Summit CFP (#5295) 2024-06-05 15:25:18 -07:00
Woosuk Kwon
6a7c7711a2 [Misc] Skip for logits_scale == 1.0 (#5291) 2024-06-05 15:19:02 -07:00
Alex Wu
0f83ddd4d7 [Bugfix][Frontend/Core] Don't log exception when AsyncLLMEngine gracefully shuts down. (#5290) 2024-06-05 15:18:12 -07:00
Michael Goin
065aff6c16 [Bugfix] Make EngineArgs use named arguments for config construction (#5285) 2024-06-05 15:16:56 -07:00
Nick Hill
3d33e372a1 [BugFix] Fix log message about default max model length (#5284) 2024-06-05 14:53:16 -07:00
Nick Hill
faf71bcd4b [Speculative Decoding] Add ProposerWorkerBase abstract class (#5252) 2024-06-05 14:53:05 -07:00
Simon Mo
f270a39537 [Docs] Add Sequoia as sponsors (#5287) 2024-06-05 18:02:56 +00:00
Philipp Moritz
51a08e7d8f [Kernel] Re-tune Mixtral MoE configurations for FP8 on H100 (#5238) 2024-06-05 10:59:14 -07:00
DriverSong
eb8fcd2666 [BugFix] Apply get_cached_tokenizer to the tokenizer setter of LLM (#5207)
Co-authored-by: qiujiawei9 <qiujiawei9@jd.com>
2024-06-05 10:59:02 -07:00
Cody Yu
5563a4dea8 [Model] Correct Mixtral FP8 checkpoint loading (#5231) 2024-06-05 10:58:50 -07:00
Tyler Michael Smith
ccd4f129e8 [Kernel] Add GPU architecture guards to the CUTLASS w8a8 kernels to reduce binary size (#5157)
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
2024-06-05 10:44:15 -07:00
Tyler Michael Smith
02cc3b51a7 [misc] benchmark_serving.py -- add ITL results and tweak TPOT results (#5263) 2024-06-05 10:17:51 -07:00
Simon Mo
d5b1eb081e [CI] Add nightly benchmarks (#5260) 2024-06-05 09:42:08 -07:00
tomeras91
f0a500545f [Frontend] OpenAI API server: Add add_special_tokens to ChatCompletionRequest (default False) (#5278) 2024-06-05 09:32:58 -07:00
Woosuk Kwon
c65146e75e [Misc] Fix docstring of get_attn_backend (#5271) 2024-06-05 09:18:59 -07:00
Woosuk Kwon
41ca62cf03 [Misc] Add CustomOp interface for device portability (#5255) 2024-06-05 09:18:19 -07:00
zifeitong
974fc9b845 [Bugfix] Fix prompt_logprobs when SamplingParams.detokenize is set to True (#5226) 2024-06-04 19:37:28 -07:00
youkaichao
fee4dcc33a [Misc] update collect env (#5261) 2024-06-04 17:29:09 -05:00
Michael Goin
650a4cc55e [Misc] Add transformers version to collect_env.py (#5259) 2024-06-04 12:52:28 -07:00
Simon Mo
9ca62d8668 [CI] mark AMD test as softfail to prevent blockage (#5256) 2024-06-04 11:34:53 -07:00
Li, Jiang
45c35f0d58 [CI/Build] Reducing CPU CI execution time (#5241) 2024-06-04 10:26:40 -07:00
Cyrus Leung
9ba093b4f4 [CI/Build] Simplify model loading for HfRunner (#5251) 2024-06-04 10:09:19 -07:00
Woosuk Kwon
27208be66e [Kernel] Add back batch size 1536 and 3072 to MoE tuning (#5242) 2024-06-04 09:58:47 -07:00
Jie Fu (傅杰)
87d5abef75 [Bugfix] Fix a bug caused by pip install setuptools>=49.4.0 for CPU backend (#5249) 2024-06-04 09:57:51 -07:00
Cyrus Leung
ec784b2526 [CI/Build] Add inputs tests (#5215) 2024-06-03 21:01:46 -07:00
zifeitong
a58f24e590 [Bugfix] Fix torch.compile() error when using MultiprocessingGPUExecutor (#5229) 2024-06-03 20:55:50 -07:00
afeldman-nm
f42a006b15 [Bugfix]: During testing, use pytest monkeypatch for safely overriding the env var that indicates the vLLM backend (#5210) 2024-06-03 20:32:57 -07:00
Woosuk Kwon
3a434b07ed [Kernel] Enhance MoE benchmarking & tuning script (#4921) 2024-06-03 20:06:59 -07:00
Zhuohan Li
bd0e7802e0 [Bugfix] Add warmup for prefix caching example (#5235) 2024-06-03 19:36:41 -07:00
Toshiki Kataoka
06b2550cbb [Bugfix] Support prompt_logprobs==0 (#5217) 2024-06-03 17:59:30 -07:00
Breno Faria
f775a07e30 [FRONTEND] OpenAI tools support named functions (#5032) 2024-06-03 18:25:29 -05:00
Kevin H. Luu
4f0d17c05c New CI template on AWS stack (#5110)
Signed-off-by: kevin <kevin@anyscale.com>
2024-06-03 16:16:43 -07:00
Kaiyang Chen
10c38e3e46 [Misc]: Implement CPU/GPU swapping in BlockManagerV2 (#3834) 2024-06-03 13:37:11 -07:00
Yuan
cafb8e06c5 [CI/BUILD] enable intel queue for longer CPU tests (#4113) 2024-06-03 10:39:50 -07:00
Tyler Michael Smith
cbb2f59cc8 [Kernel] Pass a device pointer into the quantize kernel for the scales (#5159) 2024-06-03 09:52:30 -07:00
Antoni Baum
0ab278ca31 [Core] Remove unnecessary copies in flash attn backend (#5138) 2024-06-03 09:39:31 -07:00
Cyrus Leung
7a64d24aad [Core] Support image processor (#4197) 2024-06-02 22:56:41 -07:00
Cyrus Leung
dfbe60dc62 [Misc] Simplify code and fix type annotations in conftest.py (#5118) 2024-06-02 16:05:50 -07:00
Divakar Verma
a66cf40b20 [Kernel][ROCm][AMD] enable fused topk_softmax kernel for moe layer (#4927)
This PR enables the fused topk_softmax kernel used in moe layer for HIP
2024-06-02 14:13:26 -07:00
Avinash Raj
f790ad3c50 [Frontend][OpenAI] Support for returning max_model_len on /v1/models response (#4643) 2024-06-02 08:06:13 +00:00
Simon Mo
ed59a7ed23 Update test_ignore_eos (#4898) 2024-06-02 02:21:53 +00:00
Robert Shaw
044793d8df [BugFix] Prevent LLM.encode for non-generation Models (#5184)
Co-authored-by: mgoin <michael@neuralmagic.com>
2024-06-01 23:35:41 +00:00
Daniil Arapov
c2d6d2f960 [Bugfix]: Fix issues related to prefix caching example (#5177) (#5180) 2024-06-01 15:53:52 -07:00
Zhuohan Li
8279078e21 [Bugfix] Remove deprecated @abstractproperty (#5174) 2024-06-01 22:40:25 +00:00
chenqianfzh
b9c0605a8e [Feature][Kernel] Support bitsandbytes quantization and QLoRA (#4776) 2024-06-01 14:51:10 -06:00
Nadav Shmayovits
37464a0f74 [Bugfix] Fix call to init_logger in openai server (#4765) 2024-06-01 17:18:50 +00:00
Ye Cao
c354072828 [Minor] Fix the path typo in loader.py: save_sharded_states.py -> save_sharded_state.py (#5151)
Signed-off-by: Ye Cao <caoye.cao@alibaba-inc.com>
2024-06-01 17:11:22 +00:00
Varun Sundar Rabindranath
f081c3ce4b [Kernel] Update Cutlass fp8 configs (#5144)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
2024-06-01 08:46:07 +00:00
Tyler Michael Smith
260d119e86 [Kernel] Refactor CUTLASS kernels to always take scales that reside on the GPU (#5137) 2024-06-01 06:45:32 +00:00
Daniele
a360ff80bb [CI/Build] CMakeLists: build all extensions' cmake targets at the same time (#5034) 2024-05-31 22:06:45 -06:00
Tyler Michael Smith
1197e02141 [Build] Guard against older CUDA versions when building CUTLASS 3.x kernels (#5168)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.8, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.10, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.11, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.8, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.9, 2.3.0) (push) Has been cancelled
2024-05-31 17:21:38 -07:00
Nick Hill
657579113f [Doc] Add checkmark for GPTBigCodeForCausalLM LoRA support (#5171) 2024-05-31 17:20:19 -07:00
Cody Yu
e9899fb7a4 [Model] Enable FP8 QKV in MoE and refine kernel tuning script (#5039) 2024-05-31 14:29:19 -07:00
functionxu123
a377f0bd5e [Misc]: optimize eager mode host time (#4196)
Co-authored-by: xuhao <xuhao@cambricon.com>
2024-05-31 13:14:50 +08:00
Simon Mo
e9d3aa04f6 Revert "[Kernel] Marlin_24: Ensure the mma.sp instruction is using the ::ordered_metadata modifier (introduced with PTX 8.5)" (#5149) 2024-05-30 22:00:26 -07:00
SnowDist
a22dea54d3 [Model] Support MAP-NEO model (#5081)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2024-05-30 19:24:41 -07:00
simon-mo
533c217792 Fix cutlass sm_90a vesrion in CMakeList 2024-05-31 02:13:01 +00:00
Alexander Matveev
6d21fa1cad [Kernel] Marlin_24: Ensure the mma.sp instruction is using the ::ordered_metadata modifier (introduced with PTX 8.5) (#5136) 2024-05-30 21:02:11 -05:00
Robert Shaw
b35be5403f [Bugfix] Avoid Warnings in SparseML Activation Quantization (#5120) 2024-05-30 17:04:37 -07:00
Simon Mo
45a1a69b98 [Build] Disable sm_90a in cu11 (#5141) 2024-05-30 14:37:16 -07:00
Simon Mo
87a658c812 Bump version to v0.4.3 (#5046) 2024-05-30 11:13:46 -07:00
Chansung Park
429d89720e add doc about serving option on dstack (#3074)
Co-authored-by: Roger Wang <ywang@roblox.com>
2024-05-30 10:11:07 -07:00
Cyrus Leung
a9bcc7afb2 [Doc] Use intersphinx and update entrypoints docs (#5125) 2024-05-30 09:59:23 -07:00
Hyunsung Lee
d79d9eaaff [Misc] remove duplicate definition of seq_lens_tensor in model_runner.py (#5129) 2024-05-30 06:56:19 -07:00
youkaichao
f758505c73 [CI/Build] increase wheel size limit to 200 MB (#5130) 2024-05-30 06:29:48 -07:00
Robert Shaw
d910816c73 [Bugfix] Automatically Detect SparseML models (#5119) 2024-05-30 12:58:37 +00:00
Breno Faria
87d41c849d [BUGFIX] [FRONTEND] Correct chat logprobs (#5029)
Co-authored-by: Breno Faria <breno.faria@intrafind.com>
2024-05-30 02:52:14 -07:00
omkar kakarparthi
e07aff9e52 [CI/Build] Docker cleanup functionality for amd servers (#5112)
Co-authored-by: Alexey Kondratiev <alexey.kondratiev@amd.com>
Co-authored-by: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com>
Co-authored-by: Alexei V. Ivanov <alexei.ivanov@amd.com>
Co-authored-by: omkarkakarparthi <okakarpa>
2024-05-30 03:27:39 +00:00
Alexander Matveev
5bf185a1c4 [Bugfix] gptq_marlin: Ensure g_idx_sort_indices is not a Parameter (#5108) 2024-05-30 00:30:18 +00:00
youkaichao
4fbcb0f27e [Doc][Build] update after removing vllm-nccl (#5103)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
2024-05-29 23:51:18 +00:00
Itay Etelis
7c3604fb68 [Bugfix] logprobs is not compatible with the OpenAI spec #4795 (#5031) 2024-05-29 16:13:22 -07:00
Cyrus Leung
b1c255630d [Core] Avoid the need to pass None values to Sequence.inputs (#5099) 2024-05-29 16:05:01 -07:00
Cyrus Leung
eb6c50cdc2 [Bugfix][CI/Build] Fix codespell failing to skip files in git diff (#5097) 2024-05-29 16:02:54 -07:00
Cyrus Leung
eecd864388 [Bugfix][CI/Build] Fix test and improve code for merge_async_iterators (#5096) 2024-05-29 16:02:25 -07:00
Ronen Schaffer
ae495c74ea [Doc]Replace deprecated flag in readme (#4526) 2024-05-29 22:26:33 +00:00
afeldman-nm
4238bc82f2 [Core] Cross-attention KV caching and memory-management (towards eventual encoder/decoder model support) (#4837) 2024-05-29 16:09:13 +00:00
youkaichao
594392d27a [Core][Distributed] improve p2p access check (#4992) 2024-05-29 11:29:07 +00:00
Cyrus Leung
18c1f16d86 [Bugfix] Fix arguments passed to Sequence in stop checker test (#5092) 2024-05-29 07:16:41 +00:00
youkaichao
5bd3c65072 [Core][Optimization] remove vllm-nccl (#5091) 2024-05-29 05:13:52 +00:00
Marut Pandya
616e600e0b [Misc] add gpu_memory_utilization arg (#5079)
Signed-off-by: pandyamarut <pandyamarut@gmail.com>
2024-05-28 17:16:18 -07:00
Junichi Sato
dfba529b40 [Bugfix] Remove the last EOS token unless explicitly specified (#5077) 2024-05-28 17:15:35 -07:00
Cyrus Leung
5ae5ed1e60 [Core] Consolidate prompt arguments to LLM engines (#4328)
Co-authored-by: Roger Wang <ywang@roblox.com>
2024-05-28 13:29:31 -07:00
Simon Mo
290f4ada2b [Docs] Add Dropbox as sponsors (#5089) 2024-05-28 10:29:09 -07:00
Divakar Verma
dd8de11f0a [Kernel][ROCm][AMD] Add fused_moe Triton configs for MI300X (#4951)
This PR adds Triton kernel configs for the MoE kernel for MI300X
2024-05-28 16:03:23 +00:00
Robert Shaw
9ba415588a [BugFix] Fix Embedding Models with TP>1 (#5075) 2024-05-28 08:32:42 -07:00
Michał Moskal
d4f3985907 [Core] Sliding window for block manager v2 (#4545)
Co-authored-by: Ruth Evans <ruthevans@Ruths-MacBook-Pro.local>
2024-05-28 11:07:07 +09:00
Isotr0py
890aa93d27 [Model] Add support for falcon-11B (#5069) 2024-05-27 16:41:43 -07:00
sasha0552
fbdb7b3ee2 [Core] Allow AQLM on Pascal (#5058) 2024-05-27 15:26:14 -07:00
Zhuohan Li
1102bef219 [Bugfix / Core] Prefix Caching Guards (merged with main) (#4846)
Co-authored-by: rsnm2 <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
2024-05-27 15:18:17 -07:00
Roger Wang
f17a1a8f96 [Misc] Make Serving Benchmark More User-friendly (#5044) 2024-05-25 17:28:16 +00:00
Lily Liu
d5a1697772 [Dynamic Spec Decoding] Minor fix for disabling speculative decoding (#5000) 2024-05-25 10:00:14 -07:00
youkaichao
325c119961 [Misc] add logging level env var (#5045) 2024-05-24 23:49:49 -07:00
Eric Xihui Lin
8e192ff967 [Kernel][Backend][Model] Blocksparse flash attention kernel and Phi-3-Small model (#4799)
Co-authored-by: beagleski <yunanzhang@microsoft.com>
Co-authored-by: bapatra <bapatra@microsoft.com>
Co-authored-by: Barun Patra <codedecde@users.noreply.github.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
2024-05-24 22:00:52 -07:00
leiwen83
e64fde4b01 [Core][Bugfix]: fix prefix caching for blockv2 (#4764)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
2024-05-24 10:07:09 -07:00
Robert Shaw
919770957f [Bugfix] Fix Mistral v0.3 Weight Loading (#5005)
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
2024-05-24 12:28:27 +00:00
youkaichao
6a50f4cafa [Doc] add ccache guide in doc (#5012)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
2024-05-23 23:21:54 +00:00
Elisei Smirnov
e3470f8753 [Core]: Option To Use Prompt Token Ids Inside Logits Processor (#4985)
Co-authored-by: Elisei Smirnov <el.smirnov@innopolis.university>
2024-05-23 22:04:24 +00:00
Dipika Sikka
a1242324c9 [Kernel] Initial Activation Quantization Support (#4525)
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
2024-05-23 21:29:18 +00:00
Murali Andoorveedu
5eda2ea02a [Core][1/N] Support send/recv in PyNCCL Groups (#4988)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
2024-05-23 09:54:48 -07:00
Letian Li
2ba80bed27 [Bugfix] Update Dockerfile.cpu to fix NameError: name 'vllm_ops' is not defined (#5009) 2024-05-23 09:08:58 -07:00
Alexander Matveev
6066253296 Marlin 24 prefill performance improvement (about 25% better on average) (#4983) 2024-05-23 02:39:27 -04:00
Cody Yu
ee3eea0a1b [Misc] Take user preference in attention selector (#4960) 2024-05-23 07:55:56 +09:00
Philipp Moritz
a36de682d4 [Minor] Fix small typo in llama.py: QKVParallelLinear -> QuantizationConfig (#4991) 2024-05-22 22:26:56 +00:00
Nick Hill
eb6d3c264d [Core] Eliminate parallel worker per-step task scheduling overhead (#4894) 2024-05-23 06:17:27 +09:00
raywanb
97b030005c [Model] LoRA gptbigcode implementation (#3949) 2024-05-22 13:58:59 -07:00
Cody Yu
a3a73ab069 [Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)
The 2nd PR for #4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
2024-05-22 13:28:20 -07:00
Tyler Michael Smith
8674f9880e [Kernel] Fixup for CUTLASS kernels in CUDA graphs (#4954)
Pass the CUDA stream into the CUTLASS GEMMs, to avoid future issues with CUDA graphs
2024-05-22 14:10:43 +00:00
SangBin Cho
c74c913bfb [misc] remove comments that were supposed to be removed (#4977) 2024-05-22 09:02:58 -04:00
Michael Goin
5f6d10c14c [CI/Build] Enforce style for C++ and CUDA code with clang-format (#4722) 2024-05-22 07:18:41 +00:00
sasha0552
9b9a10d6cb [Frontend] Dynamic RoPE scaling (#4638) 2024-05-22 01:32:35 -04:00
Isotr0py
99eff67ba9 [Bugfix][Kernel] Add head size check for attention backend selection (#4944) 2024-05-21 15:33:25 -04:00
Kante Yin
14772eeb8e [Bugfix] Fix flag name for max_seq_len_to_capture (#4935)
Signed-off-by: kerthcet <kerthcet@gmail.com>
2024-05-21 09:30:52 -07:00
Michael Goin
757b62c495 [CI/Build] Codespell ignore build/ directory (#4945) 2024-05-21 09:06:10 -07:00
Simon Mo
e941f88584 [Docs] Add acknowledgment for sponsors (#4925) 2024-05-21 00:17:25 -07:00
Isotr0py
f12c3b5b3d [Model] Add Phi-2 LoRA support (#4886) 2024-05-21 14:24:17 +09:00
HUANG Fei
d130b573a0 [Model] add rope_scaling support for qwen2 (#4930) 2024-05-21 05:22:22 +00:00
Antoni Baum
65ae8c2c8f [Core] Fix scheduler considering "no LoRA" as "LoRA" (#4897) 2024-05-20 17:48:32 -07:00
Kuntai Du
c3af44722c [Doc]Add documentation to benchmarking script when running TGI (#4920) 2024-05-20 20:16:57 +00:00
Aurick Qiao
1937e29848 [Core] Sharded State Loader download from HF (#4889) 2024-05-20 11:46:12 -07:00
Mor Zusman
f0eecee610 [Bugfix] Fix dummy weight for fp8 (#4916)
Allow dummy load format for fp8,
torch.uniform_ doesn't support FP8 at the moment

Co-authored-by: Mor Zusman <morz@ai21.com>
2024-05-20 18:44:25 +00:00
Alexei-V-Ivanov-AMD
943e72ca56 [Build/CI] Enabling AMD Entrypoints Test (#4834)
Co-authored-by: Alexey Kondratiev <alexey.kondratiev@amd.com>
2024-05-20 11:29:28 -07:00
Wenwei Zhang
546a97ef69 [Misc]: allow user to specify port in distributed setting (#4914) 2024-05-20 17:45:06 +00:00
Alexander Matveev
da5a0b539d Remove marlin warning (#4918) 2024-05-20 14:55:34 +00:00
Cyrus Leung
6287537a0c [Model] LLaVA model refactor (#4910) 2024-05-20 08:11:25 +00:00
Woosuk Kwon
b57e6c5949 [Kernel] Add flash-attn back (#4907) 2024-05-19 18:11:30 -07:00
Alexander Matveev
27ce85476e [Kernel] Add marlin_24 unit tests (#4901) 2024-05-19 11:37:34 -04:00
Cyrus Leung
f68470e803 [Bugfix][Model] Add base class for vision-language models (#4809) 2024-05-19 00:13:33 -07:00
SangBin Cho
2e9a2227ec [Lora] Support long context lora (#4787)
Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through.

It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors.

Follow up of https://github.com/vllm-project/vllm/pull/3095/files
2024-05-18 16:05:23 +09:00
alexeykondrat
c0724fc915 [ROCm][Hardware][AMD] Adding Navi21 to fallback to naive attention if Triton is not used (#4658) 2024-05-18 05:09:11 +00:00
Michael Goin
86b45ae065 [Bugfix] Relax tiktoken to >= 0.6.0 (#4890) 2024-05-17 12:58:52 -06:00
Antoni Baum
c5711ef985 [Doc] Update Ray Data distributed offline inference example (#4871) 2024-05-17 10:52:11 -07:00
eigenLiu
48d5985a08 Sync huggingface modifications of qwen Moe model (#4774) 2024-05-17 09:43:19 -07:00
Jinzhen Lin
33e0823de5 [Bugfix] fix rope error when load models with different dtypes (#4835) 2024-05-17 18:43:34 +09:00
Alexei-V-Ivanov-AMD
26148120b3 [Build/CI] Extending the set of AMD tests with Regression, Basic Correctness, Distributed, Engine, Llava Tests (#4797) 2024-05-16 20:58:25 -07:00
bofeng huang
0150a10630 [Frontend] OpenAI API server: Do not add bos token by default when encoding (#4688) 2024-05-16 18:47:22 -07:00
Kante Yin
8e7fb5d43a Support to serve vLLM on Kubernetes with LWS (#4829)
Signed-off-by: kerthcet <kerthcet@gmail.com>
2024-05-16 16:37:29 -07:00
Woosuk Kwon
9a31a817a8 [Bugfix] Fix FP8 KV cache support (#4869) 2024-05-16 22:42:29 +00:00
Tyler Michael Smith
2060e93659 [Kernel] Add w8a8 CUTLASS kernels (#4749) 2024-05-16 18:32:50 -04:00
Silencio
8435b207af [Kernel] Add punica dimension for Qwen1.5-32B LoRA (#4850)
Co-authored-by: Silencio <silencio@adsl-99-6-187-6.dsl.irvnca.sbcglobal.net>
2024-05-16 11:16:09 -07:00
youkaichao
10fa9eea21 [Misc] remove old comments (#4866) 2024-05-16 11:07:41 -07:00
youkaichao
e08188081b [Core][Distributed] remove graph mode function (#4818) 2024-05-16 10:59:52 -07:00
Hongxia Yang
b5853f9963 [ROCm][AMD][Bugfix] adding a missing triton autotune config (#4845) 2024-05-16 10:46:52 -07:00
Simon Mo
f09edd8a25 Add JSON output support for benchmark_latency and benchmark_throughput (#4848) 2024-05-16 10:02:56 -07:00
Alexander Matveev
6979ade384 Add GPTQ Marlin 2:4 sparse structured support (#4790)
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
2024-05-16 12:56:15 -04:00
Pierre Dulac
9216b9cc38 [Bugfix] Bypass authorization API token for preflight requests (#4862) 2024-05-16 09:42:21 -07:00
Alex Wu
5e0391c040 [Frontend] Separate OpenAI Batch Runner usage from API Server (#4851) 2024-05-17 00:42:41 +09:00
Alex Wu
dbc0754ddf [docs] Fix typo in examples filename openi -> openai (#4864) 2024-05-17 00:42:17 +09:00
Jinzhen Lin
99caa49106 [Kernel] add bfloat16 support for gptq marlin kernel (#4788) 2024-05-16 09:55:29 -04:00
alexm-nm
5c342570d7 Add marlin unit tests and marlin benchmark script (#4815) 2024-05-16 09:36:49 -04:00
Cody Yu
973617ae02 [Speculative decoding][Re-take] Enable TP>1 speculative decoding (#4840)
Co-authored-by: Cade Daniel <edacih@gmail.com>
Co-authored-by: Cade Daniel <cade@anyscale.com>
2024-05-16 00:53:51 -07:00
Aurick Qiao
30e754390c [Core] Implement sharded state loader (#4690)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-05-15 22:11:54 -07:00
Alex Wu
52f8107cf2 [Frontend] Support OpenAI batch file format (#4794)
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
2024-05-15 19:13:36 -04:00
Cyrus Leung
fc0d9dfc3a [Frontend] Re-enable custom roles in Chat Completions API (#4758) 2024-05-15 14:58:46 -07:00
Zhuohan Li
361c461a12 [Doc] Highlight the fourth meetup in the README (#4842) 2024-05-15 11:38:49 -07:00
zifeitong
a5675d348b [Bugfix] Properly set distributed_executor_backend in ParallelConfig (#4816) 2024-05-15 07:22:09 -07:00
Cyrus Leung
e9cdd2b1e2 [CI/Build] Further decouple HuggingFace implementation from ours during tests (#4166) 2024-05-14 23:38:40 -07:00
SangBin Cho
65bf2ac165 [Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode to a single API (#4681)
This PR combines prepare_prompt and prepare_decode into a single API. This PR also coelsce the attn metadata for prefill/decode to a single class and allow to slice them when running attn backend.

It also refactors subquery_start_loc which was not refactored in the previous PR
2024-05-15 14:00:10 +09:00
SangBin Cho
8a7cc254a0 Revert "[Kernel] Use flash-attn for decoding (#3648)" (#4820)
Lora 3 & 4 test seems to have illegal memory access failure after this commit;

[2024-05-14 23:51:18,182 E 22 22] logging.cc:101: Unhandled exception: N3c105ErrorE. what(): CUDA error: an illegal memory access was encountered
<br class="Apple-interchange-newline">
Exmaple: https://buildkite.com/vllm/ci/builds/7382#018f793d-1527-4e1c-ab59-c3a34ec55241

This reverts commit 1356df5.

FILL IN THE PR DESCRIPTION HERE

FIX #xxxx (link existing issues this PR will resolve)
2024-05-15 11:52:45 +09:00
Simon Mo
29bc01bf3b Add 4th meetup announcement to readme (#4817) 2024-05-14 18:33:06 -04:00
Nick Hill
676a99982f [Core] Add MultiprocessingGPUExecutor (#4539)
Co-authored-by: SAHIL SUNEJA <suneja@us.ibm.com>
2024-05-14 10:38:59 -07:00
Cyrus Leung
dc72402b57 [Bugfix][Doc] Fix CI failure in docs (#4804)
This PR fixes the CI failure introduced by #4798.

The failure originates from having duplicate target names in reST, and is fixed by changing the ref targets to anonymous ones. For more information, see this discussion.

I have also changed the format of the links to be more distinct from each other.
2024-05-15 01:57:08 +09:00
Kuntai Du
ccb63a8245 [Core][Hash][Automatic Prefix caching] Accelerating the hashing function by avoiding deep copies (#4696) 2024-05-14 21:34:33 +09:00
Zhuohan Li
c579b750a0 [Doc] Add meetups to the doc (#4798) 2024-05-13 18:48:00 -07:00
Cyrus Leung
4bfa7e7f75 [Doc] Add API reference for offline inference (#4710) 2024-05-13 17:47:42 -07:00
Zhuohan Li
ac1fbf7fd2 [Doc] Shorten README by removing supported model list (#4796) 2024-05-13 16:23:54 -07:00
Philipp Moritz
33d3914b1e [Bugfix] Fix dynamic FP8 quantization for Mixtral (#4793) 2024-05-13 19:00:27 -04:00
Stephen Krider
1356df53bd [Kernel] Use flash-attn for decoding (#3648)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
2024-05-13 15:50:33 -07:00
Cody Yu
ce532ff45c [Speculative decoding] Improve n-gram efficiency (#4724) 2024-05-13 15:00:13 -07:00
Sanger Steel
8bc68e198c [Frontend] [Core] perf: Automatically detect vLLM-tensorized model, update tensorizer to version 2.9.0 (#4208) 2024-05-13 14:57:07 -07:00
Woosuk Kwon
0fca3cdcf2 [Misc] Enhance attention selector (#4751) 2024-05-13 10:47:25 -07:00
SangBin Cho
e7c46b9527 [Scheduler] Warning upon preemption and Swapping (#4647)
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
2024-05-13 23:50:44 +09:00
Cyrus Leung
350f9e107f [CI/Build] Move test_utils.py to tests/utils.py (#4425)
Since #4335 was merged, I've noticed that the definition of ServerRunner in the tests is the same as in the test for OpenAI API. I have moved the class to the test utilities to avoid code duplication. (Although it only has been repeated twice so far, I will add another similar test suite in #4200 which would duplicate the code a third time)

Also, I have moved the test utilities file (test_utils.py) to under the test directory (tests/utils.py), since none of its code is actually used in the main package. Note that I have added __init__.py to each test subpackage and updated the ray.init() call in the test utilities file in order to relative import tests/utils.py.
2024-05-13 23:50:09 +09:00
youkaichao
702bee461f [Core][Distributed] refactor custom allreduce to support multiple tp groups (#4754) 2024-05-12 17:47:59 -07:00
Swapnil Parekh
a7be4d0072 [CORE] Improvement in ranks code (#4718) 2024-05-12 17:47:47 -07:00
Robert Shaw
a709e87a4f [CI/Build] Tweak Marlin Nondeterminism Issues (#4713) 2024-05-12 17:46:31 -07:00
Yikang Shen
6eaccb7353 [Model] Add support for IBM Granite Code models (#4636) 2024-05-11 21:27:24 -07:00
Chang Su
e254497b66 [Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734) 2024-05-11 11:30:37 -07:00
youkaichao
4e12131089 [Core][Test] fix function name typo in custom allreduce (#4750) 2024-05-10 15:14:40 -07:00
Robert Shaw
fcc2994be6 [CI] Nits for bad initialization of SeqGroup in testing (#4748) 2024-05-10 18:01:01 -04:00
heeju-kim2
2e7796f2cf [Speculative decoding] CUDA graph support (#4295)
Co-authored-by: Cade Daniel <edacih@gmail.com>
2024-05-10 17:36:25 +00:00
Allen.Dou
706588a77d [Bugfix] Fix CLI arguments in OpenAI server docs (#4729) 2024-05-11 00:00:56 +09:00
SangBin Cho
6a0f617210 [Core] Fix circular reference which leaked llm instance in local dev env (#4737)
Storing exception frame is extremely prone to circular refernece because it contains the reference to objects.

When tensorizer is not installed, it leaks llm instance because error frame has references to various modules which cause circular reference problem.

I also found spec decoding has a circular reference issue, and I solved it using weakref.proxy.
2024-05-10 23:54:32 +09:00
Steve Grubb
dac6a3f6ed [Misc] Apply a couple g++ cleanups (#4719) 2024-05-10 13:37:05 +00:00
Kunshang Ji
64b77dfd7e [Core]fix type annotation for swap_blocks (#4726) 2024-05-10 21:52:48 +09:00
Simon Mo
51d4094fda chunked-prefill-doc-syntax (#4603)
Fix the docs: https://docs.vllm.ai/en/latest/models/performance.html

Co-authored-by: sang <rkooo567@gmail.com>
2024-05-10 14:13:23 +09:00
Allen.Dou
e965d46184 [Misc] Keep only one implementation of the create_dummy_prompt function. (#4716) 2024-05-09 21:42:38 -07:00
youkaichao
208b71bcc1 [Core][Distributed] refactor pynccl (#4591)
[Core][Distributed] refactor pynccl to hold multiple communicators (#4591)
2024-05-09 19:48:43 -07:00
Cody Yu
c833101740 [Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support (#4535) 2024-05-09 18:04:17 -06:00
Philipp Moritz
379da6dcb5 [Kernel] [FP8] Improve FP8 linear layer performance (#4691)
This PR improves the FP8 performance of linear layers, which had been lacking before (#4118 (comment) and #4118 (comment)).

We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance.

Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization:

qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16)
qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16) 
qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16)
qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16)
qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16)
2024-05-09 16:38:07 -07:00
Hao Zhang
ebce310b74 [Model] Snowflake arctic model implementation (#4652)
Co-authored-by: Dash Desai <1723932+iamontheinet@users.noreply.github.com>
Co-authored-by: Aurick Qiao <qiao@aurick.net>
Co-authored-by: Aurick Qiao <aurick.qiao@snowflake.com>
Co-authored-by: Aurick Qiao <aurickq@users.noreply.github.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
2024-05-09 22:37:14 +00:00
Michael Goin
be0c5180ac [Bugfix] Add logs for all model dtype casting (#4717) 2024-05-09 18:36:25 +00:00
Robert Shaw
cea64430f6 [Bugfix] Update grafana.json (#4711) 2024-05-09 10:10:13 -07:00
Cyrus Leung
a3c124570a [Bugfix] Fix CLI arguments in OpenAI server docs (#4709) 2024-05-09 09:53:14 -07:00
kliuae
ff5abcd746 [ROCm] Add support for Punica kernels on AMD GPUs (#3140)
Co-authored-by: miloice <jeffaw99@hotmail.com>
2024-05-09 09:19:50 -07:00
Woosuk Kwon
0ee535b294 [Misc] Set block size at initialization & Fix test_model_runner (#4705) 2024-05-09 09:04:59 -07:00
Woosuk Kwon
190bc838e1 [Misc] Remove unnecessary ModelRunner imports (#4703) 2024-05-09 00:17:17 -07:00
Cyrus Leung
f12b20decc [Frontend] Move async logic outside of constructor (#4674) 2024-05-08 22:48:33 -07:00
Mahmoud Ashraf
16bc0a098f [Frontend] add tok/s speed metric to llm class when using tqdm (#4400)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
2024-05-08 22:02:31 -07:00
alexm-nm
e288df0632 [Bugfix] Fine-tune gptq_marlin configs to be more similar to marlin (#4626) 2024-05-08 17:14:31 -07:00
Cade Daniel
8b9241be3a [Speculative decoding] [Bugfix] Fix overallocation in ngram + spec logprobs (#4672) 2024-05-08 23:24:46 +00:00
Cody Yu
f942efb5a3 [Dynamic Spec Decoding] Auto-disable by the running queue size (#4592)
Co-authored-by: Cade Daniel <edacih@gmail.com>
2024-05-08 21:44:00 +00:00
Woosuk Kwon
89579a201f [Misc] Use vllm-flash-attn instead of flash-attn (#4686) 2024-05-08 13:15:34 -07:00
youkaichao
230c4b38c1 [CI/Test] fix swap test for multi gpu (#4689) 2024-05-08 13:14:02 -07:00
youkaichao
20cfcdec99 [Core][Optimization] change python dict to pytorch tensor for blocks to swap (#4659) 2024-05-08 12:07:05 -07:00
Antoni Baum
ad932a221d [Core] Faster startup for LoRA enabled models (#4634) 2024-05-08 10:33:18 -07:00
Woosuk Kwon
5510cf0e8a [Misc] Add get_name method to attention backends (#4685) 2024-05-08 09:59:31 -07:00
DefTruth
0f9a6e3d22 [Bugfix][Kernel] allow non-power-of-2 for prefix prefill with alibi (#4573) 2024-05-08 09:19:58 -07:00
SangBin Cho
f6a593093a [CI] Make mistral tests pass (#4596) 2024-05-08 08:44:35 -07:00
SangBin Cho
d7740ea4dc [Core] Optimize sampler get_logprobs (#4594) 2024-05-08 08:42:28 -07:00
youkaichao
cc466a3290 [Core][Distributed] support cpu&device in broadcast tensor dict (#4660)
[Core][Distributed] support both cpu and device tensor in broadcast tensor dict (#4660)
2024-05-07 19:34:47 -07:00
leiwen83
8344f7742b [Bug fix][Core] fixup ngram not setup correctly (#4551)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
Co-authored-by: Cade Daniel <edacih@gmail.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
2024-05-07 11:40:18 -07:00
youkaichao
469f85c782 [Core][Optimization] change copy-on-write from dict[int, list] to list (#4648) 2024-05-07 11:06:32 -07:00
Austin Veselka
10760da800 [Bugfix] Fixed error in slice_lora_b for MergedQKVParallelLinearWithLora (#4609) 2024-05-07 10:59:07 -07:00
Alexei-V-Ivanov-AMD
478aed5827 [Build/CI] Fixing 'docker run' to re-enable AMD CI tests. (#4642) 2024-05-07 09:23:17 -07:00
youkaichao
63575bc2e1 [Core][Optimization] change python dict to pytorch tensor (#4607) 2024-05-06 21:30:27 -07:00
Philipp Moritz
a98187cf72 [Kernel] Make static FP8 scaling more robust (#4570)
Previously FP8 static scaling works if the scales are overestimating the maxima of all activation tensors during computation. However this will not always be the case even if the scales were calibrated very carefully. For example, with the activations in my checkpoint

https://huggingface.co/pcmoritz/Mixtral-8x7B-v0.1-fp8-act-scale

(which was calibrated on https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k), I'm getting the following mostly random performance on MMLU:

|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.2295|±  |0.0035|
| - humanities     |N/A    |none  |     5|acc   |0.2421|±  |0.0062|
| - other          |N/A    |none  |     5|acc   |0.2398|±  |0.0076|
| - social_sciences|N/A    |none  |     5|acc   |0.2171|±  |0.0074|
| - stem           |N/A    |none  |     5|acc   |0.2125|±  |0.0073|
With the fix in this PR where the scaled activations are clamped between [-std::numeric_limits<c10::Float8_e4m3fn>::max(), std::numeric_limits<c10::Float8_e4m3fn>::max()] to make sure there are no NaNs, the performance is

|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7008|±  |0.0036|
| - humanities     |N/A    |none  |     5|acc   |0.6453|±  |0.0065|
| - other          |N/A    |none  |     5|acc   |0.7692|±  |0.0072|
| - social_sciences|N/A    |none  |     5|acc   |0.8083|±  |0.0070|
| - stem           |N/A    |none  |     5|acc   |0.6115|±  |0.0083|
This is not perfect yet but is getting very close to the FP16 / dynamic activation scale performance.
2024-05-06 17:39:28 -07:00
Noam Gat
bd99d22629 Update lm-format-enforcer to 0.10.1 (#4631) 2024-05-06 23:51:59 +00:00
Cade Daniel
19cb4716ee [CI] Add retry for agent lost (#4633) 2024-05-06 23:18:57 +00:00
Simon Mo
e186d37cb1 [CI] use ccache actions properly in release workflow (#4629) 2024-05-06 22:23:36 +00:00
Cyrus Leung
323f27b904 [Bugfix] Fix asyncio.Task not being subscriptable (#4623) 2024-05-06 09:31:05 -07:00
zhaoyang-star
0650e5935b Disable cuda version check in vllm-openai image (#4530) 2024-05-05 16:58:55 -07:00
Simon Mo
c7f2cf2b7f [CI] Reduce wheel size by not shipping debug symbols (#4602)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.8, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.10, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.11, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.8, 2.3.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.9, 2.3.0) (push) Has been cancelled
2024-05-04 21:28:58 -07:00
Simon Mo
8d8357c8ed bump version to v0.4.2 (#4600) 2024-05-04 17:09:49 -07:00
DearPlanet
4302987069 [Bugfix] Fix inappropriate content of model_name tag in Prometheus metrics (#3937) 2024-05-04 15:39:34 -07:00
Simon Mo
021b1a2ab7 [CI] check size of the wheels (#4319) 2024-05-04 20:44:36 +00:00
Michael Goin
2a052011ca [Kernel] Support MoE Fp8 Checkpoints for Mixtral (Static Weights with Dynamic/Static Activations) (#4527)
Follow on to #4332 to enable FP8 checkpoint loading for Mixtral and supersedes #4436.

This PR enables the following checkpoint loading features for Mixtral:

Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model
Supports static or dynamic activation quantization with static weight quantization (all per tensor)
Supports different scales for each expert weight
Supports Fp8 in QKV layer
Notes:

The Expert Gate/Router always runs at half / full precision for now.
If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
2024-05-04 11:45:16 -07:00
SangBin Cho
36fb68f947 [Doc] Chunked Prefill Documentation (#4580) 2024-05-04 00:18:00 -07:00
Cody Yu
bc8ad68455 [Misc][Refactor] Introduce ExecuteModelData (#4540) 2024-05-03 17:47:07 -07:00
youkaichao
344bf7cd2d [Misc] add installation time env vars (#4574) 2024-05-03 15:55:56 -07:00
Cade Daniel
ab50275111 [Speculative decoding] Support target-model logprobs (#4378) 2024-05-03 15:52:01 -07:00
Lily Liu
43c413ec57 [Kernel] Use flashinfer for decoding (#4353)
Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
2024-05-03 15:51:27 -07:00
Sebastian Schoennenbeck
f8e7adda21 Fix/async chat serving (#2727) 2024-05-03 11:04:14 -07:00
Michael Goin
7e65477e5e [Bugfix] Allow "None" or "" to be passed to CLI for string args that default to None (#4586) 2024-05-03 10:32:21 -07:00
SangBin Cho
3521ba4f25 [Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518) 2024-05-03 10:20:12 -07:00
youkaichao
2d7bce9cd5 [Doc] add env vars to the doc (#4572) 2024-05-03 05:13:49 +00:00
DefTruth
ce3f1eedf8 [Misc] remove chunk detected debug logs (#4571) 2024-05-03 04:48:08 +00:00
Yang, Bo
808632d3b4 [BugFix] Prevent the task of _force_log from being garbage collected (#4567) 2024-05-03 01:35:18 +00:00
youkaichao
344a5d0c33 [Core][Distributed] enable allreduce for multiple tp groups (#4566) 2024-05-02 17:32:33 -07:00
SangBin Cho
0f8a91401c [Core] Ignore infeasible swap requests. (#4557) 2024-05-02 14:31:20 -07:00
Alexei-V-Ivanov-AMD
9b5c9f9484 [CI/Build] AMD CI pipeline with extended set of tests. (#4267)
Co-authored-by: simon-mo <simon.mo@hey.com>
2024-05-02 12:29:07 -07:00
Michał Moskal
32881f3f31 [kernel] fix sliding window in prefix prefill Triton kernel (#4405)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
2024-05-02 11:23:37 -07:00
youkaichao
5b8a7c1cb0 [Misc] centralize all usage of environment variables (#4548) 2024-05-02 11:13:25 -07:00
Mark McLoughlin
1ff0c73a79 [BugFix] Include target-device specific requirements.txt in sdist (#4559) 2024-05-02 10:52:51 -07:00
Hu Dong
5ad60b0cbd [Misc] Exclude the tests directory from being packaged (#4552) 2024-05-02 10:50:25 -07:00
SangBin Cho
fb087af52e [mypy][7/N] Cover all directories (#4555) 2024-05-02 10:47:41 -07:00
alexm-nm
7038e8b803 [Kernel] Support running GPTQ 8-bit models in Marlin (#4533) 2024-05-02 12:56:22 -04:00
youkaichao
2a85f93007 [Core][Distributed] enable multiple tp group (#4512)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2024-05-02 04:28:21 +00:00
SangBin Cho
cf8cac8c70 [mypy][6/N] Fix all the core subdirectory typing (#4450)
Co-authored-by: Cade Daniel <edacih@gmail.com>
2024-05-02 03:01:00 +00:00
Ronen Schaffer
5e401bce17 [CI]Add regression tests to ensure the async engine generates metrics (#4524) 2024-05-01 19:57:12 -07:00
SangBin Cho
0d62fe58db [Bug fix][Core] assert num_new_tokens == 1 fails when SamplingParams.n is not 1 and max_tokens is large & Add tests for preemption (#4451) 2024-05-01 19:24:13 -07:00
Danny Guinther
b8afa8b95a [MISC] Rework logger to enable pythonic custom logging configuration to be provided (#4273) 2024-05-01 17:34:40 -07:00
Woosuk Kwon
826b82a260 [Misc] Fix expert_ids shape in MoE (#4517) 2024-05-01 23:47:59 +00:00
Philipp Moritz
c9d852d601 [Misc] Remove Mixtral device="cuda" declarations (#4543)
Remove the device="cuda" declarations in mixtral as promised in #4343
2024-05-01 16:30:52 -07:00
youkaichao
6ef09b08f8 [Core][Distributed] fix pynccl del error (#4508) 2024-05-01 15:23:06 -07:00
Roy
3a922c1e7e [Bugfix][Core] Fix and refactor logging stats (#4336) 2024-05-01 20:08:14 +00:00
sasha0552
c47ba4aaa9 [Bugfix] Add validation for seed (#4529) 2024-05-01 19:31:22 +00:00
Philipp Moritz
24bb4fe432 [Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.

All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.

Before this PR (with static activation scaling):

qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency 
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency

After this PR (with static activation scaling):

qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
Nick Hill
a657bfc48a [Core] Add multiproc_worker_utils for multiprocessing-based workers (#4357) 2024-05-01 18:41:59 +00:00
leiwen83
24750f4cad [Core] Enable prefix caching with block manager v2 enabled (#4142)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
Co-authored-by: Sage Moore <sagemoore@utexas.edu>
2024-05-01 11:20:32 -07:00
leiwen83
b38e42fbca [Speculative decoding] Add ngram prompt lookup decoding (#4237)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
2024-05-01 11:13:03 -07:00
Travis Johnson
8b798eec75 [CI/Build][Bugfix] VLLM_USE_PRECOMPILED should skip compilation (#4534)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
2024-05-01 18:01:50 +00:00
sasha0552
69909126a7 [Bugfix] Use random seed if seed is -1 (#4531) 2024-05-01 10:41:17 -07:00
Frαnçois
e491c7e053 [Doc] update(example model): for OpenAI compatible serving (#4503) 2024-05-01 10:14:16 -07:00
Robert Shaw
4dc8026d86 [Bugfix] Fix 307 Redirect for /metrics (#4523) 2024-05-01 09:14:13 -07:00
AnyISalIn
a88bb9b032 [Bugfix] Fix the fp8 kv_cache check error that occurs when failing to obtain the CUDA version. (#4173)
Signed-off-by: AnyISalIn <anyisalin@gmail.com>
2024-05-01 09:11:03 -07:00
SangBin Cho
6f1df80436 [Test] Add ignore_eos test (#4519) 2024-05-01 08:45:42 -04:00
Jee Li
d6f4bd7cdd [Misc]Add customized information for models (#4132) 2024-04-30 21:18:14 -07:00
Robert Caulk
c3845d82dc Allow user to define whitespace pattern for outlines (#4305) 2024-04-30 20:48:39 -07:00
Pastel!
a822eb3413 [Misc] fix typo in block manager (#4453) 2024-04-30 20:41:32 -07:00
harrywu
f458112e8a [Misc][Typo] type annotation fix (#4495) 2024-04-30 20:21:39 -07:00
Nick Hill
2e240c69a9 [Core] Centralize GPU Worker construction (#4419) 2024-05-01 01:06:34 +00:00
fuchen.ljl
ee37328da0 Unable to find Punica extension issue during source code installation (#4494)
Co-authored-by: Simon Mo <simon.mo@hey.com>
2024-05-01 00:42:09 +00:00
fuchen.ljl
6ad58f42c5 fix_tokenizer_snapshot_download_bug (#4493) 2024-04-30 16:38:50 -07:00
Li, Jiang
dd1a50a8bc [Bugfix][Minor] Make ignore_eos effective (#4468) 2024-04-30 16:33:33 -07:00
Alpay Ariyak
715c2d854d [Frontend] [Core] Tensorizer: support dynamic num_readers, update version (#4467) 2024-04-30 16:32:13 -07:00
Florian Greinacher
a494140433 [Frontend] Support complex message content for chat completions endpoint (#3467)
Co-authored-by: Lily Liu <lilyliupku@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
2024-04-30 16:28:46 -07:00
Robert Shaw
111815d482 [Kernel] Support Fp8 Checkpoints (Dynamic + Static) (#4332)
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
2024-04-30 21:46:12 +00:00
Prashant Gupta
b31a1fb63c [Doc] add visualization for multi-stage dockerfile (#4456)
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
2024-04-30 17:41:59 +00:00
leiwen83
4bb53e2dde [BugFix] fix num_lookahead_slots missing in async executor (#4165)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
2024-04-30 10:12:59 -07:00
Kunshang Ji
26f2fb5113 [Core]Refactor gptq_marlin ops (#4466) 2024-04-30 08:14:47 -04:00
Woosuk Kwon
fa32207842 [Bugfix][Kernel] Fix compute_type for MoE kernel (#4463) 2024-04-29 22:05:40 -07:00
Michael Goin
d627a3d837 [Misc] Upgrade to torch==2.3.0 (#4454) 2024-04-29 20:05:47 -04:00
youkaichao
f4f921b7f1 [Core][Distributed] use cpu group to broadcast metadata in cpu (#4444) 2024-04-29 13:52:22 -07:00
Simon Mo
ac5ccf0156 [CI] hotfix: soft fail neuron test (#4458) 2024-04-29 19:50:01 +00:00
Robert Shaw
73c8d677e5 [Kernel] Marlin Expansion: Support AutoGPTQ Models with Marlin (#3922)
Co-authored-by: alexm <alexm@neuralmagic.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
2024-04-29 09:35:34 -07:00
SangBin Cho
df29793dc7 [mypy][5/N] Support all typing on model executor (#4427) 2024-04-28 19:01:26 -07:00
Simon Mo
03dd7d52bf [CI] clean docker cache for neuron (#4441) 2024-04-28 23:32:07 +00:00
Ronen Schaffer
bf480c5302 Add more Prometheus metrics (#2764)
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
2024-04-28 15:59:33 -07:00
DefTruth
9c7306ac11 [Misc] fix typo in llm_engine init logging (#4428) 2024-04-28 18:58:30 +08:00
Robert Shaw
4ea1f9678d [BugFix] Resolved Issues For LinearMethod --> QuantConfig (#4418) 2024-04-27 18:35:33 +00:00
Nick Hill
ba4be44c32 [BugFix] Fix return type of executor execute_model methods (#4402) 2024-04-27 11:17:45 -07:00
Prashant Gupta
d6e520e170 [Core] Support offline use of local cache for models (#4374)
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Co-authored-by: Travis Johnson <tjohnson31415@gmail.com>
2024-04-27 09:59:55 -07:00
Nick Hill
81661da7b2 [BugFix] Fix min_tokens when eos_token_id is None (#4389)
Co-authored-by: DefTruth <31974251+deftruth@users.noreply.github.com>
2024-04-27 09:52:46 -07:00
Ruoyu Qin
dfea173148 [Bugfix] Abort requests when the connection to /v1/completions is interrupted (#4363) 2024-04-27 09:48:37 -07:00
Roy
7134303cbb [Bugfix][Core] Fix get decoding config from ray (#4335) 2024-04-27 11:30:08 +00:00
Caio Mendes
3da24c2df7 [Model] Phi-3 4k sliding window temp. fix (#4380) 2024-04-27 18:08:15 +08:00
Austin Veselka
eefeb16464 [Kernel] Full Tensor Parallelism for LoRA Layers (#3524)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
2024-04-27 00:03:48 -07:00
Hongxia Yang
18d23f642a [ROCm][Hardware][AMD] Enable group query attention for triton FA (#4406) 2024-04-26 23:37:40 -07:00
Roy
87f545ba6f [Misc] Fix logger format typo (#4396) 2024-04-27 13:45:02 +08:00
Cyrus Leung
8947bc3c15 [Frontend][Bugfix] Disallow extra fields in OpenAI API (#4355) 2024-04-27 05:08:24 +00:00
Philipp Moritz
12628d3c78 [Kernel] Optimize FP8 support for MoE kernel / Mixtral via static scales (#4343)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-04-27 04:49:59 +00:00
Nick Hill
258a2c58d0 [Core] Introduce DistributedGPUExecutor abstract class (#4348) 2024-04-27 04:14:26 +00:00
youkaichao
aba47be3fe [Misc] add RFC issue template (#4401)
Co-authored-by: Simon Mo <simon.mo@hey.com>
2024-04-26 15:47:45 -07:00
Cody Yu
a62aaf1df5 [Misc][Refactor] Generalize linear_method to be quant_method (#4373) 2024-04-26 16:41:14 -04:00
SangBin Cho
603ad84815 [Core] Refactoring sampler and support prompt logprob for chunked prefill (#4309) 2024-04-26 13:02:02 +00:00
SangBin Cho
a88081bf76 [CI] Disable non-lazy string operation on logging (#4326)
Co-authored-by: Danny Guinther <dguinther@neuralmagic.com>
2024-04-26 00:16:58 -07:00
Norman Mu
2f30e7c72f [Frontend] Add --log-level option to api server (#4377) 2024-04-26 05:36:01 +00:00
Cyrus Leung
a74dee9b62 [Bugfix] Fix parameter name in get_tokenizer (#4107) 2024-04-25 19:10:48 -07:00
Hongxia Yang
cf29b7eda4 [ROCm][Hardware][AMD][Doc] Documentation update for ROCm (#4376)
Co-authored-by: WoosukKwon <woosuk.kwon@berkeley.edu>
2024-04-25 18:12:25 -07:00
Nick Hill
efffb63f58 [Core] Move function tracing setup to util function (#4352) 2024-04-25 16:45:12 -07:00
Nick Hill
15e7c675b0 [Core] Add shutdown() method to ExecutorBase (#4349) 2024-04-25 16:32:48 -07:00
Roy
b6dcb4d442 [Misc] Fix flash attention backend log (#4368) 2024-04-25 12:43:32 -07:00
SangBin Cho
b5b4a398a7 [Mypy] Typing lora folder (#4337) 2024-04-25 19:13:50 +00:00
Kunshang Ji
f4bc4de1b1 [Core]refactor aqlm quant ops (#4351) 2024-04-25 15:03:56 -04:00
Caio Mendes
bd7a8eef25 [Doc] README Phi-3 name fix. (#4372)
Co-authored-by: Caio Mendes <caiocesart@microsoft.com>
2024-04-25 10:32:00 -07:00
Alexei-V-Ivanov-AMD
7ee82bef1e [CI/Build] Adding functionality to reset the node's GPUs before processing. (#4213) 2024-04-25 09:37:20 -07:00
Isotr0py
fbf152d976 [Bugfix][Model] Refactor OLMo model to support new HF format in transformers 4.40.0 (#4324)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-04-25 09:35:56 -07:00
Nick Hill
479d69fad0 [Core] Move ray_utils.py from engine to executor package (#4347) 2024-04-25 06:52:22 +00:00
Caio Mendes
96e90fdeb3 [Model] Adds Phi-3 support (#4298) 2024-04-25 03:06:57 +00:00
zifeitong
a395a638c2 [Misc] Use public API in benchmark_throughput (#4300) 2024-04-24 21:10:24 +00:00
youkaichao
2768884ac4 [Doc] Add note for docker user (#4340)
Co-authored-by: Simon Mo <simon.mo@hey.com>
2024-04-24 21:09:44 +00:00
alexm-nm
aae08249ac [Bugfix] Fix marlin kernel crash on H100 (#4218)
This PR addresses the Marlin kernel H100 crash that was reported here: neuralmagic#187.
The reason for the crash was the inline PTX assembly that introduced the async_copy with streaming behavior. The solution is to use the more standard PTX for async_copy (without the fractional L2 policy for "evict_first"). There is no performance difference between standard async_copy PTX and the previous one.
2024-04-24 10:35:01 -07:00
Roger Wang
7923dcad12 [Misc] Update ShareGPT Dataset Sampling in Serving Benchmark (#4279) 2024-04-24 09:49:13 -07:00
youkaichao
3cd9b5bb2d [Core][Distributed] use existing torch.cuda.device (#4318)
[Core][Distributed] use existing torch.cuda.device context manager (#4318)
2024-04-24 09:00:20 -07:00
542 changed files with 53105 additions and 16495 deletions

View File

@@ -0,0 +1,36 @@
import os
import zipfile
MAX_SIZE_MB = 200
def print_top_10_largest_files(zip_file):
with zipfile.ZipFile(zip_file, 'r') as z:
file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()]
file_sizes.sort(key=lambda x: x[1], reverse=True)
for f, size in file_sizes[:10]:
print(f"{f}: {size/(1024*1024)} MBs uncompressed.")
def check_wheel_size(directory):
for root, _, files in os.walk(directory):
for f in files:
if f.endswith(".whl"):
wheel_path = os.path.join(root, f)
wheel_size = os.path.getsize(wheel_path)
wheel_size_mb = wheel_size / (1024 * 1024)
if wheel_size_mb > MAX_SIZE_MB:
print(
f"Wheel {wheel_path} is too large ({wheel_size_mb} MB) "
f"compare to the allowed size ({MAX_SIZE_MB} MB).")
print_top_10_largest_files(wheel_path)
return 1
else:
print(f"Wheel {wheel_path} is within the allowed size "
f"({wheel_size_mb} MB).")
return 0
if __name__ == "__main__":
import sys
sys.exit(check_wheel_size(sys.argv[1]))

View File

@@ -0,0 +1,26 @@
#!/usr/bin/env bash
set -euo pipefail
# Install system packages
apt update
apt install -y curl jq
# Install minijinja for templating
curl -sSfL https://github.com/mitsuhiko/minijinja/releases/latest/download/minijinja-cli-installer.sh | sh
source $HOME/.cargo/env
# If BUILDKITE_PULL_REQUEST != "false", then we check the PR labels using curl and jq
if [ "$BUILDKITE_PULL_REQUEST" != "false" ]; then
PR_LABELS=$(curl -s "https://api.github.com/repos/vllm-project/vllm/pulls/$BUILDKITE_PULL_REQUEST" | jq -r '.labels[].name')
if [[ $PR_LABELS == *"perf-benchmarks"* ]]; then
echo "This PR has the 'perf-benchmarks' label. Proceeding with the nightly benchmarks."
else
echo "This PR does not have the 'perf-benchmarks' label. Skipping the nightly benchmarks."
exit 0
fi
fi
# Upload sample.yaml
buildkite-agent pipeline upload .buildkite/nightly-benchmarks/sample.yaml

View File

@@ -0,0 +1,39 @@
steps:
# NOTE(simon): You can create separate blocks for different jobs
- label: "A100: NVIDIA SMI"
agents:
queue: A100
plugins:
- kubernetes:
podSpec:
containers:
# - image: us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT
# TODO(simon): check latest main branch or use the PR image.
- image: us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:45c35f0d58f4508bf43bd6af1d3d0d0ec0c915e6
command:
- bash -c 'nvidia-smi && nvidia-smi topo -m && pwd && ls'
resources:
limits:
nvidia.com/gpu: 8
volumeMounts:
- name: devshm
mountPath: /dev/shm
nodeSelector:
nvidia.com/gpu.product: NVIDIA-A100-SXM4-80GB
volumes:
- name: devshm
emptyDir:
medium: Memory
# TODO(simon): bring H100 online
# - label: "H100: NVIDIA SMI"
# agents:
# queue: H100
# plugins:
# - docker#v5.11.0:
# image: us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:45c35f0d58f4508bf43bd6af1d3d0d0ec0c915e6
# command:
# - bash -c 'nvidia-smi && nvidia-smi topo -m'
# propagate-environment: true
# ipc: host
# gpus: all

View File

@@ -1,38 +1,73 @@
# This script build the ROCm docker image and run the API server inside the container. # This script runs test inside the corresponding ROCm docker container.
# It serves a sanity check for compilation and basic model usage.
set -ex set -ex
# Print ROCm version # Print ROCm version
echo "--- ROCm info"
rocminfo rocminfo
# Try building the docker image # cleanup older docker images
docker build -t rocm -f Dockerfile.rocm . cleanup_docker() {
# Get Docker's root directory
docker_root=$(docker info -f '{{.DockerRootDir}}')
if [ -z "$docker_root" ]; then
echo "Failed to determine Docker root directory."
exit 1
fi
echo "Docker root directory: $docker_root"
# Check disk usage of the filesystem where Docker's root directory is located
disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//')
# Define the threshold
threshold=70
if [ "$disk_usage" -gt "$threshold" ]; then
echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..."
# Remove dangling images (those that are not tagged and not used by any container)
docker image prune -f
# Remove unused volumes
docker volume prune -f
echo "Docker images and volumes cleanup completed."
else
echo "Disk usage is below $threshold%. No cleanup needed."
fi
}
# Setup cleanup # Call the cleanup docker function
remove_docker_container() { docker rm -f rocm || true; } cleanup_docker
trap remove_docker_container EXIT
remove_docker_container
# Run the image echo "--- Resetting GPUs"
docker run --device /dev/kfd --device /dev/dri --network host --name rocm rocm python3 -m vllm.entrypoints.api_server &
# Wait for the server to start echo "reset" > /opt/amdgpu/etc/gpu_state
wait_for_server_to_start() {
timeout=300
counter=0
while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do while true; do
sleep 1 sleep 3
counter=$((counter + 1)) if grep -q clean /opt/amdgpu/etc/gpu_state; then
if [ $counter -ge $timeout ]; then echo "GPUs state is \"clean\""
echo "Timeout after $timeout seconds"
break break
fi fi
done done
}
wait_for_server_to_start
# Test a simple prompt echo "--- Building container"
curl -X POST -H "Content-Type: application/json" \ sha=$(git rev-parse --short HEAD)
localhost:8000/generate \ image_name=rocm_${sha}
-d '{"prompt": "San Francisco is a"}' container_name=rocm_${sha}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)
docker build \
-t ${image_name} \
-f Dockerfile.rocm \
--progress plain \
.
remove_docker_container() {
docker rm -f ${container_name} || docker image rm -f ${image_name} || true
}
trap remove_docker_container EXIT
echo "--- Running container"
docker run \
--device /dev/kfd --device /dev/dri \
--network host \
--rm \
-e HF_TOKEN \
--name ${container_name} \
${image_name} \
/bin/bash -c "${@}"

View File

@@ -9,10 +9,10 @@ cd "$(dirname "${BASH_SOURCE[0]}")/.."
(which wget && which curl) || (apt-get update && apt-get install -y wget curl) (which wget && which curl) || (apt-get update && apt-get install -y wget curl)
# run python-based benchmarks and upload the result to buildkite # run python-based benchmarks and upload the result to buildkite
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt python3 benchmarks/benchmark_latency.py --output-json latency_results.json 2>&1 | tee benchmark_latency.txt
bench_latency_exit_code=$? bench_latency_exit_code=$?
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt
bench_throughput_exit_code=$? bench_throughput_exit_code=$?
# run server-based benchmarks and upload the result to buildkite # run server-based benchmarks and upload the result to buildkite
@@ -50,11 +50,16 @@ echo "### Serving Benchmarks" >> benchmark_results.md
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
echo "" >> benchmark_results.md echo "" >> benchmark_results.md
echo '```' >> benchmark_results.md echo '```' >> benchmark_results.md
tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines tail -n 24 benchmark_serving.txt >> benchmark_results.md # last 24 lines
echo '```' >> benchmark_results.md echo '```' >> benchmark_results.md
# if the agent binary is not found, skip uploading the results, exit 0
if [ ! -f /usr/bin/buildkite-agent ]; then
exit 0
fi
# upload the results to buildkite # upload the results to buildkite
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
# exit with the exit code of the benchmarks # exit with the exit code of the benchmarks
if [ $bench_latency_exit_code -ne 0 ]; then if [ $bench_latency_exit_code -ne 0 ]; then
@@ -69,4 +74,5 @@ if [ $bench_serving_exit_code -ne 0 ]; then
exit $bench_serving_exit_code exit $bench_serving_exit_code
fi fi
/workspace/buildkite-agent artifact upload openai-*.json rm ShareGPT_V3_unfiltered_cleaned_split.json
buildkite-agent artifact upload "*.json"

View File

@@ -10,5 +10,15 @@ remove_docker_container() { docker rm -f cpu-test || true; }
trap remove_docker_container EXIT trap remove_docker_container EXIT
remove_docker_container remove_docker_container
# Run the image and launch offline inference # Run the image
docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 --cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test
# offline inference
docker exec cpu-test bash -c "python3 examples/offline_inference.py"
# Run basic model test
docker exec cpu-test bash -c "cd tests;
pip install pytest Pillow protobuf
bash ../.buildkite/download-images.sh
cd ../
pytest -v -s tests/models --ignore=tests/models/test_llava.py --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py"

View File

@@ -4,6 +4,20 @@ set -e
# Try building the docker image # Try building the docker image
aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com
# prune old image and containers to save disk space, and only once a day
# by using a timestamp file in tmp.
if [ -f /tmp/neuron-docker-build-timestamp ]; then
last_build=$(cat /tmp/neuron-docker-build-timestamp)
current_time=$(date +%s)
if [ $((current_time - last_build)) -gt 86400 ]; then
docker system prune -f
echo $current_time > /tmp/neuron-docker-build-timestamp
fi
else
echo $(date +%s) > /tmp/neuron-docker-build-timestamp
fi
docker build -t neuron -f Dockerfile.neuron . docker build -t neuron -f Dockerfile.neuron .
# Setup cleanup # Setup cleanup

View File

@@ -5,102 +5,158 @@
steps: steps:
- label: Regression Test - label: Regression Test
mirror_hardwares: [amd]
command: pytest -v -s test_regression.py command: pytest -v -s test_regression.py
working_dir: "/vllm-workspace/tests" # optional working_dir: "/vllm-workspace/tests" # optional
- label: AsyncEngine Test - label: AsyncEngine Test
#mirror_hardwares: [amd]
command: pytest -v -s async_engine command: pytest -v -s async_engine
- label: Basic Correctness Test - label: Basic Correctness Test
mirror_hardwares: [amd]
commands: commands:
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
- label: Core Test - label: Core Test
mirror_hardwares: [amd]
command: pytest -v -s core command: pytest -v -s core
- label: Distributed Comm Ops Test - label: Distributed Comm Ops Test
command: pytest -v -s test_comm_ops.py #mirror_hardwares: [amd]
working_dir: "/vllm-workspace/tests/distributed" command: pytest -v -s distributed/test_comm_ops.py
num_gpus: 2 # only support 1 or 2 for now. working_dir: "/vllm-workspace/tests"
num_gpus: 2
- label: Distributed Tests - label: Distributed Tests
working_dir: "/vllm-workspace/tests/distributed" mirror_hardwares: [amd]
num_gpus: 2 # only support 1 or 2 for now. working_dir: "/vllm-workspace/tests"
num_gpus: 2
commands: commands:
- pytest -v -s test_pynccl.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
- pytest -v -s test_pynccl_library.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- label: Distributed Tests (Multiple Groups)
#mirror_hardwares: [amd]
working_dir: "/vllm-workspace/tests"
num_gpus: 4
commands:
- pytest -v -s distributed/test_pynccl.py
- label: Engine Test - label: Engine Test
mirror_hardwares: [amd]
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
- label: Entrypoints Test - label: Entrypoints Test
mirror_hardwares: [amd]
commands: commands:
# these tests have to be separated, because each one will allocate all posible GPU memory - pytest -v -s entrypoints -m llm
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py - pytest -v -s entrypoints -m openai
- pytest -v -s entrypoints/test_server_oot_registration.py
- label: Examples Test - label: Examples Test
working_dir: "/vllm-workspace/examples" working_dir: "/vllm-workspace/examples"
mirror_hardwares: [amd]
commands: commands:
# install aws cli for llava_example.py # install aws cli for llava_example.py
- pip install awscli # install tensorizer for tensorize_vllm_model.py
- pip install awscli tensorizer
- python3 offline_inference.py - python3 offline_inference.py
- python3 offline_inference_with_prefix.py - python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py - python3 llm_engine_example.py
- python3 llava_example.py - python3 llava_example.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- label: Inputs Test
#mirror_hardwares: [amd]
commands:
- bash ../.buildkite/download-images.sh
- pytest -v -s test_inputs.py
- pytest -v -s multimodal
- label: Kernels Test %N - label: Kernels Test %N
#mirror_hardwares: [amd]
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4 parallelism: 4
- label: Models Test - label: Models Test
#mirror_hardwares: [amd]
commands: commands:
- bash ../.buildkite/download-images.sh - pytest -v -s models -m \"not llava\"
- pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py
- label: Llava Test - label: Llava Test
mirror_hardwares: [amd]
commands: commands:
- bash ../.buildkite/download-images.sh - bash ../.buildkite/download-images.sh
- pytest -v -s models/test_llava.py - pytest -v -s models -m llava
- label: Prefix Caching Test - label: Prefix Caching Test
mirror_hardwares: [amd]
commands: commands:
- pytest -v -s prefix_caching - pytest -v -s prefix_caching
- label: Samplers Test - label: Samplers Test
#mirror_hardwares: [amd]
command: pytest -v -s samplers command: pytest -v -s samplers
- label: LogitsProcessor Test - label: LogitsProcessor Test
mirror_hardwares: [amd]
command: pytest -v -s test_logits_processor.py command: pytest -v -s test_logits_processor.py
- label: Utils Test
command: pytest -v -s test_utils.py
- label: Worker Test - label: Worker Test
mirror_hardwares: [amd]
command: pytest -v -s worker command: pytest -v -s worker
- label: Speculative decoding tests - label: Speculative decoding tests
command: pytest -v -s spec_decode #mirror_hardwares: [amd]
commands:
# See https://github.com/vllm-project/vllm/issues/5152
- export VLLM_ATTENTION_BACKEND=XFORMERS
- pytest -v -s spec_decode
- label: LoRA Test %N - label: LoRA Test %N
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT #mirror_hardwares: [amd]
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
parallelism: 4 parallelism: 4
- label: LoRA Long Context (Distributed)
#mirror_hardwares: [amd]
num_gpus: 4
# This test runs llama 13B, so it is required to run on 4 GPUs.
commands:
- pytest -v -s -x lora/test_long_context.py
- label: Tensorizer Test - label: Tensorizer Test
#mirror_hardwares: [amd]
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
- label: Metrics Test - label: Metrics Test
mirror_hardwares: [amd]
command: pytest -v -s metrics command: pytest -v -s metrics
- label: Quantization Test - label: Quantization Test
#mirror_hardwares: [amd]
command: pytest -v -s quantization command: pytest -v -s quantization
- label: Benchmarks - label: Benchmarks
working_dir: "/vllm-workspace/.buildkite" working_dir: "/vllm-workspace/.buildkite"
mirror_hardwares: [amd]
commands: commands:
- pip install aiohttp - pip install aiohttp
- bash run-benchmarks.sh - bash run-benchmarks.sh

View File

@@ -0,0 +1,64 @@
{% set docker_image = "public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT" %}
{% set default_working_dir = "/vllm-workspace/tests" %}
steps:
- label: ":docker: build image"
agents:
queue: cpu_queue
commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
- "docker push {{ docker_image }}"
env:
DOCKER_BUILDKIT: "1"
retry:
automatic:
- exit_status: -1 # Agent was lost
limit: 5
- exit_status: -10 # Agent was lost
limit: 5
- wait
{% for step in steps %}
- label: "{{ step.label }}"
agents:
{% if step.label == "Documentation Build" %}
queue: small_cpu_queue
{% elif step.no_gpu %}
queue: cpu_queue
{% elif step.num_gpus == 2 or step.num_gpus == 4 %}
queue: gpu_4_queue
{% else %}
queue: gpu_1_queue
{% endif %}
soft_fail: true
{% if step.parallelism %}
parallelism: {{ step.parallelism }}
{% endif %}
retry:
automatic:
- exit_status: -1 # Agent was lost
limit: 5
- exit_status: -10 # Agent was lost
limit: 5
plugins:
- docker#v5.2.0:
image: {{ docker_image }}
always-pull: true
propagate-environment: true
{% if not step.no_gpu %}
gpus: all
{% endif %}
{% if step.label == "Benchmarks" %}
mount-buildkite-agent: true
{% endif %}
command: ["bash", "-c", "cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}"]
environment:
- VLLM_USAGE_SOURCE=ci-test
- HF_TOKEN
{% if step.label == "Speculative decoding tests" %}
- VLLM_ATTENTION_BACKEND=XFORMERS
{% endif %}
volumes:
- /dev/shm:/dev/shm
{% endfor %}

View File

@@ -3,7 +3,6 @@
{% set default_working_dir = "/vllm-workspace/tests" %} {% set default_working_dir = "/vllm-workspace/tests" %}
steps: steps:
- label: ":docker: build image" - label: ":docker: build image"
commands: commands:
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
@@ -14,19 +13,36 @@ steps:
automatic: automatic:
- exit_status: -1 # Agent was lost - exit_status: -1 # Agent was lost
limit: 5 limit: 5
- exit_status: -10 # Agent was lost
limit: 5
- wait - wait
- label: "AMD Test" - group: "AMD Tests"
depends_on: ~
steps:
{% for step in steps %}
{% if step.mirror_hardwares and "amd" in step.mirror_hardwares %}
- label: "AMD: {{ step.label }}"
agents: agents:
queue: amd queue: amd
command: bash .buildkite/run-amd-test.sh command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" ; ")) | safe }}"
env:
DOCKER_BUILDKIT: "1"
soft_fail: true
{% endif %}
{% endfor %}
- label: "Neuron Test" - label: "Neuron Test"
depends_on: ~
agents: agents:
queue: neuron queue: neuron
command: bash .buildkite/run-neuron-test.sh command: bash .buildkite/run-neuron-test.sh
soft_fail: false
- label: "CPU Test" - label: "Intel Test"
depends_on: ~
agents:
queue: intel
command: bash .buildkite/run-cpu-test.sh command: bash .buildkite/run-cpu-test.sh
{% for step in steps %} {% for step in steps %}
@@ -41,9 +57,14 @@ steps:
automatic: automatic:
- exit_status: -1 # Agent was lost - exit_status: -1 # Agent was lost
limit: 5 limit: 5
- exit_status: -10 # Agent was lost
limit: 5
plugins: plugins:
- kubernetes: - kubernetes:
podSpec: podSpec:
{% if step.num_gpus %}
priorityClassName: gpu-priority-cls-{{ step.num_gpus }}
{% endif %}
volumes: volumes:
- name: dshm - name: dshm
emptyDir: emptyDir:

26
.clang-format Normal file
View File

@@ -0,0 +1,26 @@
BasedOnStyle: Google
UseTab: Never
IndentWidth: 2
ColumnLimit: 80
# Force pointers to the type for C++.
DerivePointerAlignment: false
PointerAlignment: Left
# Reordering #include statements can (and currently will) introduce errors
SortIncludes: false
# Style choices
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
IndentPPDirectives: BeforeHash
IncludeCategories:
- Regex: '^<'
Priority: 4
- Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/'
Priority: 3
- Regex: '^"(qoda|\.\.)/'
Priority: 2
- Regex: '.*'
Priority: 1

View File

@@ -59,6 +59,8 @@ body:
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
Please set the environment variable `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging to help debugging potential issues.
If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs. If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs.
placeholder: | placeholder: |
A clear and concise description of what the bug is. A clear and concise description of what the bug is.

49
.github/ISSUE_TEMPLATE/750-RFC.yml vendored Normal file
View File

@@ -0,0 +1,49 @@
name: 💬 Request for comments (RFC).
description: Ask for feedback on major architectural changes or design choices.
title: "[RFC]: "
labels: ["RFC"]
body:
- type: markdown
attributes:
value: >
#### Please take a look at previous [RFCs](https://github.com/vllm-project/vllm/issues?q=label%3ARFC+sort%3Aupdated-desc) for reference.
- type: textarea
attributes:
label: Motivation.
description: >
The motivation of the RFC.
validations:
required: true
- type: textarea
attributes:
label: Proposed Change.
description: >
The proposed change of the RFC.
validations:
required: true
- type: textarea
attributes:
label: Feedback Period.
description: >
The feedback period of the RFC. Usually at least one week.
validations:
required: false
- type: textarea
attributes:
label: CC List.
description: >
The list of people you want to CC.
validations:
required: false
- type: textarea
attributes:
label: Any Other Things.
description: >
Any other things you would like to mention.
validations:
required: false
- type: markdown
attributes:
value: >
Thanks for contributing 🎉!

42
.github/workflows/clang-format.yml vendored Normal file
View File

@@ -0,0 +1,42 @@
name: clang-format
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
clang-format:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install clang-format==18.1.5
- name: Running clang-format
run: |
EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu'
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
'csrc/punica/bgmv/bgmv_config.h'
'csrc/punica/bgmv/bgmv_impl.cuh'
'csrc/punica/bgmv/vec_dtypes.cuh'
'csrc/punica/punica_ops.cu'
'csrc/punica/type_convert.h'
)
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
| xargs clang-format --dry-run --Werror

View File

@@ -33,19 +33,19 @@ jobs:
- name: Mypy - name: Mypy
run: | run: |
mypy vllm/attention --config-file pyproject.toml mypy vllm/attention --config-file pyproject.toml
# TODO(sang): Fix nested dir mypy vllm/core --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml
mypy vllm/multimodal --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml
# TODO(sang): Fix nested dir mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/model_executor/*.py --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml
# TODO(sang): Fix nested dir mypy vllm/logging --config-file pyproject.toml
# mypy vllm/lora/*.py --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml

View File

@@ -49,7 +49,7 @@ jobs:
matrix: matrix:
os: ['ubuntu-20.04'] os: ['ubuntu-20.04']
python-version: ['3.8', '3.9', '3.10', '3.11'] python-version: ['3.8', '3.9', '3.10', '3.11']
pytorch-version: ['2.2.1'] # Must be the most recent version that meets requirements-cuda.txt. pytorch-version: ['2.3.0'] # Must be the most recent version that meets requirements-cuda.txt.
cuda-version: ['11.8', '12.1'] cuda-version: ['11.8', '12.1']
steps: steps:
@@ -58,6 +58,9 @@ jobs:
- name: Setup ccache - name: Setup ccache
uses: hendrikmuhs/ccache-action@v1.2 uses: hendrikmuhs/ccache-action@v1.2
with:
create-symlink: true
key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }}
- name: Set up Linux Env - name: Set up Linux Env
if: ${{ runner.os == 'Linux' }} if: ${{ runner.os == 'Linux' }}
@@ -79,6 +82,8 @@ jobs:
- name: Build wheel - name: Build wheel
shell: bash shell: bash
env:
CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size
run: | run: |
bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename) wheel_name=$(ls dist/*whl | xargs -n 1 basename)

View File

@@ -8,7 +8,7 @@ module.exports = async (github, context, core) => {
generate_release_notes: true, generate_release_notes: true,
name: process.env.RELEASE_TAG, name: process.env.RELEASE_TAG,
owner: context.repo.owner, owner: context.repo.owner,
prerelease: false, prerelease: true,
repo: context.repo.repo, repo: context.repo.repo,
tag_name: process.env.RELEASE_TAG, tag_name: process.env.RELEASE_TAG,
}); });

View File

@@ -31,7 +31,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# requirements.txt files and should be kept consistent. The ROCm torch # requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from Dockerfile.rocm # versions are derived from Dockerfile.rocm
# #
set(TORCH_SUPPORTED_VERSION_CUDA "2.2.1") set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0")
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1") set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1") set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
@@ -66,19 +66,6 @@ endif()
# #
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
#
# Normally `torch.utils.cpp_extension.CUDAExtension` would add
# `libtorch_python.so` for linking against an extension. Torch's cmake
# configuration does not include this library (presumably since the cmake
# config is used for standalone C++ binaries that link against torch).
# The `libtorch_python.so` library defines some of the glue code between
# torch/python via pybind and is required by VLLM extensions for this
# reason. So, add it by manually with `find_library` using torch's
# installed library path.
#
find_library(torch_python_LIBRARY torch_python PATHS
"${TORCH_INSTALL_PREFIX}/lib")
# #
# Forward the non-CUDA device extensions to external CMake scripts. # Forward the non-CUDA device extensions to external CMake scripts.
# #
@@ -167,17 +154,47 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/fp8/fp8_cuda_kernels.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu" "csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu" "csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp") "csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent)
SET(CUTLASS_ENABLE_HEADERS_ONLY=ON)
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
# CUTLASS 3.5.0
GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc
)
FetchContent_MakeAvailable(cutlass)
list(APPEND VLLM_EXT_SRC list(APPEND VLLM_EXT_SRC
"csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/marlin_cuda_kernel.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/custom_all_reduce.cu") "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu")
#
# The CUTLASS kernels for Hopper require sm90a to be enabled.
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
set_source_files_properties(
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a")
endif()
endif() endif()
define_gpu_extension_target( define_gpu_extension_target(
@@ -187,6 +204,8 @@ define_gpu_extension_target(
SOURCES ${VLLM_EXT_SRC} SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS} COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES} ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI) WITH_SOABI)
# #
@@ -194,7 +213,7 @@ define_gpu_extension_target(
# #
set(VLLM_MOE_EXT_SRC set(VLLM_MOE_EXT_SRC
"csrc/moe/moe_ops.cpp" "csrc/moe/torch_bindings.cpp"
"csrc/moe/topk_softmax_kernels.cu") "csrc/moe/topk_softmax_kernels.cu")
define_gpu_extension_target( define_gpu_extension_target(
@@ -204,6 +223,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_MOE_EXT_SRC} SOURCES ${VLLM_MOE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS} COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES} ARCHITECTURES ${VLLM_GPU_ARCHES}
USE_SABI 3
WITH_SOABI) WITH_SOABI)
# #
@@ -217,7 +237,8 @@ set(VLLM_PUNICA_EXT_SRC
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
"csrc/punica/punica_ops.cc") "csrc/punica/punica_ops.cu"
"csrc/punica/torch_bindings.cpp")
# #
# Copy GPU compilation flags+update for punica # Copy GPU compilation flags+update for punica
@@ -241,6 +262,9 @@ if (${VLLM_GPU_LANG} STREQUAL "CUDA")
endif() endif()
endforeach() endforeach()
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
elseif(${VLLM_GPU_LANG} STREQUAL "HIP")
set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES})
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
endif() endif()
if (VLLM_PUNICA_GPU_ARCHES) if (VLLM_PUNICA_GPU_ARCHES)
@@ -251,6 +275,7 @@ if (VLLM_PUNICA_GPU_ARCHES)
SOURCES ${VLLM_PUNICA_EXT_SRC} SOURCES ${VLLM_PUNICA_EXT_SRC}
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS} COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES} ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
USE_SABI 3
WITH_SOABI) WITH_SOABI)
else() else()
message(WARNING "Unable to create _punica_C target because none of the " message(WARNING "Unable to create _punica_C target because none of the "
@@ -275,9 +300,7 @@ add_custom_target(default)
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling C extension.") message(STATUS "Enabling C extension.")
add_dependencies(default _C) add_dependencies(default _C)
endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Enabling moe extension.") message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C) add_dependencies(default _moe_C)

View File

@@ -1,9 +1,13 @@
# The vLLM Dockerfile is used to construct vLLM image that can be directly used # The vLLM Dockerfile is used to construct vLLM image that can be directly used
# to run the OpenAI compatible server. # to run the OpenAI compatible server.
# Please update any changes made here to
# docs/source/dev/dockerfile/dockerfile.rst and
# docs/source/assets/dev/dockerfile-stages-dependency.png
#################### BASE BUILD IMAGE #################### #################### BASE BUILD IMAGE ####################
# prepare basic build environment # prepare basic build environment
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS dev
RUN apt-get update -y \ RUN apt-get update -y \
&& apt-get install -y python3-pip git && apt-get install -y python3-pip git
@@ -12,7 +16,7 @@ RUN apt-get update -y \
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully # https://github.com/pytorch/pytorch/issues/107960 -- hopefully
# this won't be needed for future versions of this docker image # this won't be needed for future versions of this docker image
# or future versions of triton. # or future versions of triton.
RUN ldconfig /usr/local/cuda-12.1/compat/ RUN ldconfig /usr/local/cuda-12.4/compat/
WORKDIR /workspace WORKDIR /workspace
@@ -71,34 +75,15 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/pip \
python3 setup.py bdist_wheel --dist-dir=dist python3 setup.py bdist_wheel --dist-dir=dist
# the `vllm_nccl` package must be installed from source distribution # check the size of the wheel, we cannot upload wheels larger than 100MB
# pip is too smart to store a wheel in the cache, and other CI jobs COPY .buildkite/check-wheel-size.py check-wheel-size.py
# will directly use the wheel from the cache, which is not what we want. RUN python3 check-wheel-size.py dist
# we need to remove it manually
RUN --mount=type=cache,target=/root/.cache/pip \
pip cache remove vllm_nccl*
#################### EXTENSION Build IMAGE #################### #################### EXTENSION Build IMAGE ####################
#################### FLASH_ATTENTION Build IMAGE ####################
FROM dev as flash-attn-builder
# max jobs used for build
ARG max_jobs=2
ENV MAX_JOBS=${max_jobs}
# flash attention version
ARG flash_attn_version=v2.5.6
ENV FLASH_ATTN_VERSION=${flash_attn_version}
WORKDIR /usr/src/flash-attention-v2
# Download the wheel or build it if a pre-compiled release doesn't exist
RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
--no-build-isolation --no-deps --no-cache-dir
#################### FLASH_ATTENTION Build IMAGE ####################
#################### vLLM installation IMAGE #################### #################### vLLM installation IMAGE ####################
# image with vLLM installed # image with vLLM installed
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
WORKDIR /vllm-workspace WORKDIR /vllm-workspace
RUN apt-get update -y \ RUN apt-get update -y \
@@ -108,16 +93,12 @@ RUN apt-get update -y \
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully # https://github.com/pytorch/pytorch/issues/107960 -- hopefully
# this won't be needed for future versions of this docker image # this won't be needed for future versions of this docker image
# or future versions of triton. # or future versions of triton.
RUN ldconfig /usr/local/cuda-12.1/compat/ RUN ldconfig /usr/local/cuda-12.4/compat/
# install vllm wheel first, so that torch etc will be installed # install vllm wheel first, so that torch etc will be installed
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/pip \
pip install dist/*.whl --verbose pip install dist/*.whl --verbose
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
--mount=type=cache,target=/root/.cache/pip \
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
#################### vLLM installation IMAGE #################### #################### vLLM installation IMAGE ####################

View File

@@ -1,13 +1,15 @@
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform. # This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
FROM ubuntu:22.04 FROM ubuntu:22.04 AS cpu-test-1
RUN apt-get update -y \ RUN apt-get update -y \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \ && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
RUN pip install --upgrade pip \ RUN pip install --upgrade pip \
&& pip install wheel packaging ninja setuptools>=49.4.0 numpy && pip install wheel packaging ninja "setuptools>=49.4.0" numpy
FROM cpu-test-1 AS build
COPY ./ /workspace/vllm COPY ./ /workspace/vllm
@@ -17,4 +19,8 @@ RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.py
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
WORKDIR /workspace/
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
CMD ["/bin/bash"] CMD ["/bin/bash"]

View File

@@ -28,7 +28,7 @@ COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt
RUN cd /app/vllm \ RUN cd /app/vllm \
&& python3 -m pip install -U -r requirements-neuron.txt && python3 -m pip install -U -r requirements-neuron.txt
ENV VLLM_BUILD_WITH_NEURON 1 ENV VLLM_TARGET_DEVICE neuron
RUN cd /app/vllm \ RUN cd /app/vllm \
&& pip install -e . \ && pip install -e . \
&& cd .. && cd ..

View File

@@ -46,7 +46,7 @@ RUN apt-get update && apt-get install -y \
### Mount Point ### ### Mount Point ###
# When launching the container, mount the code directory to /app # When launching the container, mount the code directory to /app
ARG APP_MOUNT=/app ARG APP_MOUNT=/vllm-workspace
VOLUME [ ${APP_MOUNT} ] VOLUME [ ${APP_MOUNT} ]
WORKDIR ${APP_MOUNT} WORKDIR ${APP_MOUNT}
@@ -89,18 +89,27 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \
&& cd ../..; \ && cd ../..; \
fi fi
COPY ./ /app/vllm WORKDIR /vllm-workspace
COPY . .
#RUN python3 -m pip install pynvml # to be removed eventually
RUN python3 -m pip install --upgrade pip numba RUN python3 -m pip install --upgrade pip numba
RUN cd /app \ # make sure punica kernels are built (for LoRA)
&& cd vllm \ ENV VLLM_INSTALL_PUNICA_KERNELS=1
&& pip install -U -r requirements-rocm.txt \ # Workaround for ray >= 2.10.0
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \ ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -U -r requirements-rocm.txt \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
&& python3 setup.py install \ && python3 setup.py install \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.abi3.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.abi3.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.abi3.so vllm/ \
&& cd .. && cd ..
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3
CMD ["/bin/bash"] CMD ["/bin/bash"]

View File

@@ -1,6 +1,9 @@
include LICENSE include LICENSE
include requirements-common.txt include requirements-common.txt
include requirements-cuda.txt include requirements-cuda.txt
include requirements-rocm.txt
include requirements-neuron.txt
include requirements-cpu.txt
include CMakeLists.txt include CMakeLists.txt
recursive-include cmake * recursive-include cmake *

View File

@@ -14,6 +14,24 @@ Easy, fast, and cheap LLM serving for everyone
</p> </p>
---
**Ray Summit CPF is Open (June 4th to June 20th)!**
There will be a track for vLLM at the Ray Summit (09/30-10/02, SF) this year!
If you have cool projects related to vLLM or LLM inference, we would love to see your proposals.
This will be a great chance for everyone in the community to get together and learn.
Please submit your proposal [here](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/eventsite)
**The Fourth vLLM Bay Area Meetup (June 11th 5:30pm-8pm PT)**
We are thrilled to announce our fourth vLLM Meetup!
The vLLM team will share recent updates and roadmap.
We will also have vLLM collaborators from BentoML and Cloudflare coming up to the stage to discuss their experience in deploying LLMs with vLLM.
Please register [here](https://lu.ma/agivllm) and join us!
---
*Latest News* 🔥 *Latest News* 🔥
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing). - [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing). - [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
@@ -51,40 +69,14 @@ vLLM is flexible and easy to use with:
- (Experimental) Prefix caching support - (Experimental) Prefix caching support
- (Experimental) Multi-lora support - (Experimental) Multi-lora support
vLLM seamlessly supports many Hugging Face models, including the following architectures: vLLM seamlessly supports most popular open-source models on HuggingFace, including:
- Transformer-like LLMs (e.g., Llama)
- Mixture-of-Expert LLMs (e.g., Mixtral)
- Multi-modal LLMs (e.g., LLaVA)
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.) Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) ## Getting Started
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
- Command-R (`CohereForAI/c4ai-command-r-v01`, etc.)
- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.)
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.)
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
- LLaMA, Llama 2, and Meta Llama 3 (`meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.)
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.)
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
@@ -92,9 +84,7 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get
pip install vllm pip install vllm
``` ```
## Getting Started Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more.
Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started.
- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) - [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html) - [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html) - [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
@@ -104,6 +94,33 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started
We welcome and value any contributions and collaborations. We welcome and value any contributions and collaborations.
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved. Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
## Sponsors
vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support!
<!-- Note: Please sort them in alphabetical order. -->
<!-- Note: Please keep these consistent with docs/source/community/sponsors.md -->
- a16z
- AMD
- Anyscale
- AWS
- Crusoe Cloud
- Databricks
- DeepInfra
- Dropbox
- Lambda Lab
- NVIDIA
- Replicate
- Roblox
- RunPod
- Sequoia Capital
- Trainy
- UC Berkeley
- UC San Diego
We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM.
## Citation ## Citation
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180): If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):

View File

@@ -89,6 +89,9 @@ async def async_request_tgi(
output.latency = most_recent_timestamp - st output.latency = most_recent_timestamp - st
output.success = True output.success = True
output.generated_text = data["generated_text"] output.generated_text = data["generated_text"]
else:
output.error = response.reason or ""
output.success = False
except Exception: except Exception:
output.success = False output.success = False
exc_info = sys.exc_info() exc_info = sys.exc_info()
@@ -276,6 +279,9 @@ async def async_request_openai_completions(
output.generated_text = generated_text output.generated_text = generated_text
output.success = True output.success = True
output.latency = latency output.latency = latency
else:
output.error = response.reason or ""
output.success = False
except Exception: except Exception:
output.success = False output.success = False
exc_info = sys.exc_info() exc_info = sys.exc_info()

View File

@@ -1,14 +1,16 @@
"""Benchmark the latency of processing a single batch of requests.""" """Benchmark the latency of processing a single batch of requests."""
import argparse import argparse
import json
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@@ -18,6 +20,8 @@ def main(args: argparse.Namespace):
# NOTE(woosuk): If the request cannot be processed in a single batch, # NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches. # the engine will automatically process the request in multiple batches.
llm = LLM(model=args.model, llm = LLM(model=args.model,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
tokenizer=args.tokenizer, tokenizer=args.tokenizer,
quantization=args.quantization, quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size, tensor_parallel_size=args.tensor_parallel_size,
@@ -28,9 +32,12 @@ def main(args: argparse.Namespace):
quantization_param_path=args.quantization_param_path, quantization_param_path=args.quantization_param_path,
device=args.device, device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight, ray_workers_use_nsight=args.ray_workers_use_nsight,
use_v2_block_manager=args.use_v2_block_manager,
enable_chunked_prefill=args.enable_chunked_prefill, enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir, download_dir=args.download_dir,
block_size=args.block_size) block_size=args.block_size,
gpu_memory_utilization=args.gpu_memory_utilization,
distributed_executor_backend=args.distributed_executor_backend)
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=args.n, n=args.n,
@@ -44,7 +51,9 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000, dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size, size=(args.batch_size,
args.input_len)) args.input_len))
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist() dummy_inputs: List[PromptStrictInputs] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
def run_to_completion(profile_dir: Optional[str] = None): def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir: if profile_dir:
@@ -55,13 +64,13 @@ def main(args: argparse.Namespace):
], ],
on_trace_ready=torch.profiler.tensorboard_trace_handler( on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p: str(profile_dir))) as p:
llm.generate(prompt_token_ids=dummy_prompt_token_ids, llm.generate(dummy_inputs,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=False) use_tqdm=False)
print(p.key_averages()) print(p.key_averages())
else: else:
start_time = time.perf_counter() start_time = time.perf_counter()
llm.generate(prompt_token_ids=dummy_prompt_token_ids, llm.generate(dummy_inputs,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=False) use_tqdm=False)
end_time = time.perf_counter() end_time = time.perf_counter()
@@ -93,12 +102,24 @@ def main(args: argparse.Namespace):
for percentage, percentile in zip(percentages, percentiles): for percentage, percentile in zip(percentages, percentiles):
print(f'{percentage}% percentile latency: {percentile} seconds') print(f'{percentage}% percentile latency: {percentile} seconds')
# Output JSON results if specified
if args.output_json:
results = {
"avg_latency": np.mean(latencies),
"latencies": latencies.tolist(),
"percentiles": dict(zip(percentages, percentiles.tolist())),
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Benchmark the latency of processing a single batch of ' description='Benchmark the latency of processing a single batch of '
'requests till completion.') 'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m') parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--speculative-model', type=str, default=None)
parser.add_argument('--num-speculative-tokens', type=int, default=None)
parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
@@ -137,15 +158,13 @@ if __name__ == '__main__':
action='store_true', action='store_true',
help='enforce eager mode and disable CUDA graph') help='enforce eager mode and disable CUDA graph')
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", '--kv-cache-dtype',
type=str, type=str,
choices=['auto', 'fp8'], choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default='auto', default="auto",
help= help='Data type for kv cache storage. If "auto", will use model '
'Data type for kv cache storage. If "auto", will use model data type. ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater ' 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
parser.add_argument( parser.add_argument(
'--quantization-param-path', '--quantization-param-path',
type=str, type=str,
@@ -181,6 +200,7 @@ if __name__ == '__main__':
action='store_true', action='store_true',
help='If True, the prefill requests can be chunked based on the ' help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens') 'max_num_batched_tokens')
parser.add_argument('--use-v2-block-manager', action='store_true')
parser.add_argument( parser.add_argument(
"--ray-workers-use-nsight", "--ray-workers-use-nsight",
action='store_true', action='store_true',
@@ -191,5 +211,23 @@ if __name__ == '__main__':
default=None, default=None,
help='directory to download and load the weights, ' help='directory to download and load the weights, '
'default to the default cache dir of huggingface') 'default to the default cache dir of huggingface')
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the latency results in JSON format.')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp'],
default=None,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@@ -16,20 +16,22 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
def main(args): def main(args):
llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat", llm = LLM(model=args.model,
tokenizer_mode='auto', tokenizer_mode='auto',
trust_remote_code=True, trust_remote_code=True,
enforce_eager=True, enforce_eager=True,
use_v2_block_manager=args.use_v2_block_manager,
tensor_parallel_size=args.tensor_parallel_size,
enable_prefix_caching=args.enable_prefix_caching) enable_prefix_caching=args.enable_prefix_caching)
num_prompts = 100 num_prompts = 100
prompts = [PROMPT] * num_prompts prompts = [PROMPT] * num_prompts
sampling_params = SamplingParams(temperature=0, max_tokens=100) sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
print("------warm up------") print("------warm up------")
test_prefix( test_prefix(
llm=llm, llm=llm,
prompts=prompts[:1], prompts=prompts,
sampling_params=sampling_params, sampling_params=sampling_params,
) )
@@ -45,8 +47,16 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Benchmark the performance with or without automatic ' description='Benchmark the performance with or without automatic '
'prefix caching.') 'prefix caching.')
parser.add_argument('--model',
type=str,
default='baichuan-inc/Baichuan2-13B-Chat')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching', parser.add_argument('--enable-prefix-caching',
action='store_true', action='store_true',
help='enable prefix caching') help='enable prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@@ -17,6 +17,10 @@ On the client side, run:
--dataset-path <path to dataset> \ --dataset-path <path to dataset> \
--request-rate <request_rate> \ # By default <request_rate> is inf --request-rate <request_rate> \ # By default <request_rate> is inf
--num-prompts <num_prompts> # By default <num_prompts> is 1000 --num-prompts <num_prompts> # By default <num_prompts> is 1000
when using tgi backend, add
--endpoint /generate_stream
to the end of the command above.
""" """
import argparse import argparse
import asyncio import asyncio
@@ -27,7 +31,7 @@ import time
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator, List, Tuple from typing import AsyncGenerator, List, Optional, Tuple
import numpy as np import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
@@ -52,13 +56,20 @@ class BenchmarkMetrics:
mean_tpot_ms: float mean_tpot_ms: float
median_tpot_ms: float median_tpot_ms: float
p99_tpot_ms: float p99_tpot_ms: float
mean_itl_ms: float
median_itl_ms: float
p99_itl_ms: float
def sample_sharegpt_requests( def sample_sharegpt_requests(
dataset_path: str, dataset_path: str,
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, int, int]]: ) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset. # Load the dataset.
with open(dataset_path) as f: with open(dataset_path) as f:
dataset = json.load(f) dataset = json.load(f)
@@ -68,38 +79,32 @@ def sample_sharegpt_requests(
dataset = [(data["conversations"][0]["value"], dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset] data["conversations"][1]["value"]) for data in dataset]
# some of these will be filtered out, so sample more than we need # Shuffle the dataset.
sampled_indices = random.sample(range(len(dataset)), random.shuffle(dataset)
int(num_requests * 1.2))
dataset = [dataset[i] for i in sampled_indices] # Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
# Tokenize the prompts and completions. # Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset] prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompts).input_ids prompt_token_ids = tokenizer(prompt).input_ids
completions = [completion for _, completion in dataset] completion = dataset[i][1]
completion_token_ids = tokenizer(completions).input_ids completion_token_ids = tokenizer(completion).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Filter out too long sequences.
filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset:
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4: if prompt_len < 4 or output_len < 4:
# Prune too short sequences. # Prune too short sequences.
# This is because TGI causes errors when the input or output length
# is too short.
continue continue
if prompt_len > 1024 or prompt_len + output_len > 2048: if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences. # Prune too long sequences.
continue continue
filtered_dataset.append((prompt, prompt_len, output_len)) filtered_dataset.append((prompt, prompt_len, output_len))
# Sample the requests. return filtered_dataset
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
def sample_sonnet_requests( def sample_sonnet_requests(
@@ -198,21 +203,34 @@ def calculate_metrics(
actual_output_lens = [] actual_output_lens = []
total_input = 0 total_input = 0
completed = 0 completed = 0
itls = []
tpots = [] tpots = []
ttfts = [] ttfts = []
for i in range(len(outputs)): for i in range(len(outputs)):
if outputs[i].success: if outputs[i].success:
output_len = len(tokenizer(outputs[i].generated_text).input_ids) # We use the tokenizer to count the number of output tokens for all
# serving backends instead of looking at len(outputs[i].itl) since
# multiple output tokens may be bundled together
# Note: this may inflate the output token count slightly
output_len = len(
tokenizer(outputs[i].generated_text,
add_special_tokens=False).input_ids)
actual_output_lens.append(output_len) actual_output_lens.append(output_len)
total_input += input_requests[i][1] total_input += input_requests[i][1]
if output_len > 1: if output_len > 1:
tpots.append( tpots.append(
(outputs[i].latency - outputs[i].ttft) / (output_len - 1)) (outputs[i].latency - outputs[i].ttft) / (output_len - 1))
itls += outputs[i].itl
ttfts.append(outputs[i].ttft) ttfts.append(outputs[i].ttft)
completed += 1 completed += 1
else: else:
actual_output_lens.append(0) actual_output_lens.append(0)
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2)
metrics = BenchmarkMetrics( metrics = BenchmarkMetrics(
completed=completed, completed=completed,
total_input=total_input, total_input=total_input,
@@ -224,9 +242,12 @@ def calculate_metrics(
1000, # ttfts is empty if streaming is not supported by backend 1000, # ttfts is empty if streaming is not supported by backend
median_ttft_ms=np.median(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000,
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
mean_tpot_ms=np.mean(tpots) * 1000, mean_tpot_ms=np.mean(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000,
p99_tpot_ms=np.percentile(tpots, 99) * 1000, p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
mean_itl_ms=np.mean(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
) )
return metrics, actual_output_lens return metrics, actual_output_lens
@@ -248,6 +269,24 @@ async def benchmark(
else: else:
raise ValueError(f"Unknown backend: {backend}") raise ValueError(f"Unknown backend: {backend}")
print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len = input_requests[0]
test_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
api_url=api_url,
prompt_len=test_prompt_len,
output_len=test_output_len,
best_of=best_of,
use_beam_search=use_beam_search,
)
test_output = await request_func(request_func_input=test_input)
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}")
else:
print("Initial test run completed. Starting main benchmark run...")
print(f"Traffic request rate: {request_rate}") print(f"Traffic request rate: {request_rate}")
pbar = None if disable_tqdm else tqdm(total=len(input_requests)) pbar = None if disable_tqdm else tqdm(total=len(input_requests))
@@ -308,6 +347,10 @@ async def benchmark(
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
metrics.median_tpot_ms)) metrics.median_tpot_ms))
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-'))
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
print("=" * 50) print("=" * 50)
result = { result = {
@@ -324,6 +367,9 @@ async def benchmark(
"mean_tpot_ms": metrics.mean_tpot_ms, "mean_tpot_ms": metrics.mean_tpot_ms,
"median_tpot_ms": metrics.median_tpot_ms, "median_tpot_ms": metrics.median_tpot_ms,
"p99_tpot_ms": metrics.p99_tpot_ms, "p99_tpot_ms": metrics.p99_tpot_ms,
"mean_itl_ms": metrics.mean_itl_ms,
"median_itl_ms": metrics.median_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms,
"input_lens": [output.prompt_len for output in outputs], "input_lens": [output.prompt_len for output in outputs],
"output_lens": actual_output_lens, "output_lens": actual_output_lens,
"ttfts": [output.ttft for output in outputs], "ttfts": [output.ttft for output in outputs],
@@ -361,6 +407,7 @@ def main(args: argparse.Namespace):
dataset_path=args.dataset, dataset_path=args.dataset,
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
) )
elif args.dataset_name == "sharegpt": elif args.dataset_name == "sharegpt":
@@ -368,6 +415,7 @@ def main(args: argparse.Namespace):
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
) )
elif args.dataset_name == "sonnet": elif args.dataset_name == "sonnet":
@@ -524,6 +572,12 @@ if __name__ == "__main__":
default=1000, default=1000,
help="Number of prompts to process.", help="Number of prompts to process.",
) )
parser.add_argument(
"--sharegpt-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length "
"from the ShareGPT dataset.")
parser.add_argument( parser.add_argument(
"--sonnet-input-len", "--sonnet-input-len",
type=int, type=int,

View File

@@ -78,6 +78,7 @@ def run_vllm(
enable_prefix_caching: bool, enable_prefix_caching: bool,
enable_chunked_prefill: bool, enable_chunked_prefill: bool,
max_num_batched_tokens: int, max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None, download_dir: Optional[str] = None,
) -> float: ) -> float:
@@ -100,28 +101,26 @@ def run_vllm(
download_dir=download_dir, download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill, enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
) )
# Add the requests to the engine. # Add the requests to the engine.
prompts = []
sampling_params = []
for prompt, _, output_len in requests: for prompt, _, output_len in requests:
sampling_params = SamplingParams( prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n, n=n,
temperature=0.0 if use_beam_search else 1.0, temperature=0.0 if use_beam_search else 1.0,
top_p=1.0, top_p=1.0,
use_beam_search=use_beam_search, use_beam_search=use_beam_search,
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
) ))
# FIXME(woosuk): Do not use internal method.
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=sampling_params,
)
start = time.perf_counter() start = time.perf_counter()
# FIXME(woosuk): Do not use internal method. llm.generate(prompts, sampling_params, use_tqdm=True)
llm._run_engine(use_tqdm=True)
end = time.perf_counter() end = time.perf_counter()
return end - start return end - start
@@ -228,8 +227,8 @@ def main(args: argparse.Namespace):
args.enforce_eager, args.kv_cache_dtype, args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device, args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill, args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.gpu_memory_utilization, args.max_num_batched_tokens, args.distributed_executor_backend,
args.download_dir) args.gpu_memory_utilization, args.download_dir)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -245,6 +244,18 @@ def main(args: argparse.Namespace):
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s") f"{total_num_tokens / elapsed_time:.2f} tokens/s")
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.") parser = argparse.ArgumentParser(description="Benchmark the throughput.")
@@ -314,15 +325,13 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="enforce eager execution") help="enforce eager execution")
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", '--kv-cache-dtype',
type=str, type=str,
choices=["auto", "fp8"], choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto", default="auto",
help= help='Data type for kv cache storage. If "auto", will use model '
'Data type for kv cache storage. If "auto", will use model data type. ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater ' 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
parser.add_argument( parser.add_argument(
'--quantization-param-path', '--quantization-param-path',
type=str, type=str,
@@ -356,6 +365,18 @@ if __name__ == "__main__":
default=None, default=None,
help='directory to download and load the weights, ' help='directory to download and load the weights, '
'default to the default cache dir of huggingface') 'default to the default cache dir of huggingface')
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp'],
default=None,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.')
args = parser.parse_args() args = parser.parse_args()
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model

View File

@@ -0,0 +1,352 @@
import argparse
import copy
import itertools
import pickle as pkl
import time
from typing import Callable, Iterable, List, Tuple
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]
# helpers
def to_fp8(tensor: torch.tensor) -> torch.tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.tensor) -> torch.tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.tensor, torch.tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
if dtype == torch.int8:
return to_int8(a), to_int8(b)
if dtype == torch.float8_e4m3fn:
return to_fp8(a), to_fp8(b)
raise ValueError("unsupported dtype")
# impl
def pytorch_i8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
return torch.mm(a, b)
def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
return torch._scaled_mm(a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype)
def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
scale_a: torch.tensor, scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
return torch._scaled_mm(a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
use_fast_accum=True)
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
return ops.cutlass_scaled_mm_dq(a,
b,
scale_a,
scale_b,
out_dtype=out_dtype)
# bench
def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor, out_dtype: torch.dtype, label: str,
sub_label: str, fn: Callable, description: str) -> TMeasurement:
min_run_time = 1
globals = {
"a": a,
"b": b,
"scale_a": scale_a,
"scale_b": scale_b,
"out_dtype": out_dtype,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(a, b, scale_a, scale_b, out_dtype)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
assert dtype == torch.int8
a, b = make_rand_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
timers = []
# pytorch impl
timers.append(
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
torch.bfloat16, label, sub_label, pytorch_i8_impl,
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
# cutlass impl
timers.append(
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
torch.bfloat16, label, sub_label, cutlass_impl,
"cutlass_i8_i8_bf16_scaled_mm"))
return timers
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
assert dtype == torch.float8_e4m3fn
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
timers = []
# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm"))
# pytorch impl: bf16 output, with fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
pytorch_fp8_impl_fast_accum,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"))
# pytorch impl: fp16 output, without fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm"))
# pytorch impl: fp16 output, with fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
pytorch_fp8_impl_fast_accum,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"))
# cutlass impl: bf16 output
timers.append(
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
torch.bfloat16, label, sub_label, cutlass_impl,
"cutlass_fp8_fp8_bf16_scaled_mm"))
# cutlass impl: fp16 output
timers.append(
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
torch.float16, label, sub_label, cutlass_impl,
"cutlass_fp8_fp8_fp16_scaled_mm"))
return timers
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label)
if dtype == torch.float8_e4m3fn:
return bench_fp8(dtype, m, k, n, label, sub_label)
raise ValueError("unsupported type")
# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
def run(dtype: torch.dtype,
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
f"MKN=({m}x{k}x{n})")
print_timers(timers)
results.extend(timers)
return results
# output makers
def make_output(data: Iterable[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]],
base_description: str,
timestamp=None):
print(f"== All Results {base_description} ====")
print_timers(data)
# pickle all the results
timestamp = int(time.time()) if timestamp is None else timestamp
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
pkl.dump(data, f)
# argparse runners
def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs)
make_output(data, MKNs, f"square_bench-{args.dtype}")
def run_range_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
n = len(dim_sizes)
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
MKNs = list(zip(Ms, Ks, Ns))
data = run(args.dtype, MKNs)
make_output(data, MKNs, f"range_bench-{args.dtype}")
def run_model_bench(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KNs.append(KN)
return KNs
model_bench_data = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
Ms = args.batch_sizes
KNs = model_shapes(model, tp_size)
MKNs = []
for m in Ms:
for k, n in KNs:
MKNs.append((m, k, n))
data = run(args.dtype, MKNs)
model_bench_data.append(data)
# Print all results
for data, model_tp in zip(model_bench_data, models_tps):
model, tp_size = model_tp
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
print_timers(data)
timestamp = int(time.time())
all_data = []
for d in model_bench_data:
all_data.extend(d)
# pickle all data
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
pkl.dump(all_data, f)
if __name__ == '__main__':
def to_torch_dtype(dt):
if dt == "int8":
return torch.int8
if dt == "fp8":
return torch.float8_e4m3fn
raise ValueError("unsupported dtype")
parser = argparse.ArgumentParser(
description="""
Benchmark Cutlass GEMM.
To run square GEMMs:
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['int8', 'fp8']")
subparsers = parser.add_subparsers(dest="cmd")
square_parser = subparsers.add_parser("square_bench")
square_parser.add_argument("--dim-start", type=int, required=True)
square_parser.add_argument("--dim-end", type=int, required=True)
square_parser.add_argument("--dim-increment", type=int, required=True)
square_parser.set_defaults(func=run_square_bench)
range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--dim-start", type=int, required=True)
range_parser.add_argument("--dim-end", type=int, required=True)
range_parser.add_argument("--dim-increment", type=int, required=True)
range_parser.add_argument("--m-constant", type=int, default=None)
range_parser.add_argument("--n-constant", type=int, default=None)
range_parser.add_argument("--k-constant", type=int, default=None)
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument("--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys())
model_parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
model_parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
args.func(args)

View File

@@ -0,0 +1,37 @@
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES = {
"mistralai/Mistral-7B-v0.1": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-7b-hf": [
([4096, 12288], 1),
([4096, 4096], 0),
([4096, 22016], 1),
([11008, 4096], 0),
],
"meta-llama/Llama-2-13b-hf": [
([5120, 15360], 1),
([5120, 5120], 0),
([5120, 27648], 1),
([13824, 5120], 0),
],
"meta-llama/Llama-2-70b-hf": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
}

View File

@@ -6,7 +6,7 @@ from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm._C import ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.aqlm import ( from vllm.model_executor.layers.quantization.aqlm import (
dequantize_weight, generic_dequantize_gemm, get_int_dtype, dequantize_weight, generic_dequantize_gemm, get_int_dtype,
optimized_dequantize_gemm) optimized_dequantize_gemm)

View File

@@ -0,0 +1,233 @@
import argparse
import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MarlinWorkspace, marlin_24_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
size_m, size_k, size_n):
label = "Quant Matmul"
sub_label = ("{}, act={} k_full={}, b={}, g={}, "
"MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits,
group_size, size_m, size_k, size_n))
print(f"Testing: {sub_label}")
a = torch.randn(size_m, size_k).to(torch.half).cuda()
b = torch.rand(size_k, size_n).to(torch.half).cuda()
a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())
# Marlin quant
(
marlin_w_ref,
marlin_q_w,
marlin_s,
marlin_g_idx,
marlin_sort_indices,
marlin_rand_perm,
) = marlin_quantize(b, num_bits, group_size, act_order)
# Marlin_24 quant
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
# GPTQ quant
(w_ref, q_w, s, g_idx,
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx"
# so that group ids are increasing
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
if act_order:
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
# Prepare
marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL)
globals = {
# Gen params
"num_bits": num_bits,
"group_size": group_size,
"size_m": size_m,
"size_n": size_n,
"size_k": size_k,
"a": a,
"a_tmp": a_tmp,
# Marlin params
"marlin_w_ref": marlin_w_ref,
"marlin_q_w": marlin_q_w,
"marlin_s": marlin_s,
"marlin_g_idx": marlin_g_idx,
"marlin_sort_indices": marlin_sort_indices,
"marlin_rand_perm": marlin_rand_perm,
"marlin_workspace": marlin_workspace,
"is_k_full": is_k_full,
# Marlin_24 params
"marlin_24_w_ref": marlin_24_w_ref,
"marlin_24_q_w_comp": marlin_24_q_w_comp,
"marlin_24_meta": marlin_24_meta,
"marlin_24_s": marlin_24_s,
"marlin_24_workspace": marlin_24_workspace,
# GPTQ params
"q_w_gptq": q_w_gptq,
"repack_sort_indices": repack_sort_indices,
# Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
"gptq_marlin_repack": ops.gptq_marlin_repack,
}
min_run_time = 1
# Warmup pytorch
for i in range(5):
torch.matmul(a, marlin_w_ref)
results.append(
benchmark.Timer(
stmt="torch.matmul(a, marlin_w_ref)",
globals=globals,
label=label,
sub_label=sub_label,
description="pytorch_gemm",
).blocked_autorange(min_run_time=min_run_time))
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm",
).blocked_autorange(min_run_time=min_run_time))
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_24_gemm",
).blocked_autorange(min_run_time=min_run_time))
results.append(
benchmark.Timer(
stmt=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_repack",
).blocked_autorange(min_run_time=min_run_time))
def main(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
results = []
for model in args.models:
for layer in WEIGHT_SHAPES[model]:
size_k = layer[0]
size_n = layer[1]
if len(args.limit_k) > 0 and size_k not in args.limit_k:
continue
if len(args.limit_n) > 0 and size_n not in args.limit_n:
continue
for act_order in ACT_ORDER_OPTS:
if len(args.limit_act_order
) > 0 and act_order not in args.limit_act_order:
continue
for is_k_full in K_FULL_OPTS:
if len(args.limit_k_full
) > 0 and is_k_full not in args.limit_k_full:
continue
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
if len(args.limit_num_bits
) > 0 and num_bits not in args.limit_num_bits:
continue
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
if len(
args.limit_group_size
) > 0 and group_size not in args.limit_group_size:
continue
# For act_order, the group_size must be less than
# size_k
if act_order and (group_size == size_k
or group_size == -1):
continue
for size_m in args.batch_sizes:
bench_run(results, model, act_order, is_k_full,
num_bits, group_size, size_m, size_k,
size_n)
compare = benchmark.Compare(results)
compare.print()
# For quick benchmarking use:
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
#
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches")
parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
args = parser.parse_args()
main(args)

View File

@@ -1,182 +0,0 @@
import json
import os
import sys
import torch
import torch.nn.functional as F
import triton
from vllm.model_executor.layers.fused_moe import (fused_moe,
get_config_file_name)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def main():
method = fused_moe
for bs in [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
]:
run_grid(bs, method=method)
def run_grid(bs, method):
d_model = 4096
num_total_experts = 8
top_k = 2
tp_size = 2
model_intermediate_size = 14336
num_layers = 32
num_calls = 100
num_warmup_trials = 1
num_trials = 1
configs = []
if bs <= 16:
BLOCK_SIZES_M = [16]
elif bs <= 32:
BLOCK_SIZES_M = [16, 32]
elif bs <= 64:
BLOCK_SIZES_M = [16, 32, 64]
elif bs <= 128:
BLOCK_SIZES_M = [16, 32, 64, 128]
else:
BLOCK_SIZES_M = [16, 32, 64, 128, 256]
for block_size_n in [32, 64, 128, 256]:
for block_size_m in BLOCK_SIZES_M:
for block_size_k in [64, 128, 256]:
for group_size_m in [1, 16, 32, 64]:
for num_warps in [4, 8]:
configs.append({
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": 4,
})
best_config = None
best_time_us = 1e20
for config in configs:
print(f'{tp_size=} {bs=}')
print(f'{config}')
# warmup
print('warming up')
try:
for _ in range(num_warmup_trials):
run_timing(
num_calls=num_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
)
except triton.runtime.autotuner.OutOfResources:
continue
# trial
print('benchmarking')
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
)
kernel_dur_us = 1000 * kernel_dur_ms
model_dur_ms = kernel_dur_ms * num_layers
if kernel_dur_us < best_time_us:
best_config = config
best_time_us = kernel_dur_us
print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
f'{d_model=} {model_intermediate_size=} {num_layers=}')
print("best_time_us", best_time_us)
print("best_config", best_config)
# holds Dict[str, Dict[str, int]]
filename = get_config_file_name(num_total_experts,
model_intermediate_size // tp_size)
print(f"writing config to file {filename}")
existing_content = {}
if os.path.exists(filename):
with open(filename, "r") as f:
existing_content = json.load(f)
existing_content[str(bs)] = best_config
with open(filename, "w") as f:
json.dump(existing_content, f, indent=4)
f.write("\n")
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
top_k: int, tp_size: int, model_intermediate_size: int, method,
config) -> float:
shard_intermediate_size = model_intermediate_size // tp_size
hidden_states = torch.rand(
(bs, d_model),
device="cuda:0",
dtype=torch.bfloat16,
)
ws = torch.rand(
(num_total_experts, 2 * shard_intermediate_size, d_model),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
w2s = torch.rand(
(num_total_experts, d_model, shard_intermediate_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
gating_output = F.softmax(torch.rand(
(num_calls, bs, num_total_experts),
device=hidden_states.device,
dtype=torch.float32,
),
dim=-1)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(num_calls):
hidden_states = method(
hidden_states=hidden_states,
w1=ws,
w2=w2s,
gating_output=gating_output[i],
topk=2,
renormalize=True,
inplace=True,
override_config=config,
)
end_event.record()
end_event.synchronize()
dur_ms = start_event.elapsed_time(end_event) / num_calls
return dur_ms
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,322 @@
import argparse
import time
from datetime import datetime
from typing import Any, Dict, List, Tuple
import ray
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import *
def benchmark_config(
config: Dict[str, int],
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
num_iters: int = 100,
) -> float:
init_dtype = torch.float16 if use_fp8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1 = torch.randn(num_experts,
shard_intermediate_size,
hidden_size,
dtype=init_dtype)
w2 = torch.randn(num_experts,
hidden_size,
shard_intermediate_size // 2,
dtype=init_dtype)
gating_output = torch.randn(num_iters,
num_tokens,
num_experts,
dtype=torch.float32)
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if use_fp8:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
def prepare(i: int):
input_gating.copy_(gating_output[i])
def run():
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
override_config=config,
use_fp8=use_fp8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
# JIT compilation & warmup
run()
torch.cuda.synchronize()
# Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
for _ in range(10):
run()
torch.cuda.synchronize()
# Warmup
for _ in range(5):
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies = []
for i in range(num_iters):
prepare(i)
torch.cuda.synchronize()
start_event.record()
graph.replay()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
graph.reset()
return avg
def get_configs_compute_bound() -> List[Dict[str, int]]:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
configs = []
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append({
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
})
return configs
@ray.remote(num_gpus=1)
class BenchmarkWorker:
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(seed)
self.seed = seed
def benchmark(
self,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(self.seed)
dtype_str = "float8" if use_fp8 else None
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
dtype_str)
if op_config is None:
config = get_default_config(num_tokens, num_experts,
shard_intermediate_size, hidden_size,
topk, dtype_str)
else:
config = op_config[min(op_config.keys(),
key=lambda x: abs(x - num_tokens))]
kernel_time = benchmark_config(config, num_tokens, num_experts,
shard_intermediate_size, hidden_size,
topk, dtype, use_fp8)
return config, kernel_time
def tune(
self,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
search_space: List[Dict[str, int]],
) -> Dict[str, int]:
best_config = None
best_time = float("inf")
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8,
num_iters=10)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time:
best_time = kernel_time
best_config = config
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
return best_config
def sort_config(config: Dict[str, int]) -> Dict[str, int]:
return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
}
def save_configs(
configs: Dict[int, Dict[str, int]],
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
) -> None:
dtype_str = "float8" if use_fp8 else None
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
dtype_str)
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def main(args: argparse.Namespace):
print(args)
config = AutoConfig.from_pretrained(args.model)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral.
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
dtype = config.torch_dtype
use_fp8 = args.dtype == "fp8"
if args.batch_size is None:
batch_sizes = [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
]
else:
batch_sizes = [args.batch_size]
ray.init()
num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
outputs = []
worker_idx = 0
for input_args in inputs:
worker = workers[worker_idx]
worker_method = getattr(worker, method)
output = worker_method.remote(*input_args)
outputs.append(output)
worker_idx = (worker_idx + 1) % num_gpus
return ray.get(outputs)
if args.tune:
search_space = get_configs_compute_bound()
print(f"Start tuning over {len(search_space)} configurations...")
start = time.time()
configs = _distribute(
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8, search_space)
for batch_size in batch_sizes])
best_configs = {
M: sort_config(config)
for M, config in zip(batch_sizes, configs)
}
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8)
end = time.time()
print(f"Tuning took {end - start:.2f} seconds")
else:
outputs = _distribute("benchmark",
[(batch_size, E, shard_intermediate_size,
hidden_size, topk, dtype, use_fp8)
for batch_size in batch_sizes])
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}, config: {config}")
print(f"Kernel time: {kernel_time:.2f} us")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
parser.add_argument("--tp-size", "-tp", type=int, default=2)
parser.add_argument("--dtype",
type=str,
choices=["auto", "fp8"],
default="auto")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")
args = parser.parse_args()
main(args)

View File

@@ -16,7 +16,7 @@ PARTITION_SIZE = 512
def main( def main(
version: str, version: str,
num_seqs: int, num_seqs: int,
context_len: int, seq_len: int,
num_query_heads: int, num_query_heads: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
@@ -48,12 +48,12 @@ def main(
dtype=torch.float, dtype=torch.float,
device=device) device=device)
context_lens = [context_len for _ in range(num_seqs)] seq_lens = [seq_len for _ in range(num_seqs)]
max_context_len = max(context_lens) max_seq_len = max(seq_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device) seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
# Create the block tables. # Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = [] block_tables = []
for _ in range(num_seqs): for _ in range(num_seqs):
block_table = [ block_table = [
@@ -77,8 +77,7 @@ def main(
# Prepare for the paged attention kernel. # Prepare for the paged attention kernel.
output = torch.empty_like(query) output = torch.empty_like(query)
if version == "v2": if version == "v2":
num_partitions = ((max_context_len + PARTITION_SIZE - 1) // num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
PARTITION_SIZE)
tmp_output = torch.empty( tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size), size=(num_seqs, num_query_heads, num_partitions, head_size),
dtype=output.dtype, dtype=output.dtype,
@@ -110,9 +109,9 @@ def main(
num_kv_heads, num_kv_heads,
scale, scale,
block_tables, block_tables,
context_lens, seq_lens,
block_size, block_size,
max_context_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, kv_scale,
@@ -129,9 +128,9 @@ def main(
num_kv_heads, num_kv_heads,
scale, scale,
block_tables, block_tables,
context_lens, seq_lens,
block_size, block_size,
max_context_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, kv_scale,
@@ -166,12 +165,12 @@ if __name__ == '__main__':
choices=["v1", "v2"], choices=["v1", "v2"],
default="v2") default="v2")
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--context-len", type=int, default=4096) parser.add_argument("--seq_len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-query-heads", type=int, default=64)
parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument("--head-size",
type=int, type=int,
choices=[64, 80, 96, 112, 128, 256], choices=[64, 80, 96, 112, 128, 192, 256],
default=128) default=128)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true") parser.add_argument("--use-alibi", action="store_true")
@@ -184,13 +183,11 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", "--kv-cache-dtype",
type=str, type=str,
choices=["auto", "fp8"], choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
default="auto", default="auto",
help= help="Data type for kv cache storage. If 'auto', will use model "
'Data type for kv cache storage. If "auto", will use model data type. ' "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
'FP8_E5M2 (without scaling) is only supported on cuda version greater ' "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
@@ -199,7 +196,7 @@ if __name__ == '__main__':
main( main(
version=args.version, version=args.version,
num_seqs=args.batch_size, num_seqs=args.batch_size,
context_len=args.context_len, seq_len=args.seq_len,
num_query_heads=args.num_query_heads, num_query_heads=args.num_query_heads,
num_kv_heads=args.num_kv_heads, num_kv_heads=args.num_kv_heads,
head_size=args.head_size, head_size=args.head_size,

View File

@@ -93,7 +93,7 @@ if __name__ == '__main__':
parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument("--head-size",
type=int, type=int,
choices=[64, 80, 96, 112, 128, 256], choices=[64, 80, 96, 112, 128, 192, 256],
default=128) default=128)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype", parser.add_argument("--dtype",

View File

@@ -0,0 +1,75 @@
WEIGHT_SHAPES = {
"ideal": [[4 * 256 * 32, 256 * 32]],
"mistralai/Mistral-7B-v0.1/TP1": [
[4096, 6144],
[4096, 4096],
[4096, 28672],
[14336, 4096],
],
"mistralai/Mistral-7B-v0.1/TP2": [
[4096, 3072],
[2048, 4096],
[4096, 14336],
[7168, 4096],
],
"mistralai/Mistral-7B-v0.1/TP4": [
[4096, 1536],
[1024, 4096],
[4096, 7168],
[3584, 4096],
],
"meta-llama/Llama-2-7b-hf/TP1": [
[4096, 12288],
[4096, 4096],
[4096, 22016],
[11008, 4096],
],
"meta-llama/Llama-2-7b-hf/TP2": [
[4096, 6144],
[2048, 4096],
[4096, 11008],
[5504, 4096],
],
"meta-llama/Llama-2-7b-hf/TP4": [
[4096, 3072],
[1024, 4096],
[4096, 5504],
[2752, 4096],
],
"meta-llama/Llama-2-13b-hf/TP1": [
[5120, 15360],
[5120, 5120],
[5120, 27648],
[13824, 5120],
],
"meta-llama/Llama-2-13b-hf/TP2": [
[5120, 7680],
[2560, 5120],
[5120, 13824],
[6912, 5120],
],
"meta-llama/Llama-2-13b-hf/TP4": [
[5120, 3840],
[1280, 5120],
[5120, 6912],
[3456, 5120],
],
"meta-llama/Llama-2-70b-hf/TP1": [
[8192, 10240],
[8192, 8192],
[8192, 57344],
[28672, 8192],
],
"meta-llama/Llama-2-70b-hf/TP2": [
[8192, 5120],
[4096, 8192],
[8192, 28672],
[14336, 8192],
],
"meta-llama/Llama-2-70b-hf/TP4": [
[8192, 2560],
[2048, 8192],
[8192, 14336],
[7168, 8192],
],
}

View File

@@ -4,7 +4,7 @@ PORT=8000
MODEL=$1 MODEL=$1
TOKENS=$2 TOKENS=$2
docker run --gpus all --shm-size 1g -p $PORT:80 \ docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \
-v $PWD/data:/data \ -v $PWD/data:/data \
ghcr.io/huggingface/text-generation-inference:1.4.0 \ ghcr.io/huggingface/text-generation-inference:1.4.0 \
--model-id $MODEL \ --model-id $MODEL \

View File

@@ -0,0 +1,63 @@
import argparse
import cProfile
import pstats
from vllm import LLM, SamplingParams
# A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
] * 1000
LONG_PROMPT = ' '.join(LONG_PROMPT)
def main(args):
llm = LLM(
model=args.model,
enforce_eager=True,
enable_prefix_caching=True,
tensor_parallel_size=args.tensor_parallel_size,
use_v2_block_manager=args.use_v2_block_manager,
)
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
profiler = cProfile.Profile()
print("------warm up------")
for i in range(3):
output = llm.generate(LONG_PROMPT, sampling_params)
print(output[0].outputs[0].text)
print("------start generating------")
for i in range(3):
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)',
globals(), locals())
# analyze the runtime of hashing function
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
total_time = 0
total_calls = 0
for func in stats.stats:
if 'hash_of_block' in func[2]:
total_time = stats.stats[func][3]
total_calls = stats.stats[func][0]
percentage = (total_time / stats.total_tt) * 100
print(f"Hashing took {total_time:.2f} seconds,"
f"{percentage:.2f}% of the total runtime.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Benchmark the performance of hashing function in'
'automatic prefix caching.')
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
args = parser.parse_args()
main(args)

View File

@@ -73,7 +73,7 @@ set(VLLM_EXT_SRC
"csrc/cpu/cache.cpp" "csrc/cpu/cache.cpp"
"csrc/cpu/layernorm.cpp" "csrc/cpu/layernorm.cpp"
"csrc/cpu/pos_encoding.cpp" "csrc/cpu/pos_encoding.cpp"
"csrc/cpu/pybind.cpp") "csrc/cpu/torch_bindings.cpp")
define_gpu_extension_target( define_gpu_extension_target(
_C _C
@@ -81,10 +81,10 @@ define_gpu_extension_target(
LANGUAGE CXX LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC} SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS} COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3
WITH_SOABI WITH_SOABI
) )
add_custom_target(default) add_custom_target(default)
message(STATUS "Enabling C extension.") message(STATUS "Enabling C extension.")
add_dependencies(default _C) add_dependencies(default _C)

View File

@@ -5,7 +5,7 @@
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS) macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
file(REAL_PATH ${EXECUTABLE} EXECUTABLE) file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
set(Python_EXECUTABLE ${EXECUTABLE}) set(Python_EXECUTABLE ${EXECUTABLE})
find_package(Python COMPONENTS Interpreter Development.Module) find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
if (NOT Python_FOUND) if (NOT Python_FOUND)
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
endif() endif()
@@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"Failed to determine torch nvcc compiler flags") "Failed to determine torch nvcc compiler flags")
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") list(APPEND GPU_FLAGS "-DENABLE_FP8")
endif() endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
list(REMOVE_ITEM GPU_FLAGS list(REMOVE_ITEM GPU_FLAGS
@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS list(APPEND GPU_FLAGS
"-DUSE_ROCM" "-DUSE_ROCM"
"-DENABLE_FP8_E4M3" "-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__" "-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc") "-fno-gpu-rdc")
@@ -294,6 +294,7 @@ endmacro()
# INCLUDE_DIRECTORIES <dirs> - Extra include directories. # INCLUDE_DIRECTORIES <dirs> - Extra include directories.
# LIBRARIES <libraries> - Extra link libraries. # LIBRARIES <libraries> - Extra link libraries.
# WITH_SOABI - Generate library with python SOABI suffix name. # WITH_SOABI - Generate library with python SOABI suffix name.
# USE_SABI <version> - Use python stable api <version>
# #
# Note: optimization level/debug info is set via cmake build type. # Note: optimization level/debug info is set via cmake build type.
# #
@@ -301,7 +302,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
cmake_parse_arguments(PARSE_ARGV 1 cmake_parse_arguments(PARSE_ARGV 1
GPU GPU
"WITH_SOABI" "WITH_SOABI"
"DESTINATION;LANGUAGE" "DESTINATION;LANGUAGE;USE_SABI"
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES") "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
# Add hipify preprocessing step when building with HIP/ROCm. # Add hipify preprocessing step when building with HIP/ROCm.
@@ -315,7 +316,11 @@ function (define_gpu_extension_target GPU_MOD_NAME)
set(GPU_WITH_SOABI) set(GPU_WITH_SOABI)
endif() endif()
Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI}) if (GPU_USE_SABI)
Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
else()
Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
endif()
if (GPU_LANGUAGE STREQUAL "HIP") if (GPU_LANGUAGE STREQUAL "HIP")
# Make this target dependent on the hipify preprocessor step. # Make this target dependent on the hipify preprocessor step.

View File

@@ -64,6 +64,7 @@ DEFAULT_CONDA_PATTERNS = {
"triton", "triton",
"optree", "optree",
"nccl", "nccl",
"transformers",
} }
DEFAULT_PIP_PATTERNS = { DEFAULT_PIP_PATTERNS = {
@@ -75,6 +76,7 @@ DEFAULT_PIP_PATTERNS = {
"optree", "optree",
"onnx", "onnx",
"nccl", "nccl",
"transformers",
} }
@@ -601,6 +603,11 @@ Versions of relevant libraries:
{conda_packages} {conda_packages}
""".strip() """.strip()
# both the above code and the following code use `strip()` to
# remove leading/trailing whitespaces, so we need to add a newline
# in between to separate the two sections
env_info_fmt += "\n"
env_info_fmt += """ env_info_fmt += """
ROCM Version: {rocm_version} ROCM Version: {rocm_version}
Neuron SDK Version: {neuron_sdk_version} Neuron SDK Version: {neuron_sdk_version}

View File

@@ -1,5 +1,5 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/all.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <cmath> #include <cmath>
@@ -63,31 +63,25 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \ input.scalar_type(), "act_and_mul_kernel", [&] { \
"act_and_mul_kernel", \ vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
[&] { \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \ input.data_ptr<scalar_t>(), d); \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
}); });
void silu_and_mul( void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
} }
void gelu_and_mul( void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
} }
void gelu_tanh_and_mul( void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
@@ -118,14 +112,10 @@ __global__ void activation_kernel(
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
input.scalar_type(), \ vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
"activation_kernel", \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
[&] { \ input.data_ptr<scalar_t>(), d); \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
}); });
namespace vllm { namespace vllm {
@@ -140,21 +130,20 @@ __device__ __forceinline__ T gelu_new_kernel(const T& x) {
template <typename T> template <typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) { __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
const float f = (float)x; const float f = (float)x;
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); const T t =
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
return ((T)0.5) * x * (((T)1.0) + t); return ((T)0.5) * x * (((T)1.0) + t);
} }
} // namespace vllm } // namespace vllm
void gelu_new( void gelu_new(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d] torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
} }
void gelu_fast( void gelu_fast(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d] torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);

View File

@@ -1,5 +1,6 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,6 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *

View File

@@ -1,6 +1,8 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
} }
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
#else #else
@@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf
#endif #endif
} }
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
#else #else

View File

@@ -1,6 +1,8 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@@ -130,7 +132,9 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
} tmp; } tmp;
#ifndef USE_ROCM #ifndef USE_ROCM
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
: "=r"(tmp.u32)
: "f"(f.y), "f"(f.x));
#else #else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
@@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d; uint32_t d;
#ifndef USE_ROCM #ifndef USE_ROCM
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
#else #else
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
: "=v"(d)
: "v"(a), "v"(b), "v"(c));
#endif #endif
return d; return d;
} }
@@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) {
} }
// From float16 to float32. // From float16 to float32.
inline __device__ float to_float(uint16_t u) { inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
return half_to_float(u);
}
inline __device__ float2 to_float(uint32_t u) { inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
return half2_to_float2(u);
}
inline __device__ Float4_ to_float(uint2 u) { inline __device__ Float4_ to_float(uint2 u) {
Float4_ tmp; Float4_ tmp;
@@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) {
} }
// Zero-out a variable. // Zero-out a variable.
inline __device__ void zero(uint16_t& dst) { inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
dst = uint16_t(0);
}
} // namespace vllm } // namespace vllm

View File

@@ -1,6 +1,8 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@@ -66,9 +68,7 @@ struct FloatVec<float4> {
}; };
// Vector addition. // Vector addition.
inline __device__ float add(float a, float b) { inline __device__ float add(float a, float b) { return a + b; }
return a + b;
}
inline __device__ float2 add(float2 a, float2 b) { inline __device__ float2 add(float2 a, float2 b) {
float2 c; float2 c;
@@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) {
} }
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ float fma(float a, float b, float c) { inline __device__ float fma(float a, float b, float c) { return a * b + c; }
return a * b + c;
}
inline __device__ float2 fma(float2 a, float2 b, float2 c) { inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d; float2 d;
@@ -208,9 +206,7 @@ inline __device__ float sum(Float8_ v) {
} }
// Vector dot product. // Vector dot product.
inline __device__ float dot(float a, float b) { inline __device__ float dot(float a, float b) { return a * b; }
return a * b;
}
inline __device__ float dot(float2 a, float2 b) { inline __device__ float dot(float2 a, float2 b) {
float2 c = mul<float2, float2, float2>(a, b); float2 c = mul<float2, float2, float2>(a, b);
@@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) {
} }
// From float to float. // From float to float.
inline __device__ void from_float(float& dst, float src) { inline __device__ void from_float(float& dst, float src) { dst = src; }
dst = src;
}
inline __device__ void from_float(float2& dst, float2 src) { inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
dst = src;
}
inline __device__ void from_float(float4& dst, float4 src) { inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
dst = src;
}
// From float to float. // From float to float.
inline __device__ float to_float(float u) { inline __device__ float to_float(float u) { return u; }
return u;
}
inline __device__ float2 to_float(float2 u) { inline __device__ float2 to_float(float2 u) { return u; }
return u;
}
inline __device__ float4 to_float(float4 u) { inline __device__ float4 to_float(float4 u) { return u; }
return u;
}
inline __device__ Float4_ to_float(Float4_ u) { inline __device__ Float4_ to_float(Float4_ u) { return u; }
return u;
}
inline __device__ Float8_ to_float(Float8_ u) { inline __device__ Float8_ to_float(Float8_ u) { return u; }
return u;
}
// Zero-out a variable. // Zero-out a variable.
inline __device__ void zero(float& dst) { inline __device__ void zero(float& dst) { dst = 0.f; }
dst = 0.f;
}
} // namespace vllm } // namespace vllm

View File

@@ -3,14 +3,21 @@
#include "attention_generic.cuh" #include "attention_generic.cuh"
#include <stdint.h> #include <stdint.h>
#ifdef ENABLE_FP8_E5M2 #ifdef ENABLE_FP8
#ifndef USE_ROCM
#include <cuda_fp8.h> #include <cuda_fp8.h>
#endif #endif // USE_ROCM
#endif // ENABLE_FP8
namespace vllm { namespace vllm {
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
// fp8 vector types for quantization of kv cache
enum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
};
// fp8 vector types for quantization of kv cache
template <> template <>
struct Vec<uint8_t, 1> { struct Vec<uint8_t, 1> {
using Type = uint8_t; using Type = uint8_t;
@@ -30,6 +37,5 @@ template<>
struct Vec<uint8_t, 8> { struct Vec<uint8_t, 8> {
using Type = uint2; using Type = uint2;
}; };
#endif // ENABLE_FP8_E5M2
} // namespace vllm } // namespace vllm

View File

@@ -1,30 +1,32 @@
#pragma once #pragma once
#include <torch/extension.h> #include <torch/all.h>
#include <map> #include <map>
#include <vector> #include <vector>
void swap_blocks( void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
torch::Tensor& src, const torch::Tensor& block_mapping);
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping);
void copy_blocks( // Note: the key_caches and value_caches vectors are constant but
std::vector<torch::Tensor>& key_caches, // not the Tensors they contain. The vectors need to be const refs
std::vector<torch::Tensor>& value_caches, // in order to satisfy pytorch's C++ operator registration code.
const std::map<int64_t, std::vector<int64_t>>& block_mapping); void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping);
void reshape_and_cache( void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& value, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
const double kv_scale);
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype);
const float kv_scale);
// Just for unittest // Just for unittest
void convert_fp8( void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype);
torch::Tensor& dst_cache);

View File

@@ -1,13 +1,14 @@
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#if defined(ENABLE_FP8_E5M2)
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" #ifdef USE_ROCM
#elif defined(ENABLE_FP8_E4M3) #include "quantization/fp8/amd/quant_utils.cuh"
#include "quantization/fp8/amd_detail/quant_utils.cuh" #else
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#include <algorithm> #include <algorithm>
@@ -20,16 +21,13 @@
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#endif #endif
void swap_blocks( void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
torch::Tensor& src, const torch::Tensor& block_mapping) {
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping) {
torch::Device src_device = src.device(); torch::Device src_device = src.device();
torch::Device dst_device = dst.device(); torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type; cudaMemcpyKind memcpy_type;
if (src_device.is_cuda() && dst_device.is_cuda()) { if (src_device.is_cuda() && dst_device.is_cuda()) {
TORCH_CHECK( TORCH_CHECK(src_device.index() == dst_device.index(),
src_device.index() == dst_device.index(),
"src and dst must be on the same GPU"); "src and dst must be on the same GPU");
memcpy_type = cudaMemcpyDeviceToDevice; memcpy_type = cudaMemcpyDeviceToDevice;
} else if (src_device.is_cuda() && dst_device.is_cpu()) { } else if (src_device.is_cuda() && dst_device.is_cpu()) {
@@ -40,24 +38,27 @@ void swap_blocks(
TORCH_CHECK(false, "Invalid device combination"); TORCH_CHECK(false, "Invalid device combination");
} }
// NOTE(youkaichao): keep in mind that `block_mapping` should be
// a cpu tensor, otherwise every `item` call will require a gpu-cpu
// synchronization.
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
char* src_ptr = static_cast<char*>(src.data_ptr()); char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr()); char* dst_ptr = static_cast<char*>(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); const at::cuda::OptionalCUDAGuard device_guard(
src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large. // NOTE(woosuk): This can be slow if the number of blocks is large.
for (const auto& pair : block_mapping) { const int64_t num_blocks = block_mapping.size(0);
int64_t src_block_number = pair.first; for (size_t i = 0; i < num_blocks; i++) {
int64_t dst_block_number = pair.second; int64_t src_block_number = block_mapping[i][0].item<int64_t>();
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
int64_t src_offset = src_block_number * block_size_in_bytes; int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync( cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
dst_ptr + dst_offset, block_size_in_bytes, memcpy_type, stream);
src_ptr + src_offset,
block_size_in_bytes,
memcpy_type,
stream);
} }
} }
@@ -65,8 +66,7 @@ namespace vllm {
// Grid: (num_layers, num_pairs) // Grid: (num_layers, num_pairs)
template <typename scalar_t> template <typename scalar_t>
__global__ void copy_blocks_kernel( __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
int64_t* key_cache_ptrs,
int64_t* value_cache_ptrs, int64_t* value_cache_ptrs,
const int64_t* __restrict__ block_mapping, const int64_t* __restrict__ block_mapping,
const int numel_per_block) { const int numel_per_block) {
@@ -74,7 +74,8 @@ __global__ void copy_blocks_kernel(
const int pair_idx = blockIdx.y; const int pair_idx = blockIdx.y;
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]); scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]); scalar_t* value_cache =
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
int64_t src_block_number = block_mapping[2 * pair_idx]; int64_t src_block_number = block_mapping[2 * pair_idx];
int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
@@ -94,10 +95,12 @@ __global__ void copy_blocks_kernel(
} // namespace vllm } // namespace vllm
void copy_blocks( // Note: the key_caches and value_caches vectors are constant but
std::vector<torch::Tensor>& key_caches, // not the Tensors they contain. The vectors need to be const refs
std::vector<torch::Tensor>& value_caches, // in order to satisfy pytorch's C++ operator registration code.
const std::map<int64_t, std::vector<int64_t>>& block_mapping) { void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping) {
int num_layers = key_caches.size(); int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) { if (num_layers == 0) {
@@ -111,29 +114,23 @@ void copy_blocks(
int64_t key_cache_ptrs[num_layers]; int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers]; int64_t value_cache_ptrs[num_layers];
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr()); key_cache_ptrs[layer_idx] =
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr()); reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
} }
// Create block mapping array.
std::vector<int64_t> block_mapping_vec; // block_mapping is a 2D tensor with shape (num_pairs, 2).
for (const auto& pair : block_mapping) { int num_pairs = block_mapping.size(0);
int64_t src_block_number = pair.first;
for (int64_t dst_block_number : pair.second) {
block_mapping_vec.push_back(src_block_number);
block_mapping_vec.push_back(dst_block_number);
}
}
int64_t* block_mapping_array = block_mapping_vec.data();
int num_pairs = block_mapping_vec.size() / 2;
// Move the data structures to the GPU. // Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU. // NOTE: This synchronizes the CPU and GPU.
torch::Tensor key_cache_ptrs_tensor = torch::from_blob( torch::Tensor key_cache_ptrs_tensor =
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
torch::Tensor value_cache_ptrs_tensor = torch::from_blob( .to(cache_device);
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::Tensor value_cache_ptrs_tensor =
torch::Tensor block_mapping_tensor = torch::from_blob( torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device); .to(cache_device);
// Launch the kernel. // Launch the kernel.
const int numel_per_block = key_caches[0][0].numel(); const int numel_per_block = key_caches[0][0].numel();
@@ -146,26 +143,23 @@ void copy_blocks(
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(), key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(), value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping_tensor.data_ptr<int64_t>(), block_mapping.data_ptr<int64_t>(), numel_per_block);
numel_per_block);
})); }));
} }
namespace vllm { namespace vllm {
template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache> template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel( __global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] // block_size, x]
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
// block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens] const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int key_stride, const int value_stride, const int num_heads,
const int value_stride, const int head_size, const int block_size, const int x,
const int num_heads,
const int head_size,
const int block_size,
const int x,
const float kv_scale) { const float kv_scale) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
@@ -187,60 +181,84 @@ __global__ void reshape_and_cache_kernel(
const int x_idx = head_offset / x; const int x_idx = head_offset / x;
const int x_offset = head_offset % x; const int x_offset = head_offset % x;
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x const int64_t tgt_key_idx =
+ head_idx * (head_size / x) * block_size * x block_idx * num_heads * (head_size / x) * block_size * x +
+ x_idx * block_size * x head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
+ block_offset * x block_offset * x + x_offset;
+ x_offset; const int64_t tgt_value_idx =
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size block_idx * num_heads * head_size * block_size +
+ head_idx * head_size * block_size head_idx * head_size * block_size + head_offset * block_size +
+ head_offset * block_size block_offset;
+ block_offset;
scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx]; scalar_t tgt_value = value[src_value_idx];
if constexpr (is_fp8_kv_cache) { if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
#if defined(ENABLE_FP8_E5M2)
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
#elif defined(ENABLE_FP8_E4M3)
key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
#else
assert(false);
#endif
} else {
key_cache[tgt_key_idx] = tgt_key; key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value; value_cache[tgt_value_idx] = tgt_value;
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
} }
} }
} }
template <typename scalar_t>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
// head_size]
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
// head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int64_t tgt_value_idx = block_idx * block_stride +
block_offset * num_heads * head_size +
head_idx * head_size + head_offset;
k_cache[tgt_value_idx] = key[src_key_idx];
v_cache[tgt_value_idx] = value[src_value_idx];
}
}
} // namespace vllm } // namespace vllm
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ // KV_T is the stored data type of kv-cache.
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \ // CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \ reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \ reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), \ slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
key_stride, \ num_heads, head_size, block_size, x, kv_scale);
value_stride, \
num_heads, \
head_size, \
block_size, \
x, \
kv_scale);
void reshape_and_cache( void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor&
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype, const double kv_scale) {
const float kv_scale)
{
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);
@@ -254,65 +272,77 @@ void reshape_and_cache(
dim3 block(std::min(num_heads * head_size, 512)); dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (kv_cache_dtype == "auto") {
if (key.dtype() == at::ScalarType::Float) { DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE(float, float, false); CALL_RESHAPE_AND_CACHE)
} else if (key.dtype() == at::ScalarType::Half) {
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
} }
} else if (kv_cache_dtype == "fp8") {
if (key.dtype() == at::ScalarType::Float) { void reshape_and_cache_flash(
CALL_RESHAPE_AND_CACHE(float, uint8_t, true); torch::Tensor& key, // [num_tokens, num_heads, head_size]
} else if (key.dtype() == at::ScalarType::Half) { torch::Tensor& value, // [num_tokens, num_heads, head_size]
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true); torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
} else if (key.dtype() == at::ScalarType::BFloat16) { torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true); torch::Tensor& slot_mapping, // [num_tokens]
} const std::string& kv_cache_dtype) {
} else { // FIXME: only support auto datatype, does not support fp8
if (kv_cache_dtype != "auto") {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
} }
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = k_cache.size(1);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_stride = k_cache.stride(0);
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "reshape_and_cache_flash", [&] {
vllm::reshape_and_cache_flash_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
value_stride, num_heads, head_size, block_size);
});
} }
namespace vllm { namespace vllm {
template<typename Tout, typename Tin> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_kernel( __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache, Tout* __restrict__ dst_cache,
const float kv_scale,
const int64_t block_stride) { const int64_t block_stride) {
const int64_t block_idx = blockIdx.x; const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i; int64_t idx = block_idx * block_stride + i;
#if defined(ENABLE_FP8_E5M2) dst_cache[idx] =
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]); fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
#elif defined(ENABLE_FP8_E4M3)
dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
#else
assert(false);
#endif
} }
} }
} // namespace vllm } // namespace vllm
#define CALL_CONVERT_FP8(Tout, Tin) \ #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \ vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \ reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \ reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
block_stride);
void convert_fp8( // Only for testing.
torch::Tensor& src_cache, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
torch::Tensor& dst_cache) const double kv_scale, const std::string& kv_cache_dtype) {
{
torch::Device src_device = src_cache.device(); torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device(); torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
TORCH_CHECK( TORCH_CHECK(src_device.index() == dst_device.index(),
src_device.index() == dst_device.index(),
"src and dst must be on the same GPU"); "src and dst must be on the same GPU");
at::cuda::OptionalCUDAGuard device_guard(src_device); at::cuda::OptionalCUDAGuard device_guard(src_device);
@@ -323,17 +353,37 @@ void convert_fp8(
dim3 block(std::min(block_stride, int64_t(512))); dim3 block(std::min(block_stride, int64_t(512)));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (kv_cache_dtype == "auto") {
if (src_cache.dtype() == at::ScalarType::Float) { if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float); CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
} else if (src_cache.dtype() == at::ScalarType::Half) { } else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t); CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) { } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16); CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::Float) { } else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t); CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::Half) { } else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t); CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) { } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t); CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
}
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
}
} else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
} }
} }

View File

@@ -81,12 +81,10 @@ void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
input.scalar_type(), "silu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(silu_and_mul_impl) CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
activation_kernel<scalar_t, silu_act, true>(num_tokens, d, activation_kernel<scalar_t, silu_act, true>(
input.data_ptr<scalar_t>(), num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
}); });
} }
@@ -97,12 +95,10 @@ void gelu_and_mul(torch::Tensor &out, // [..., d]
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
input.scalar_type(), "gelu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d, activation_kernel<scalar_t, gelu_act, true>(
input.data_ptr<scalar_t>(), num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
}); });
} }

View File

@@ -2,7 +2,8 @@
namespace { namespace {
template <typename scalar_t> struct KernelVecType { template <typename scalar_t>
struct KernelVecType {
using q_load_vec_type = void; using q_load_vec_type = void;
using q_vec_type = void; using q_vec_type = void;
using k_load_vec_type = void; using k_load_vec_type = void;
@@ -11,7 +12,8 @@ template <typename scalar_t> struct KernelVecType {
using v_load_vec_type = void; using v_load_vec_type = void;
}; };
template <> struct KernelVecType<float> { template <>
struct KernelVecType<float> {
using q_load_vec_type = vec_op::FP32Vec4; using q_load_vec_type = vec_op::FP32Vec4;
using q_vec_type = vec_op::FP32Vec16; using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::FP32Vec16;
@@ -21,7 +23,8 @@ template <> struct KernelVecType<float> {
}; };
#ifdef __AVX512BF16__ #ifdef __AVX512BF16__
template <> struct KernelVecType<c10::BFloat16> { template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8; using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::BF16Vec32; using q_vec_type = vec_op::BF16Vec32;
using k_load_vec_type = vec_op::BF16Vec32; using k_load_vec_type = vec_op::BF16Vec32;
@@ -30,7 +33,8 @@ template <> struct KernelVecType<c10::BFloat16> {
using v_load_vec_type = vec_op::BF16Vec16; using v_load_vec_type = vec_op::BF16Vec16;
}; };
#else #else
template <> struct KernelVecType<c10::BFloat16> { template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8; using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16; using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16; using k_load_vec_type = vec_op::BF16Vec16;
@@ -67,14 +71,15 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
} }
template <typename T> template <typename T>
FORCE_INLINE std::pair<T, T> FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
reduceSoftmaxAlibi(T *data, const int size, const int capacity, const int capacity,
const float alibi_slope, const int start_index, const float alibi_slope,
const int context_len) { const int start_index,
data[0] += alibi_slope * (start_index - context_len + 1); const int seq_len) {
data[0] += alibi_slope * (start_index - seq_len + 1);
T max = data[0]; T max = data[0];
for (int i = 1; i < size; ++i) { for (int i = 1; i < size; ++i) {
T qk = data[i] + alibi_slope * (start_index + i - context_len + 1); T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1);
data[i] = qk; data[i] = qk;
max = max >= qk ? max : qk; max = max >= qk ? max : qk;
} }
@@ -215,17 +220,17 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
namespace { namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE> template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
struct paged_attention_v1_impl { struct paged_attention_v1_impl {
static void static void call(
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int* __restrict__ block_tables, // [num_seqs,
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] // max_num_blocks_per_seq]
const int *__restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
@@ -235,32 +240,31 @@ struct paged_attention_v1_impl {
static_assert(BLOCK_SIZE == 16); static_assert(BLOCK_SIZE == 16);
int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE;
int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0;
TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0);
const int parallel_work_item_num = omp_get_max_threads(); const int parallel_work_item_num = omp_get_max_threads();
size_t logits_bytes = size_t logits_bytes =
parallel_work_item_num * max_context_len_padded * sizeof(float); parallel_work_item_num * max_seq_len_padded * sizeof(float);
float* logits = (float*)std::aligned_alloc( float* logits = (float*)std::aligned_alloc(
64, logits_bytes); // Cacheline alignment for each context token. 64, logits_bytes); // Cacheline alignment for each context token.
// [parallel_work_item_num, max_context_len_padded] // [parallel_work_item_num, max_seq_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1) #pragma omp parallel for collapse(2) schedule(dynamic, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
int context_len = context_lens[seq_idx]; int seq_len = seq_lens[seq_idx];
const int* seq_block_table = const int* seq_block_table =
block_tables + max_num_blocks_per_seq * seq_idx; block_tables + max_num_blocks_per_seq * seq_idx;
const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv; const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t* __restrict__ q_vec_ptr = const scalar_t* __restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE; q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const int last_block_token_num = const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
context_len - (block_num - 1) * BLOCK_SIZE;
float* __restrict__ thread_block_logits = float* __restrict__ thread_block_logits =
logits + omp_get_thread_num() * max_context_len_padded; logits + omp_get_thread_num() * max_seq_len_padded;
// Compute logits // Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
@@ -278,12 +282,11 @@ struct paged_attention_v1_impl {
// Compute softmax // Compute softmax
if (alibi_slopes) { if (alibi_slopes) {
reduceSoftmaxAlibi(thread_block_logits, context_len, reduceSoftmaxAlibi(thread_block_logits, seq_len,
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
context_len); seq_len);
} else { } else {
reduceSoftmax(thread_block_logits, context_len, reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
block_num * BLOCK_SIZE);
} }
// Compute value // Compute value
@@ -340,7 +343,7 @@ struct paged_attention_v1_impl {
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \ paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads); num_heads);
@@ -348,8 +351,8 @@ template <typename T, int BLOCK_SIZE>
void paged_attention_v1_impl_launcher( void paged_attention_v1_impl_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) { const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
@@ -369,7 +372,7 @@ void paged_attention_v1_impl_launcher(
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr()); T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int *context_lens_ptr = context_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 64: case 64:
@@ -387,6 +390,9 @@ void paged_attention_v1_impl_launcher(
case 128: case 128:
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break; break;
case 192:
LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
break;
case 256: case 256:
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break; break;
@@ -399,7 +405,7 @@ void paged_attention_v1_impl_launcher(
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \ paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
context_lens, max_context_len, alibi_slopes); seq_lens, max_seq_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \
@@ -412,15 +418,18 @@ void paged_attention_v1_impl_launcher(
} }
} // namespace } // namespace
void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, void paged_attention_v1(
torch::Tensor &key_cache, torch::Tensor &value_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
int num_kv_heads, float scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor &block_tables, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
torch::Tensor &context_lens, int block_size, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
int max_context_len, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const c10::optional<torch::Tensor> &alibi_slopes, const int64_t blocksparse_local_blocks,
const std::string &kv_cache_dtype, float kv_scale) { const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] { [&] {
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
@@ -435,9 +444,10 @@ template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
struct paged_attention_v2_impl { struct paged_attention_v2_impl {
static void call( static void call(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads,
float // max_num_partitions]
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size] // max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
@@ -446,9 +456,9 @@ struct paged_attention_v2_impl {
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int* __restrict__ block_tables, // [num_seqs,
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] // max_num_blocks_per_seq]
const int *__restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
@@ -465,22 +475,20 @@ struct paged_attention_v2_impl {
for (int partition_idx = 0; partition_idx < max_num_partitions; for (int partition_idx = 0; partition_idx < max_num_partitions;
++partition_idx) { ++partition_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int start_token_idx = partition_idx * PARTITION_SIZE; const int start_token_idx = partition_idx * PARTITION_SIZE;
if (start_token_idx >= context_len) if (start_token_idx >= seq_len) continue;
continue;
const int partition_num = const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
const bool no_reduce = (partition_num == 1); const bool no_reduce = (partition_num == 1);
const int context_token_num = const int token_num =
(std::min(context_len, start_token_idx + PARTITION_SIZE) - (std::min(seq_len, start_token_idx + PARTITION_SIZE) -
start_token_idx); start_token_idx);
const int block_num = const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
(context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_token_num = const int last_block_token_num =
context_token_num - (block_num - 1) * BLOCK_SIZE; token_num - (block_num - 1) * BLOCK_SIZE;
const int* seq_block_table = block_tables + const int* seq_block_table = block_tables +
max_num_blocks_per_seq * seq_idx + max_num_blocks_per_seq * seq_idx +
start_token_idx / BLOCK_SIZE; start_token_idx / BLOCK_SIZE;
@@ -507,11 +515,11 @@ struct paged_attention_v2_impl {
std::pair<float, float> max_and_sum; std::pair<float, float> max_and_sum;
if (alibi_slopes) { if (alibi_slopes) {
max_and_sum = reduceSoftmaxAlibi( max_and_sum = reduceSoftmaxAlibi(
logits, context_token_num, block_num * BLOCK_SIZE, logits, token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, context_len); alibi_slopes[head_idx], start_token_idx, seq_len);
} else { } else {
max_and_sum = reduceSoftmax(logits, context_token_num, max_and_sum =
block_num * BLOCK_SIZE); reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
} }
auto&& [max_logit, exp_sum] = max_and_sum; auto&& [max_logit, exp_sum] = max_and_sum;
@@ -583,12 +591,11 @@ struct paged_attention_v2_impl {
#pragma omp parallel for collapse(2) schedule(static, 1) #pragma omp parallel for collapse(2) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int partition_num = const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1) continue;
continue;
reducePartitonSoftmax( reducePartitonSoftmax(
max_logits + seq_idx * num_heads * max_num_partitions + max_logits + seq_idx * num_heads * max_num_partitions +
@@ -603,8 +610,8 @@ struct paged_attention_v2_impl {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type; using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
constexpr int head_elem_num_per_group = constexpr int head_elem_num_per_group =
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE 16; // Note: didn't align with the cacheline size, due to some
// didn't align with 64 bytes // HEAD_SIZE didn't align with 64 bytes
static_assert(HEAD_SIZE % head_elem_num_per_group == 0); static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
const float* __restrict__ rescale_factors = exp_sums; const float* __restrict__ rescale_factors = exp_sums;
@@ -612,12 +619,11 @@ struct paged_attention_v2_impl {
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int partition_num = const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1) continue;
continue;
const float* __restrict__ seq_head_rescale_factors = const float* __restrict__ seq_head_rescale_factors =
rescale_factors + seq_idx * num_heads * max_num_partitions + rescale_factors + seq_idx * num_heads * max_num_partitions +
@@ -649,7 +655,7 @@ struct paged_attention_v2_impl {
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \ paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \ kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions); max_num_partitions);
@@ -658,8 +664,8 @@ void paged_attention_v2_impl_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) { int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
@@ -683,7 +689,7 @@ void paged_attention_v2_impl_launcher(
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr()); T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int *context_lens_ptr = context_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 64: case 64:
@@ -701,6 +707,9 @@ void paged_attention_v2_impl_launcher(
case 128: case 128:
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break; break;
case 192:
LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
break;
case 256: case 256:
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break; break;
@@ -713,8 +722,8 @@ void paged_attention_v2_impl_launcher(
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \ paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, block_size, \ num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
max_context_len, alibi_slopes); alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \
@@ -727,16 +736,19 @@ void paged_attention_v2_impl_launcher(
} }
} // namespace } // namespace
void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, void paged_attention_v2(
torch::Tensor &max_logits, torch::Tensor &tmp_out, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor &value_cache, int num_kv_heads, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
float scale, torch::Tensor &block_tables, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
torch::Tensor &context_lens, int block_size, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
int max_context_len, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const c10::optional<torch::Tensor> &alibi_slopes, const int64_t blocksparse_local_blocks,
const std::string &kv_cache_dtype, float kv_scale) { const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] { [&] {
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)

View File

@@ -5,19 +5,20 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void copy_blocks_cpu_impl( void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> &key_caches, std::vector<torch::Tensor> const& value_caches,
std::vector<torch::Tensor> &value_caches, const torch::Tensor& mapping_pairs,
const std::vector<std::pair<int64_t, int64_t>> mapping_pairs, const int element_num_per_block,
const int element_num_per_block, const int layer_num) { const int layer_num) {
const size_t pair_num = mapping_pairs.size(); const size_t pair_num = mapping_pairs.size(0);
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int layer = 0; layer < layer_num; ++layer) { for (int layer = 0; layer < layer_num; ++layer) {
for (size_t pair = 0; pair < pair_num; ++pair) { for (size_t pair = 0; pair < pair_num; ++pair) {
int64_t source_offset = element_num_per_block * mapping_pairs[pair].first; int64_t source_offset =
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
int64_t target_offset = int64_t target_offset =
element_num_per_block * mapping_pairs[pair].second; element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>(); scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
scalar_t* source_ptr = key_cache_ptr + source_offset; scalar_t* source_ptr = key_cache_ptr + source_offset;
scalar_t* target_ptr = key_cache_ptr + target_offset; scalar_t* target_ptr = key_cache_ptr + target_offset;
@@ -81,28 +82,23 @@ void reshape_and_cache_cpu_impl(
} }
}; // namespace }; // namespace
void copy_blocks(std::vector<torch::Tensor> &key_caches, // Note: the key_caches and value_caches vectors are constant but
std::vector<torch::Tensor> &value_caches, // not the Tensors they contain. The vectors need to be const refs
const std::map<int64_t, std::vector<int64_t>> &block_mapping) { // in order to satisfy pytorch's C++ operator registration code.
int num_layers = key_caches.size(); void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping) {
unsigned num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) { if (num_layers == 0) {
return; return;
} }
std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
mapping_pairs.reserve(block_mapping.size());
for (const auto &pair : block_mapping) {
for (const auto &dst : pair.second) {
mapping_pairs.emplace_back(pair.first, dst);
}
}
const int element_num_per_block = key_caches[0][0].numel(); const int element_num_per_block = key_caches[0][0].numel();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs, copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
element_num_per_block, num_layers); element_num_per_block, num_layers);
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
}); });
@@ -111,7 +107,7 @@ void copy_blocks(std::vector<torch::Tensor> &key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string &kv_cache_dtype, float kv_scale) { const std::string& kv_cache_dtype, double kv_scale) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
int num_tokens = key.size(0); int num_tokens = key.size(0);
@@ -136,6 +132,6 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
} }
void swap_blocks(torch::Tensor& src, torch::Tensor& dst, void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const std::map<int64_t, int64_t> &block_mapping) { const torch::Tensor& block_mapping) {
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
} }

View File

@@ -3,7 +3,7 @@
#define CPU_TYPES_HPP #define CPU_TYPES_HPP
#include <immintrin.h> #include <immintrin.h>
#include <torch/extension.h> #include <torch/all.h>
namespace vec_op { namespace vec_op {

View File

@@ -87,8 +87,8 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
} }
} // namespace } // namespace
void rms_norm(torch::Tensor &out, torch::Tensor &input, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor &weight, float epsilon) { double epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
@@ -102,7 +102,7 @@ void rms_norm(torch::Tensor &out, torch::Tensor &input,
} }
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor &weight, float epsilon) { torch::Tensor& weight, double epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;

View File

@@ -4,25 +4,74 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void rotary_embedding_impl( void rotary_embedding_impl(
const int64_t const int64_t* __restrict__ positions, // [batch_size, seq_len] or
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens] // [num_tokens]
scalar_t scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or /// head_size] or [num_tokens, num_heads,
/// [num_tokens, num_heads, head_size] /// head_size]
scalar_t scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or // head_size] or [num_tokens, num_kv_heads,
// [num_tokens, num_kv_heads, head_size] // head_size]
const scalar_t const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size, const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) { const int num_tokens) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
constexpr int ELEM_SIZE = sizeof(scalar_t);
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); bool flag = (embed_dim % VEC_ELEM_NUM == 0);
const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM;
auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr,
scalar_t* qk) {
int j = 0;
for (; j < loop_upper; j += VEC_ELEM_NUM) {
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const scalar_vec_t q_x(qk + out_x);
const scalar_vec_t q_y(qk + out_y);
vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
vec_op::FP32Vec8 fp32_q_x(q_x);
vec_op::FP32Vec8 fp32_q_y(q_y);
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
scalar_vec_t(out1).save(qk + out_x);
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
scalar_vec_t(out2).save(qk + out_y);
}
if (!flag) {
for (; j < embed_dim; ++j) {
const int x_index = j;
const int y_index = embed_dim + j;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const float fp32_cos = cache_ptr[x_index];
const float fp32_sin = cache_ptr[y_index];
const float fp32_q_x = qk[out_x];
const float fp32_q_y = qk[out_y];
qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
}
}
};
#pragma omp parallel for #pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
@@ -33,78 +82,29 @@ void rotary_embedding_impl(
const int head_idx = i; const int head_idx = i;
const int64_t token_head = const int64_t token_head =
token_idx * query_stride + head_idx * head_size; token_idx * query_stride + head_idx * head_size;
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { compute_loop(token_head, cache_ptr, query);
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const scalar_vec_t q_x(query + out_x);
const scalar_vec_t q_y(query + out_y);
vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
vec_op::FP32Vec8 fp32_q_x(q_x);
vec_op::FP32Vec8 fp32_q_y(q_y);
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
scalar_vec_t(out1).save(query + out_x);
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
scalar_vec_t(out2).save(query + out_y);
}
} }
for (int i = 0; i < num_kv_heads; ++i) { for (int i = 0; i < num_kv_heads; ++i) {
const int head_idx = i; const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { compute_loop(token_head, cache_ptr, key);
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const scalar_vec_t k_x(key + out_x);
const scalar_vec_t k_y(key + out_y);
vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
vec_op::FP32Vec8 fp32_k_x(k_x);
vec_op::FP32Vec8 fp32_k_y(k_y);
auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin;
scalar_vec_t(out1).save(key + out_x);
auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin;
scalar_vec_t(out2).save(key + out_y);
}
} }
} }
} }
template <typename scalar_t> template <typename scalar_t>
void rotary_embedding_gptj_impl( void rotary_embedding_gptj_impl(
const int64_t const int64_t* __restrict__ positions, // [batch_size, seq_len] or
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens] // [num_tokens]
scalar_t scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or /// head_size] or [num_tokens, num_heads,
/// [num_tokens, num_heads, head_size] /// head_size]
scalar_t scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or // head_size] or [num_tokens, num_kv_heads,
// [num_tokens, num_kv_heads, head_size] // head_size]
const scalar_t const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size, const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) { const int num_tokens) {
@@ -168,7 +168,7 @@ void rotary_embedding_gptj_impl(
}; // namespace }; // namespace
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor &key, int head_size, torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox) { torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = query.numel() / query.size(-1); int num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);

View File

@@ -1,73 +0,0 @@
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// vLLM custom ops
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops
ops.def(
"paged_attention_v1",
&paged_attention_v1,
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops
ops.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU with `none` approximation.");
ops.def(
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm
ops.def(
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def(
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding
ops.def(
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def(
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def(
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def(
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
}

106
csrc/cpu/torch_bindings.cpp Normal file
View File

@@ -0,0 +1,106 @@
#include "cache.h"
#include "ops.h"
#include "registration.h"
#include <torch/library.h>
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
// Attention ops
// Compute the attention between an input query and the cached keys/values
// using PagedAttention.
ops.def(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul);
// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul);
// Activation function used in GeGLU with `tanh` approximation.
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul);
// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_new", torch::kCPU, &gelu_new);
// Approximate GELU implementation.
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCPU, &gelu_fast);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()");
ops.impl("rms_norm", torch::kCPU, &rms_norm);
// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Cache ops
// Swap in (out) the cache blocks from src to dst.
cache_ops.def(
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks);
// Copy the cache blocks from src to dst.
cache_ops.def(
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
"block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" float kv_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

View File

@@ -17,9 +17,14 @@
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
#else #else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
__shfl_xor(var, lane_mask, width)
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
@@ -28,6 +33,13 @@
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
#endif #endif
#ifndef USE_ROCM
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
__shfl_down_sync(uint32_t(-1), var, lane_delta)
#else
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
#endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
@@ -35,4 +47,3 @@
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif #endif

View File

@@ -1,10 +1,5 @@
#pragma once #pragma once
#include <torch/extension.h> int64_t get_device_attribute(int64_t attribute, int64_t device_id);
int get_device_attribute( int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
int attribute,
int device_id);
int get_max_shared_memory_per_block_device_attribute(
int device_id);

View File

@@ -2,26 +2,20 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#endif #endif
int get_device_attribute( int64_t get_device_attribute(int64_t attribute, int64_t device_id) {
int attribute,
int device_id)
{
int device, value; int device, value;
if (device_id < 0) { if (device_id < 0) {
cudaGetDevice(&device); cudaGetDevice(&device);
} } else {
else {
device = device_id; device = device_id;
} }
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device); cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
device);
return value; return value;
} }
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) {
int get_max_shared_memory_per_block_device_attribute( int64_t attribute;
int device_id)
{
int attribute;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74

View File

@@ -1,17 +1,17 @@
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h> #include <c10/cuda/CUDAStream.h>
#include <torch/extension.h> #include <torch/all.h>
#include "custom_all_reduce.cuh" #include "custom_all_reduce.cuh"
// fake pointer type // fake pointer type, must match fptr_t type in ops.h
using fptr_t = uint64_t; using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t)); static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink) { bool full_nvlink) {
int world_size = offsets.size(); int world_size = offsets.size();
if (world_size > 8) if (world_size > 8)
@@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor &t) {
t.numel() * t.element_size()); t.numel() * t.element_size());
} }
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
bool full_nvlink) { bool full_nvlink) {
auto inp_size = inp.numel() * inp.element_size(); auto inp_size = inp.numel() * inp.element_size();
// custom allreduce requires input byte size to be multiples of 16 // custom allreduce requires input byte size to be multiples of 16
@@ -80,8 +80,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
} }
case at::ScalarType::Half: { case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()), fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half *>(out.data_ptr()), reinterpret_cast<half*>(out.data_ptr()), out.numel());
out.numel());
break; break;
} }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
@@ -126,7 +125,7 @@ void dispose(fptr_t _fa) {
delete fa; delete fa;
} }
int meta_size() { return sizeof(vllm::Signal); } int64_t meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor& t, void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles, const std::vector<std::string>& handles,
@@ -135,10 +134,16 @@ void register_buffer(fptr_t _fa, torch::Tensor &t,
fa->register_buffer(handles, offsets, t.data_ptr()); fa->register_buffer(handles, offsets, t.data_ptr());
} }
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta( std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) { fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
return fa->get_graph_buffer_ipc_meta(); auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handles =
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
return {handles, std::move(offsets)};
} }
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles, void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,

View File

@@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write // Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// wait until we got true from all ranks // wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x]) while (!self_sg->start[blockIdx.x][threadIdx.x]);
;
} }
__syncthreads(); __syncthreads();
} }
@@ -162,8 +161,7 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write // Latency = 1 p2p write
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
// wait until we got true from all ranks // wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x]) while (!self_sg->end[blockIdx.x][threadIdx.x]);
;
} }
if constexpr (!final_sync) __syncthreads(); if constexpr (!final_sync) __syncthreads();
} }
@@ -192,8 +190,7 @@ __global__ void __launch_bounds__(512, 1)
// do the actual reduction // do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
((P *)result)[idx] = ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
} }
end_sync<ngpus, true>(sg, self_sg, rank); end_sync<ngpus, true>(sg, self_sg, rank);
} }

View File

@@ -4,7 +4,7 @@
*/ */
#pragma once #pragma once
#include <torch/extension.h> #include <torch/all.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
@@ -12,8 +12,7 @@
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
@@ -22,8 +21,8 @@
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
@@ -33,5 +32,4 @@
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))

View File

@@ -1,4 +1,4 @@
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
@@ -23,9 +23,7 @@ __global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
@@ -41,11 +39,11 @@ __global__ void rms_norm_kernel(
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx]; float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
} }
} }
/* Converter structs for the conversion from torch types to HIP/CUDA types, /* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion to be implemented for now because the relevant type conversion
@@ -57,7 +55,9 @@ __global__ void rms_norm_kernel(
If true, the struct should be fully defined as shown in the examples below. If true, the struct should be fully defined as shown in the examples below.
*/ */
template <typename torch_type> template <typename torch_type>
struct _typeConvert { static constexpr bool exists = false; }; struct _typeConvert {
static constexpr bool exists = false;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion // CUDA < 12.0 runs into issues with packed type conversion
@@ -68,9 +68,15 @@ struct _typeConvert<c10::Half> {
using packed_hip_type = __half2; using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); } __device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } __device__ static inline float2 convert(packed_hip_type x) {
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); } return __half22float2(x);
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } }
__device__ static inline hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
}; };
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
@@ -82,13 +88,22 @@ struct _typeConvert<c10::BFloat16> {
using hip_type = __nv_bfloat16; using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162; using packed_hip_type = __nv_bfloat162;
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); } __device__ static inline float convert(hip_type x) {
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } return __bfloat162float(x);
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } }
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } __device__ static inline float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
}; };
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) #endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops /* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel. for appropriate specializations of fused_add_rms_norm_kernel.
@@ -117,8 +132,7 @@ struct alignas(16) _f16Vec {
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) for (int i = 0; i < width; ++i) data[i] += other.data[i];
data[i] += other.data[i];
} }
return *this; return *this;
} }
@@ -134,8 +148,7 @@ struct alignas(16) _f16Vec {
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) for (int i = 0; i < width; ++i) data[i] *= other.data[i];
data[i] *= other.data[i];
} }
return *this; return *this;
} }
@@ -185,14 +198,12 @@ struct alignas(16) _f16Vec {
packed and vectorized operations, which help with the packed and vectorized operations, which help with the
memory latency bottleneck. */ memory latency bottleneck. */
template <typename scalar_t, int width> template <typename scalar_t, int width>
__global__ std::enable_if_t< __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic // Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>); static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width); static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
@@ -203,9 +214,12 @@ __global__ std::enable_if_t<
/* These and the argument pointers are all declared `restrict` as they are /* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */ in this kernel as that would be undefined behavior */
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input); auto* __restrict__ input_v =
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual); reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight); auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx; int id = blockIdx.x * vec_hidden_size + idx;
@@ -218,7 +232,8 @@ __global__ std::enable_if_t<
calculation of max_block_size in fused_add_rms_norm */ calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) { if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance); variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance); } else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon); s_variance = rsqrtf(variance / hidden_size + epsilon);
} }
@@ -233,19 +248,16 @@ __global__ std::enable_if_t<
} }
} }
/* Generic fused_add_rms_norm_kernel /* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations. The width field is not used here but necessary for other specializations.
*/ */
template <typename scalar_t, int width> template <typename scalar_t, int width>
__global__ std::enable_if_t< __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
@@ -260,7 +272,8 @@ __global__ std::enable_if_t<
calculation of max_block_size in fused_add_rms_norm */ calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) { if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance); variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance); } else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon); s_variance = rsqrtf(variance / hidden_size + epsilon);
} }
@@ -268,17 +281,17 @@ __global__ std::enable_if_t<
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx]; float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; input[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
} }
} }
} // namespace vllm } // namespace vllm
void rms_norm( void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
float epsilon) { double epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
@@ -286,40 +299,27 @@ void rms_norm(
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
input.scalar_type(),
"rms_norm_kernel",
[&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
}); });
} }
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
"fused_add_rms_norm_kernel", \ vllm::fused_add_rms_norm_kernel<scalar_t, width> \
[&] { \ <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
vllm::fused_add_rms_norm_kernel \
<scalar_t, width><<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \ residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), \ weight.data_ptr<scalar_t>(), epsilon, \
epsilon, \ num_tokens, hidden_size); \
num_tokens, \
hidden_size); \
}); });
void fused_add_rms_norm( void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
float epsilon) { double epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
@@ -342,8 +342,8 @@ void fused_add_rms_norm(
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr()); auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr()); auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr()); auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ bool ptrs_are_aligned =
&& wt_ptr % 16 == 0; inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) { if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8); LAUNCH_FUSED_ADD_RMS_NORM(8);
} else { } else {

View File

@@ -1,7 +0,0 @@
#include "moe_ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
}

View File

@@ -1,9 +1,7 @@
#pragma once #pragma once
#include <torch/extension.h> #include <torch/all.h>
void topk_softmax( void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices, torch::Tensor& token_expert_indices,
torch::Tensor& gating_output); torch::Tensor& gating_output);

View File

@@ -16,18 +16,25 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"
#include <cub/cub.cuh> #ifndef USE_ROCM
#include <cub/util_type.cuh> #include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
namespace vllm { namespace vllm {
namespace moe { namespace moe {
static constexpr int WARP_SIZE = 32;
/// Aligned array type /// Aligned array type
template < template <
typename T, typename T,
@@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll #pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{ {
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW));
} }
// From this point, thread max in all the threads have the max within the row. // From this point, thread max in all the threads have the max within the row.
@@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll #pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{ {
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
} }
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
@@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll #pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{ {
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
// We want lower indices to "win" in every thread so we break ties this way // We want lower indices to "win" in every thread so we break ties this way
if (other_max > max_val || (other_max == max_val && other_expert < expert)) if (other_max > max_val || (other_max == max_val && other_expert < expert))
@@ -383,7 +390,7 @@ struct TopkConstants
{ {
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
static constexpr int THREADS_PER_ROW = EXPERTS / VPT; static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
@@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
{ {
static constexpr std::size_t MAX_BYTES_PER_LDG = 16; static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>; using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
static constexpr int VPT = Constants::VPT; static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;

View File

@@ -0,0 +1,12 @@
#include "registration.h"
#include "moe_ops.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

View File

@@ -1,4 +1,4 @@
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
@@ -12,11 +12,12 @@
namespace vllm { namespace vllm {
namespace { namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
int32_t col) {
// don't worry about overflow because num_experts is relatively small // don't worry about overflow because num_experts is relatively small
return row * total_col + col; return row * total_col + col;
} }
} } // namespace
template <typename scalar_t> template <typename scalar_t>
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
@@ -24,15 +25,17 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
int32_t* expert_ids, int32_t* expert_ids,
int32_t* total_tokens_post_pad, int32_t* total_tokens_post_pad,
int32_t num_experts, int32_t num_experts,
int32_t block_size, int32_t block_size, size_t numel) {
size_t numel) {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread; const size_t start_idx = threadIdx.x * tokens_per_thread;
extern __shared__ int32_t shared_mem[]; extern __shared__ int32_t shared_mem[];
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) int32_t* tokens_cnts =
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
int32_t* cumsum =
shared_mem + (num_experts + 1) *
num_experts; // 1d tensor with shape (num_experts + 1)
for (int i = 0; i < num_experts; ++i) { for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
@@ -40,8 +43,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
/** /**
* In the first step we compute token_cnts[thread_index + 1][expert_index], * In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are assigned * which counts how many tokens in the token shard of thread_index are
* to expert expert_index. * assigned to expert expert_index.
*/ */
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
@@ -52,7 +55,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
// For each expert we accumulate the token counts from the different threads. // For each expert we accumulate the token counts from the different threads.
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) { for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
} }
__syncthreads(); __syncthreads();
@@ -61,7 +65,10 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
cumsum[0] = 0; cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) { for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; cumsum[i] = cumsum[i - 1] +
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
} }
*total_tokens_post_pad = cumsum[num_experts]; *total_tokens_post_pad = cumsum[num_experts];
} }
@@ -69,57 +76,59 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
__syncthreads(); __syncthreads();
/** /**
* For each expert, each thread processes the tokens of the corresponding blocks * For each expert, each thread processes the tokens of the corresponding
* and stores the corresponding expert_id for each block. * blocks and stores the corresponding expert_id for each block.
*/ */
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) { for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x; expert_ids[i / block_size] = threadIdx.x;
} }
/** /**
* Each thread processes a token shard, calculating the index of each token after * Each thread processes a token shard, calculating the index of each token
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and * after sorting by expert number. Given the example topk_ids =
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* where * represents a padding value(preset in python). * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/ */
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i]; int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the /** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] * expert with expert_id needs to process, and
* stores the indices of the tokens processed by the expert with expert_id within * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* the current thread's token shard. * processed by the expert with expert_id within the current thread's token
* shard.
*/ */
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i; sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
} }
} }
} } // namespace vllm
void moe_align_block_size( void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor topk_ids, int64_t block_size, torch::Tensor sorted_token_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) { torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); // tensors
const int32_t shared_mem =
((num_experts + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);
// set dynamic shared mem // set dynamic shared mem
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>; auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK( AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem)); (void*)kernel, shared_mem));
kernel<<<1, num_experts, shared_mem, stream>>>( kernel<<<1, num_experts, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
num_experts,
block_size,
topk_ids.numel()); topk_ids.numel());
}); });
} }

View File

@@ -1,181 +1,146 @@
#pragma once #pragma once
#include <torch/extension.h> #include <torch/library.h>
void paged_attention_v1( void paged_attention_v1(
torch::Tensor& out, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& query, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& key_cache, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
torch::Tensor& value_cache, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
int num_kv_heads, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
float scale, const int64_t blocksparse_local_blocks,
torch::Tensor& block_tables, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
torch::Tensor& context_lens, const int64_t blocksparse_head_sliding_step);
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale);
void paged_attention_v2( void paged_attention_v2(
torch::Tensor& out, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& exp_sums, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& max_logits, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& tmp_out, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
torch::Tensor& query, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
torch::Tensor& key_cache, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
torch::Tensor& value_cache, const int64_t blocksparse_local_blocks,
int num_kv_heads, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
float scale, const int64_t blocksparse_head_sliding_step);
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale);
void rms_norm( void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor& out, double epsilon);
torch::Tensor& input,
torch::Tensor& weight,
float epsilon);
void fused_add_rms_norm( void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& input, torch::Tensor& weight, double epsilon);
torch::Tensor& residual,
torch::Tensor& weight,
float epsilon);
void rotary_embedding( void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& positions, torch::Tensor& key, int64_t head_size,
torch::Tensor& query, torch::Tensor& cos_sin_cache, bool is_neox);
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox);
void batched_rotary_embedding( void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& positions, torch::Tensor& key, int64_t head_size,
torch::Tensor& query, torch::Tensor& cos_sin_cache, bool is_neox,
torch::Tensor& key, int64_t rot_dim,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox,
int rot_dim,
torch::Tensor& cos_sin_cache_offsets); torch::Tensor& cos_sin_cache_offsets);
void silu_and_mul( void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_and_mul( void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_tanh_and_mul( void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_new( void gelu_new(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast( void gelu_fast(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
#ifndef USE_ROCM #ifndef USE_ROCM
torch::Tensor aqlm_gemm( torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& scales, const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes, const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias const std::optional<torch::Tensor>& bias);
);
torch::Tensor aqlm_dequant( torch::Tensor aqlm_dequant(const torch::Tensor& codes,
const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes const torch::Tensor& codebook_partition_sizes);
);
torch::Tensor awq_gemm( torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _in_feats, torch::Tensor _scaling_factors, torch::Tensor _zeros,
torch::Tensor _kernel, int64_t split_k_iters);
torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _scaling_factors,
torch::Tensor _zeros, torch::Tensor _zeros, int64_t split_k_iters,
int split_k_iters); int64_t thx, int64_t thy);
torch::Tensor awq_dequantize( torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor _kernel, torch::Tensor& b_scales, torch::Tensor& workspace,
torch::Tensor _scaling_factors, int64_t size_m, int64_t size_n, int64_t size_k);
torch::Tensor _zeros,
int split_k_iters,
int thx,
int thy);
torch::Tensor marlin_gemm( torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& a, torch::Tensor& b_meta,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& b_scales,
torch::Tensor& workspace, torch::Tensor& workspace, int64_t num_bits,
int64_t size_m, int64_t size_m, int64_t size_n,
int64_t size_n,
int64_t size_k); int64_t size_k);
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& g_idx,
torch::Tensor& perm, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full);
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);
void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
#endif #endif
void squeezellm_gemm( void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor vec, torch::Tensor const& scale);
torch::Tensor mat,
torch::Tensor mul, void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scales);
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table); torch::Tensor lookup_table);
torch::Tensor gptq_gemm( torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
torch::Tensor b_g_idx, bool use_exllama, int64_t bit);
bool use_exllama,
int bit);
void gptq_shuffle( void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
torch::Tensor q_weight,
torch::Tensor q_perm,
int bit);
void scaled_fp8_quant( void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale); torch::Tensor& scale);
void moe_align_block_size( void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor topk_ids, torch::Tensor& scale);
int num_experts,
int block_size, void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor sorted_token_ids, int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM #ifndef USE_ROCM
using fptr_t = uint64_t; using fptr_t = int64_t;
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink); bool full_nvlink);
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
bool full_nvlink); bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out); torch::Tensor& out);
void dispose(fptr_t _fa); void dispose(fptr_t _fa);
int meta_size(); int64_t meta_size();
void register_buffer(fptr_t _fa, torch::Tensor& t, void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles, const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets); const std::vector<int64_t>& offsets);
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa); std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles, void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets); const std::vector<std::vector<int64_t>>& offsets);
#endif #endif

View File

@@ -1,4 +1,4 @@
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
@@ -9,12 +9,8 @@ namespace vllm {
template <typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding( inline __device__ void apply_token_rotary_embedding(
scalar_t* __restrict__ arr, scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index; int x_index, y_index;
scalar_t cos, sin; scalar_t cos, sin;
if (IS_NEOX) { if (IS_NEOX) {
@@ -39,17 +35,15 @@ inline __device__ void apply_token_rotary_embedding(
template <typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding( inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] // head_size] or [num_tokens, num_heads,
const scalar_t* cache_ptr, // head_size]
const int head_size, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int num_heads, // head_size] or [num_tokens, num_kv_heads,
const int num_kv_heads, // head_size]
const int rot_dim, const scalar_t* cache_ptr, const int head_size, const int num_heads,
const int token_idx, const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t query_stride, const int64_t key_stride) {
const int64_t key_stride)
{
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr; const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim; const scalar_t* sin_ptr = cache_ptr + embed_dim;
@@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, apply_token_rotary_embedding<scalar_t, IS_NEOX>(
sin_ptr, rot_offset, embed_dim); query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
} }
const int nk = num_kv_heads * embed_dim; const int nk = num_kv_heads * embed_dim;
@@ -68,60 +62,72 @@ inline __device__ void apply_rotary_embedding(
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, apply_token_rotary_embedding<scalar_t, IS_NEOX>(
sin_ptr, rot_offset, embed_dim); key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
} }
} }
template <typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel( __global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] const int64_t* __restrict__ positions, // [batch_size, seq_len] or
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] // [num_tokens]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // head_size] or [num_tokens, num_heads,
const int rot_dim, // head_size]
const int64_t query_stride, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int64_t key_stride, // head_size] or [num_tokens, num_kv_heads,
const int num_heads, // head_size]
const int num_kv_heads, const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
const int head_size) { // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
} }
template <typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
__global__ void batched_rotary_embedding_kernel( __global__ void batched_rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] const int64_t* __restrict__ positions, // [batch_size, seq_len] or
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] // [num_tokens]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // head_size] or [num_tokens, num_heads,
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] // head_size]
const int rot_dim, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int64_t query_stride, // head_size] or [num_tokens, num_kv_heads,
const int64_t key_stride, // head_size]
const int num_heads, const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
const int num_kv_heads, // 2]
const int head_size) { const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
// or [num_tokens]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; const scalar_t* cache_ptr =
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
} }
} // namespace vllm } // namespace vllm
void rotary_embedding( void rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] // [num_tokens, num_heads * head_size]
int head_size, torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) { bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1); int64_t num_tokens = query.numel() / query.size(-1);
@@ -132,36 +138,21 @@ void rotary_embedding(
int64_t key_stride = key.stride(-2); int64_t key_stride = key.stride(-2);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
query.scalar_type(),
"rotary_embedding",
[&] {
if (is_neox) { if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
query.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
key.data_ptr<scalar_t>(), query_stride, key_stride, num_heads, num_kv_heads, head_size);
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else { } else {
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>( vllm::rotary_embedding_kernel<scalar_t, false>
positions.data_ptr<int64_t>(), <<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size); head_size);
} }
}); });
@@ -173,12 +164,13 @@ and process in batched manner.
*/ */
void batched_rotary_embedding( void batched_rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] // [num_tokens, num_heads * head_size]
int head_size, torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, bool is_neox, int64_t rot_dim,
int rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens] torch::Tensor& cos_sin_cache_offsets // [num_tokens]
) { ) {
int64_t num_tokens = cos_sin_cache_offsets.size(0); int64_t num_tokens = cos_sin_cache_offsets.size(0);
@@ -188,39 +180,24 @@ void batched_rotary_embedding(
int64_t key_stride = key.stride(-2); int64_t key_stride = key.stride(-2);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
query.scalar_type(),
"rotary_embedding",
[&] {
if (is_neox) { if (is_neox) {
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( vllm::batched_rotary_embedding_kernel<scalar_t, true>
positions.data_ptr<int64_t>(), <<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
cos_sin_cache_offsets.data_ptr<int64_t>(), key_stride, num_heads, num_kv_heads, head_size);
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else { } else {
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>( vllm::batched_rotary_embedding_kernel<scalar_t, false>
positions.data_ptr<int64_t>(), <<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
cos_sin_cache_offsets.data_ptr<int64_t>(), key_stride, num_heads, num_kv_heads, head_size);
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} }
}); });
} }

View File

@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)

View File

@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16)

View File

@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 2752) \ f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \ f(in_T, out_T, W_T, narrow, 2816) \
f(in_T, out_T, W_T, narrow, 3072) \ f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3328) \
f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \ f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \ f(in_T, out_T, W_T, narrow, 4096) \
@@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \ f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \ f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6400) \
f(in_T, out_T, W_T, narrow, 6848) \ f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \ f(in_T, out_T, W_T, narrow, 7168) \
@@ -53,6 +55,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 22016) \ f(in_T, out_T, W_T, narrow, 22016) \
f(in_T, out_T, W_T, narrow, 24576) \ f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 27392) \ f(in_T, out_T, W_T, narrow, 27392) \
f(in_T, out_T, W_T, narrow, 27648) \
f(in_T, out_T, W_T, narrow, 28672) \ f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \ f(in_T, out_T, W_T, narrow, 32256) \
@@ -74,6 +77,77 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA // Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py // and vllm/tests/lora/test_punica.py
// Used for defining kernels going from the variety of
// dim in to the narrow dim out
// Using it for the fully sharded column
// parallel LoRA A which splits the rank dim
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, 128, narrow) \
f(in_T, out_T, W_T, 256, narrow) \
f(in_T, out_T, W_T, 512, narrow) \
f(in_T, out_T, W_T, 640, narrow) \
f(in_T, out_T, W_T, 768, narrow) \
f(in_T, out_T, W_T, 1024, narrow) \
f(in_T, out_T, W_T, 1152, narrow) \
f(in_T, out_T, W_T, 1280, narrow) \
f(in_T, out_T, W_T, 1536, narrow) \
f(in_T, out_T, W_T, 1728, narrow) \
f(in_T, out_T, W_T, 1792, narrow) \
f(in_T, out_T, W_T, 2048, narrow) \
f(in_T, out_T, W_T, 2304, narrow) \
f(in_T, out_T, W_T, 2560, narrow) \
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3328, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
f(in_T, out_T, W_T, 4608, narrow) \
f(in_T, out_T, W_T, 5120, narrow) \
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6400, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
f(in_T, out_T, W_T, 8192, narrow) \
f(in_T, out_T, W_T, 9216, narrow) \
f(in_T, out_T, W_T, 10240, narrow) \
f(in_T, out_T, W_T, 11008, narrow) \
f(in_T, out_T, W_T, 12288, narrow) \
f(in_T, out_T, W_T, 13696, narrow) \
f(in_T, out_T, W_T, 13824, narrow) \
f(in_T, out_T, W_T, 14336, narrow) \
f(in_T, out_T, W_T, 15360, narrow) \
f(in_T, out_T, W_T, 16384, narrow) \
f(in_T, out_T, W_T, 20480, narrow) \
f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 27648, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \
f(in_T, out_T, W_T, 32512, narrow) \
f(in_T, out_T, W_T, 32768, narrow) \
f(in_T, out_T, W_T, 33024, narrow) \
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \
f(in_T, out_T, W_T, 102400, narrow) \
f(in_T, out_T, W_T, 102656, narrow) \
f(in_T, out_T, W_T, 102912, narrow) \
f(in_T, out_T, W_T, 128000, narrow) \
f(in_T, out_T, W_T, 128256, narrow) \
f(in_T, out_T, W_T, 128512, narrow) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig // Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
@@ -81,4 +155,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
f(in_T, out_T, W_T, 8, 64) \
f(in_T, out_T, W_T, 16, 64) \
f(in_T, out_T, W_T, 32, 64) \
f(in_T, out_T, W_T, 64, 64)
// clang-format on // clang-format on

View File

@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)

View File

@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half)

View File

@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16)

View File

@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half)

View File

@@ -1,8 +1,14 @@
#pragma once #pragma once
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#ifndef USE_ROCM
#include <cooperative_groups.h> #include <cooperative_groups.h>
#else
#include <hip/hip_cooperative_groups.h>
#endif
#ifndef USE_ROCM
#include <cuda/pipeline> #include <cuda/pipeline>
#endif
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream> #include <iostream>
#include <stdio.h> #include <stdio.h>
@@ -11,6 +17,24 @@
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
#ifdef USE_ROCM
template <size_t len>
__host__ __device__
inline void* memcpy_blocking(void *dst, const void *src) {
// Does not handle the case of long datatypes
char *d = reinterpret_cast<char *>(dst);
const char *s = reinterpret_cast<const char *>(src);
size_t i = 0;
#pragma unroll
for (i = 0; i < len; ++i) {
d[i] = s[i];
}
return dst;
}
#endif
#ifndef USE_ROCM
// nthrs = (32, 4) // nthrs = (32, 4)
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size, template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T, size_t W_copy_size, int tx, int ty, int tz, typename in_T,
@@ -141,6 +165,81 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
} }
} }
#else
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
typename out_T, typename W_T>
__global__ void
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
size_t j = blockIdx.x;
constexpr size_t tile_size = tx * ty * vec_size;
constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size;
__shared__ float y_warpwise[ty];
float y = 0;
vec_t<in_T, vec_size> x_vec;
vec_t<W_T, vec_size> w_vec;
size_t tile_idx;
#pragma unroll
for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
x_vec.load(X + (batch_idx * feat_in) +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W + (idx * feat_out + j) * feat_in +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size);
}
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += VLLM_SHFL_DOWN_SYNC(sum, offset);
}
__syncthreads();
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
y += sum;
}
}
if (threadIdx.x == 0) {
y_warpwise[threadIdx.y] = y;
}
__syncthreads();
float y_write = 0.f;
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y_write += y_warpwise[i];
}
// write Y;
if (threadIdx.x == 0 && threadIdx.y == 0) {
size_t y_idx = batch_idx * full_y_size + y_offset + j;
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(y_write));
}
}
#endif
// nthrs = (2, 16, 4) // nthrs = (2, 16, 4)
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz, template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
typename in_T, typename out_T, typename W_T> typename in_T, typename out_T, typename W_T>
@@ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
float sum = 0.f; float sum = 0.f;
#pragma unroll #pragma unroll
for (size_t i = 0; i < vec_size; ++i) { for (size_t i = 0; i < vec_size; ++i) {
#ifndef USE_ROCM
sum += float(w_vec[i]) * float(x_vec[i]) * scale; sum += float(w_vec[i]) * float(x_vec[i]) * scale;
#else
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
#endif
} }
cg::thread_block_tile g = cg::tiled_partition<tx>(block); cg::thread_block_tile g = cg::tiled_partition<tx>(block);
@@ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
sum = g.shfl(sum, 0); sum = g.shfl(sum, 0);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
#ifndef USE_ROCM
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum); threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
#else
size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y;
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(sum));
#endif
} }
} }
@@ -199,7 +308,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
constexpr int tz = 4; constexpr int tz = 4;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if constexpr (feat_in < feat_out) { if constexpr (feat_in <= feat_out) {
static_assert(feat_in % vec_size == 0); static_assert(feat_in % vec_size == 0);
constexpr int tx = feat_in / vec_size; constexpr int tx = feat_in / vec_size;
@@ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
scale); scale);
} }
} else { } else {
#ifndef USE_ROCM
static_assert(feat_in % (vec_size * 32) == 0 || static_assert(feat_in % (vec_size * 32) == 0 ||
feat_in % (vec_size * 16) == 0 || feat_in % (vec_size * 16) == 0 ||
feat_in % (vec_size * 8) == 0); feat_in % (vec_size * 8) == 0);
@@ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
full_y_size, num_layers, layer_idx, full_y_size, num_layers, layer_idx,
scale); scale);
} }
#else
constexpr size_t rocm_warp_size = warpSize;
#define CHECK_INPUT_TILEABLE_BY(vec_size_) \
feat_in % (rocm_warp_size * vec_size_) == 0
#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \
if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \
constexpr size_t vec_size_shrink = vec_size_; \
constexpr int tx = tx_; \
constexpr int ty = ty_; \
dim3 nblks(feat_out, batch_size); \
dim3 nthrs(tx, ty); \
bgmv_shrink_kernel<feat_in, feat_out, vec_size_shrink, \
vec_size_shrink * sizeof(in_T), \
vec_size_shrink * sizeof(W_T), \
tx, ty, tz> \
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, \
full_y_size, num_layers, layer_idx, \
scale); \
}
static_assert(CHECK_INPUT_TILEABLE_BY(32) ||
CHECK_INPUT_TILEABLE_BY(16) ||
CHECK_INPUT_TILEABLE_BY( 8) ||
CHECK_INPUT_TILEABLE_BY( 4) ||
CHECK_INPUT_TILEABLE_BY( 2) ||
CHECK_INPUT_TILEABLE_BY( 1));
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1)
#undef CHECK_INPUT_TILEABLE_BY
#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM
#endif
} }
} }
@@ -289,6 +443,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale); int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ #define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \ INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T) INST_BGMV(wide, narrow, in_T, out_T, W_T)

View File

@@ -10,6 +10,7 @@ TEMPLATE = """
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip() # noqa: E501 """.lstrip() # noqa: E501
for input_dtype in DTYPES: for input_dtype in DTYPES:

View File

@@ -1,8 +1,6 @@
#ifndef VEC_DTYPES_CUH_ #ifndef VEC_DTYPES_CUH_
#define VEC_DTYPES_CUH_ #define VEC_DTYPES_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifdef FLASHINFER_USE_FP8 #ifdef FLASHINFER_USE_FP8
#include <cuda_fp8.h> #include <cuda_fp8.h>
#endif #endif
@@ -10,6 +8,9 @@
#include <type_traits> #include <type_traits>
#include "../type_convert.h"
#include "../../cuda_compat.h"
#define FLASHINFER_INLINE \ #define FLASHINFER_INLINE \
inline __attribute__((always_inline)) __device__ __host__ inline __attribute__((always_inline)) __device__ __host__

View File

@@ -1,12 +1,11 @@
#include <cuda_bf16.h> #include <torch/all.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <cstdint> #include <cstdint>
#include "type_convert.h"
#include "../cuda_compat.h"
#include "bgmv/bgmv_config.h" #include "bgmv/bgmv_config.h"
namespace {
//====== utils ====== //====== utils ======
@@ -79,17 +78,17 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
FOR_BGMV_WIDE_NARROW(CASE, _, _, _) FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _)
#undef CASE #undef CASE
#undef CASE_ONESIDE #undef CASE_ONESIDE
default: default:
return false; return false;
} }
return true; return true;
} }
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, float scale) { torch::Tensor indicies, int64_t layer_idx, double scale) {
CHECK_INPUT(y); CHECK_INPUT(y);
CHECK_INPUT(x); CHECK_INPUT(x);
CHECK_INPUT(w); CHECK_INPUT(w);
@@ -321,7 +320,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, torch::Tensor indicies, int64_t layer_idx,
float scale, int64_t h_in, int64_t h_out, double scale, int64_t h_in, int64_t h_out,
int64_t y_offset) { int64_t y_offset) {
CHECK_INPUT(y); CHECK_INPUT(y);
CHECK_INPUT(x); CHECK_INPUT(x);
@@ -568,15 +567,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
} }
} // namespace
//====== pybind ======
#define DEFINE_pybind(name) m.def(#name, &name, #name);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
"dispatch_bgmv_low_level");
}

11
csrc/punica/punica_ops.h Normal file
View File

@@ -0,0 +1,11 @@
#pragma once
#include <torch/all.h>
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, double scale);
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx,
double scale, int64_t h_in, int64_t h_out,
int64_t y_offset);

View File

@@ -0,0 +1,18 @@
#include "registration.h"
#include "punica_ops.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
"layer_idx, float scale) -> ()");
m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);
m.def(
"dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
"Tensor indicies, int layer_idx,"
"float scale, int h_in, int h_out,"
"int y_offset) -> ()");
m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

View File

@@ -0,0 +1,82 @@
#ifndef CSRC__PUNICA__TYPE_CONVERT_H__
#define CSRC__PUNICA__TYPE_CONVERT_H__
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__
typedef __half nv_half;
typedef __hip_bfloat16 nv_bfloat16;
typedef __hip_bfloat162 nv_bfloat162;
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) {
return __hip_bfloat162{val, val};
}
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) {
return __hip_bfloat162{vall, valr};
}
template <typename T_src, typename T_dst>
__TYPE_CONVERT__HOST_DEVICE__
inline T_dst convert_type(T_src val) {
return static_cast<T_dst>(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline float convert_type<__half, float>(__half val) {
return __half2float(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __half convert_type<float, __half>(float val) {
return __float2half(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) {
return __bfloat162float(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat16 convert_type<float, __hip_bfloat16>(float val) {
return __float2bfloat16(val);
}
template <typename T>
__TYPE_CONVERT__HOST_DEVICE__
inline T vllm_add(T a, T b) {
return a + b;
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __half vllm_add<__half>(__half a, __half b) {
return __hadd(a, b);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) {
return __hadd(a, b);
}
#undef __TYPE_CONVERT__HOST_DEVICE__
#endif // USE_ROCM
#endif // CSRC__PUNICA__TYPE_CONVERT_H__

View File

@@ -1,129 +0,0 @@
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// vLLM custom ops
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops
ops.def(
"paged_attention_v1",
&paged_attention_v1,
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops
ops.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU with `none` approximation.");
ops.def(
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm
ops.def(
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def(
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding
ops.def(
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
ops.def(
"batched_rotary_embedding",
&batched_rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
// Quantization ops
#ifndef USE_ROCM
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def(
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def(
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def(
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"convert_fp8",
&convert_fp8,
"Convert the key and value cache to fp8 data type");
// Cuda utils
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
cuda_utils.def(
"get_device_attribute",
&get_device_attribute,
"Gets the specified device attribute.");
cuda_utils.def(
"get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute,
"Gets the maximum shared memory per block device attribute.");
#ifndef USE_ROCM
// Custom all-reduce kernels
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
custom_ar.def("dispose", &dispose, "dispose");
custom_ar.def("meta_size", &meta_size, "meta_size");
custom_ar.def("register_buffer", &register_buffer, "register_buffer");
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
"get_graph_buffer_ipc_meta");
custom_ar.def("register_graph_buffers", &register_graph_buffers,
"register_graph_buffers");
#endif
}

View File

@@ -18,37 +18,33 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <torch/extension.h> #include <torch/all.h>
#include <c10/cuda/CUDAStream.h> #include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
namespace vllm { namespace vllm {
namespace aqlm { namespace aqlm {
__global__ void Code1x16MatVec( __global__ void Code1x16MatVec(
const int4* __restrict__ A, const int4* __restrict__ A, const int4* __restrict__ B,
const int4* __restrict__ B, int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m,
int4* __restrict__ C,
const int4* __restrict__ codebook,
const int prob_m,
const int prob_k, const int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. const int4 codebook_a_sizes, // cumulative sizes of A spanning each
// codebook, at most 3 long.
const int codebook_stride // as int4. const int codebook_stride // as int4.
) { ) {
int a_gl_stride = prob_k / 8 / 8; int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m; bool pred = a_gl_rd < prob_m;
if (pred) if (pred) {
{ // advance to the correct codebook, this easy because we only multiply one
// advance to the correct codebook, this easy because we only multiply one column of the codebook. // column of the codebook.
auto codebook_size = &codebook_a_sizes.x; auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) while (a_gl_rd >= *codebook_size) {
{
codebook += codebook_stride; codebook += codebook_stride;
++codebook_size; ++codebook_size;
} }
@@ -67,8 +63,7 @@ __global__ void Code1x16MatVec(
// We pad shared memory to avoid bank conflicts during reads // We pad shared memory to avoid bank conflicts during reads
__syncthreads(); __syncthreads();
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
if (b_gl_rd + i < prob_k / 8) if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
} }
__syncthreads(); __syncthreads();
b_gl_rd += 32 * 8; b_gl_rd += 32 * 8;
@@ -79,19 +74,16 @@ __global__ void Code1x16MatVec(
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
uint32_t dec[4]; uint32_t dec[4];
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't // We bypass the L1 cache to avoid massive amounts of memory streaming
// actually help us; this brings > 2x speedup. // that doesn't actually help us; this brings > 2x speedup.
asm volatile ( asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
: "l"((void*) &codebook[enc[i]]) : "l"((void*)&codebook[enc[i]]));
);
half2* a = reinterpret_cast<half2*>(&dec); half2* a = reinterpret_cast<half2*>(&dec);
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]); half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
half2 res2 = {}; half2 res2 = {};
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2);
res2 = __hfma2(a[j], b[j], res2);
res += __half2float(res2.x) + __half2float(res2.y); res += __half2float(res2.x) + __half2float(res2.y);
b_sh_rd++; b_sh_rd++;
} }
@@ -101,21 +93,18 @@ __global__ void Code1x16MatVec(
if (pred) { if (pred) {
#pragma unroll #pragma unroll
for (int i = 16; i > 0; i /= 2) for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
res += __shfl_down_sync(0xffffffff, res, i);
if (threadIdx.x % 32 == 0) if (threadIdx.x % 32 == 0)
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
} }
} }
__global__ void Code2x8MatVec( __global__ void Code2x8MatVec(
const int4* __restrict__ A, const int4* __restrict__ A, const int4* __restrict__ B,
const int4* __restrict__ B, int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m,
int4* __restrict__ C,
const int4* __restrict__ codebook,
int prob_m,
int prob_k, int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. const int4 codebook_a_sizes, // cumulative sizes of A spanning each
// codebook, at most 3 long.
const int codebook_stride // as int4. const int codebook_stride // as int4.
) { ) {
@@ -123,12 +112,11 @@ __global__ void Code2x8MatVec(
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m; bool pred = a_gl_rd < prob_m;
if (pred) if (pred) {
{ // advance to the correct codebook, this easy because we only multiply one
// advance to the correct codebook, this easy because we only multiply one column of the codebook. // column of the codebook.
auto codebook_size = &codebook_a_sizes.x; auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) while (a_gl_rd >= *codebook_size) {
{
codebook += codebook_stride; codebook += codebook_stride;
++codebook_size; ++codebook_size;
} }
@@ -149,8 +137,7 @@ __global__ void Code2x8MatVec(
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
int4 dec = codebook[i]; int4 dec = codebook[i];
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
sh_code[8 * i + (j + lane) % 8] = dec;
} }
__syncthreads(); __syncthreads();
@@ -161,8 +148,7 @@ __global__ void Code2x8MatVec(
// We pad shared memory to avoid bank conflicts during reads // We pad shared memory to avoid bank conflicts during reads
__syncthreads(); __syncthreads();
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
if (b_gl_rd + i < prob_k / 8) if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
} }
__syncthreads(); __syncthreads();
b_gl_rd += 32 * 8; b_gl_rd += 32 * 8;
@@ -172,8 +158,10 @@ __global__ void Code2x8MatVec(
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]); const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]); half2* a0 =
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]); reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
half2* a1 =
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]); half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
half2 res2 = {}; half2 res2 = {};
#pragma unroll #pragma unroll
@@ -188,33 +176,28 @@ __global__ void Code2x8MatVec(
if (pred) { if (pred) {
#pragma unroll #pragma unroll
for (int i = 16; i > 0; i /= 2) for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
res += __shfl_down_sync(0xffffffff, res, i);
if (threadIdx.x % 32 == 0) if (threadIdx.x % 32 == 0)
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
} }
} }
__global__ void Code1x16Dequant( __global__ void Code1x16Dequant(
const int4* __restrict__ A, const int4* __restrict__ A, int4* __restrict__ C,
int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, int prob_k,
const int4* __restrict__ codebook, const int4 codebook_a_sizes, // cumulative sizes of A spanning each
int prob_m, // codebook, at most 3 long, sums to m.
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
const int codebook_stride // as int4 const int codebook_stride // as int4
) { ) {
int a_gl_stride = prob_k / 8 / 8; int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m; bool pred = a_gl_rd < prob_m;
if (pred) if (pred) {
{ // advance to the correct codebook, this easy because we only multiply one
// advance to the correct codebook, this easy because we only multiply one column of the codebook. // column of the codebook.
auto codebook_size = &codebook_a_sizes.x; auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) while (a_gl_rd >= *codebook_size) {
{
codebook += codebook_stride; codebook += codebook_stride;
++codebook_size; ++codebook_size;
} }
@@ -235,13 +218,11 @@ __global__ void Code1x16Dequant(
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
int4 chunk; int4 chunk;
auto dec = reinterpret_cast<uint32_t*>(&chunk); auto dec = reinterpret_cast<uint32_t*>(&chunk);
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't // We bypass the L1 cache to avoid massive amounts of memory streaming
// actually help us; this brings > 2x speedup. // that doesn't actually help us; this brings > 2x speedup.
asm volatile ( asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
: "l"((void*) &codebook[enc[i]]) : "l"((void*)&codebook[enc[i]]));
);
C[a_gl_rd * 8 + i] = chunk; C[a_gl_rd * 8 + i] = chunk;
} }
@@ -250,26 +231,23 @@ __global__ void Code1x16Dequant(
} }
} }
__global__ void Code2x8Dequant( __global__ void Code2x8Dequant(
const int4* __restrict__ A, const int4* __restrict__ A, int4* __restrict__ C,
int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, int prob_k,
const int4* __restrict__ codebook, const int4
int prob_m, codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
int prob_k, // most 3 long, corresponds to cols.
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
const int codebook_stride // as int4 const int codebook_stride // as int4
) { ) {
int a_gl_stride = prob_k / 8 / 8; int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m; bool pred = a_gl_rd < prob_m;
if (pred) if (pred) {
{ // advance to the correct codebook, this easy because we only multiply one
// advance to the correct codebook, this easy because we only multiply one column of the codebook. // column of the codebook.
auto codebook_size = &codebook_a_sizes.x; auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) while (a_gl_rd >= *codebook_size) {
{
codebook += codebook_stride; codebook += codebook_stride;
++codebook_size; ++codebook_size;
} }
@@ -291,8 +269,7 @@ __global__ void Code2x8Dequant(
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
int4 dec = codebook[i]; int4 dec = codebook[i];
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
sh_code[8 * i + (j + lane) % 8] = dec;
} }
__syncthreads(); __syncthreads();
@@ -305,8 +282,10 @@ __global__ void Code2x8Dequant(
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
int4 chunk; int4 chunk;
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]); half2* a0 =
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]); reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
half2* a1 =
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) for (int j = 0; j < 4; j++)
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]); reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
@@ -317,22 +296,15 @@ __global__ void Code2x8Dequant(
} }
} }
inline int ceildiv(int a, int b) { inline int ceildiv(int a, int b) { return (a + b - 1) / b; }
return (a + b - 1) / b;
}
const int THREAD_M = 16; const int THREAD_M = 16;
void code1x16_matvec_cuda( void code1x16_matvec_cuda(const void* __restrict__ A,
const void* __restrict__ A, const void* __restrict__ B, void* __restrict__ C,
const void* __restrict__ B, const void* __restrict__ codebook, int prob_m,
void* __restrict__ C, int prob_k, const int4 codebook_a_sizes,
const void* __restrict__ codebook, const int codebook_stride) {
int prob_m,
int prob_k,
const int4 codebook_a_sizes,
const int codebook_stride
) {
int sms; int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0; int waves = 0;
@@ -346,27 +318,15 @@ void code1x16_matvec_cuda(
int threads = 32 * thread_m; int threads = 32 * thread_m;
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>( Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>(
(const int4*) A, (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
(const int4*) B, prob_k, codebook_a_sizes, codebook_stride);
(int4*) C,
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride
);
} }
void code2x8_matvec_cuda( void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B,
const void* __restrict__ A,
const void* __restrict__ B,
void* __restrict__ C, void* __restrict__ C,
const void* __restrict__ codebook, const void* __restrict__ codebook, int prob_m,
int prob_m, int prob_k, const int4 codebook_a_sizes,
int prob_k, const int codebook_stride) {
const int4 codebook_a_sizes,
const int codebook_stride
) {
int sms; int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0; int waves = 0;
@@ -379,29 +339,19 @@ void code2x8_matvec_cuda(
int blocks = ceildiv(prob_m, thread_m); int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m; int threads = 32 * thread_m;
int shared = 16 * (2 * 256 * 8 + 32 * 9); int shared = 16 * (2 * 256 * 8 + 32 * 9);
cudaFuncSetAttribute( cudaFuncSetAttribute(Code2x8MatVec,
Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code2x8MatVec<<<blocks, threads, shared, stream>>>( Code2x8MatVec<<<blocks, threads, shared, stream>>>(
(const int4*) A, (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
(const int4*) B, prob_k, codebook_a_sizes, codebook_stride);
(int4*) C,
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride
);
} }
void code1x16_dequant_cuda( void code1x16_dequant_cuda(
const void* __restrict__ A, const void* __restrict__ A, void* __restrict__ C,
void* __restrict__ C, const void* __restrict__ codebook, int prob_m, int prob_k,
const void* __restrict__ codebook, const int4 codebook_a_sizes, // cumulative sizes of A spanning each
int prob_m, // codebook, at most 3 long.
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
const int codebook_stride // as int4. const int codebook_stride // as int4.
) { ) {
int sms; int sms;
@@ -417,24 +367,20 @@ void code1x16_dequant_cuda(
int threads = 32 * thread_m; int threads = 32 * thread_m;
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code1x16Dequant<<<blocks, threads, 0, stream>>>( Code1x16Dequant<<<blocks, threads, 0, stream>>>(
(const int4*) A, (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
(int4*) C, codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
(const int4*) codebook, // most 3 long.
prob_m,
prob_k,
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
codebook_stride // as int4. codebook_stride // as int4.
); );
} }
// Dequantizes the code and codebook into weights. // Dequantizes the code and codebook into weights.
void code2x8_dequant_cuda( void code2x8_dequant_cuda(
const void* __restrict__ A, const void* __restrict__ A, void* __restrict__ C,
void* __restrict__ C, const void* __restrict__ codebook, int prob_m, int prob_k,
const void* __restrict__ codebook, const int4
int prob_m, codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
int prob_k, // most 3 long, corresponds to cols.
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
const int codebook_stride // as int4 const int codebook_stride // as int4
) { ) {
int sms; int sms;
@@ -451,50 +397,33 @@ void code2x8_dequant_cuda(
int shared = 16 * (2 * 256 * 8 + 32 * 9); int shared = 16 * (2 * 256 * 8 + 32 * 9);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cudaFuncSetAttribute( cudaFuncSetAttribute(Code2x8Dequant,
Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
);
Code2x8Dequant<<<blocks, threads, shared, stream>>>( Code2x8Dequant<<<blocks, threads, shared, stream>>>(
(const int4*) A, (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
(int4*) C, codebook_a_sizes, codebook_stride);
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride
);
} }
int codebook_stride(const torch::Tensor& codebooks) int codebook_stride(const torch::Tensor& codebooks) {
{
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
} }
void code1x16_matvec( void code1x16_matvec(
const torch::Tensor& A, const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C,
const torch::Tensor& B,
torch::Tensor& C,
const torch::Tensor& codebook, const torch::Tensor& codebook,
const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long. const int4 codebook_a_sizes // cumulative sizes of A spanning each
// codebook, at most 3 long.
) { ) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0); int prob_m = C.size(0);
int prob_k = B.size(0); int prob_k = B.size(0);
code1x16_matvec_cuda( code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
A.data_ptr(), codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
B.data_ptr(), codebook_stride(codebook));
C.data_ptr(),
codebook.data_ptr(),
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride(codebook)
);
} }
torch::Tensor code1x16_matmat( torch::Tensor code1x16_matmat(const torch::Tensor& input,
const torch::Tensor& input,
const torch::Tensor& codes, const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& scales, const torch::Tensor& scales,
@@ -503,22 +432,15 @@ torch::Tensor code1x16_matmat(
auto input_sizes = input.sizes(); auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2); auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)}); auto flat_input = input.reshape({-1, input.size(-1)});
auto flat_output = torch::empty({flat_input.size(0), out_features}, auto flat_output = torch::empty(
torch::TensorOptions() {flat_input.size(0), out_features},
.dtype(input.dtype()) torch::TensorOptions().dtype(input.dtype()).device(input.device()));
.device(input.device())
);
for (int i = 0; i < flat_input.size(0); ++i) { for (int i = 0; i < flat_input.size(0); ++i) {
auto input_vec = flat_input.index({i}); auto input_vec = flat_input.index({i});
auto output_vec = flat_output.index({i}); auto output_vec = flat_output.index({i});
code1x16_matvec( code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
codes.squeeze(2), codebook_a_sizes);
input_vec,
output_vec,
codebooks,
codebook_a_sizes
);
} }
flat_output *= scales.flatten().unsqueeze(0); flat_output *= scales.flatten().unsqueeze(0);
@@ -533,55 +455,35 @@ torch::Tensor code1x16_matmat(
return output; return output;
} }
void code2x8_matvec( void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B,
const torch::Tensor& A, torch::Tensor& C, const torch::Tensor& codebook,
const torch::Tensor& B, const int4 codebook_a_sizes) {
torch::Tensor& C,
const torch::Tensor& codebook,
const int4 codebook_a_sizes
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0); int prob_m = C.size(0);
int prob_k = B.size(0); int prob_k = B.size(0);
code2x8_matvec_cuda( code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
A.data_ptr(), codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
B.data_ptr(), 2 * codebook_stride(codebook));
C.data_ptr(),
codebook.data_ptr(),
prob_m,
prob_k,
codebook_a_sizes,
2 * codebook_stride(codebook)
);
} }
torch::Tensor code2x8_matmat( torch::Tensor code2x8_matmat(const torch::Tensor& input,
const torch::Tensor& input,
const torch::Tensor& codes, const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& scales, const torch::Tensor& scales,
const int4 codebook_a_sizes, const int4 codebook_a_sizes,
const std::optional<torch::Tensor>& bias const std::optional<torch::Tensor>& bias) {
) {
auto input_sizes = input.sizes(); auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2); auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)}); auto flat_input = input.reshape({-1, input.size(-1)});
auto flat_output = torch::empty({flat_input.size(0), out_features}, auto flat_output = torch::empty(
torch::TensorOptions() {flat_input.size(0), out_features},
.dtype(input.dtype()) torch::TensorOptions().dtype(input.dtype()).device(input.device()));
.device(input.device())
);
for (int i = 0; i < flat_input.size(0); ++i) { for (int i = 0; i < flat_input.size(0); ++i) {
auto input_vec = flat_input.index({i}); auto input_vec = flat_input.index({i});
auto output_vec = flat_output.index({i}); auto output_vec = flat_output.index({i});
code2x8_matvec( code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
codes.squeeze(2), codebook_a_sizes);
input_vec,
output_vec,
codebooks,
codebook_a_sizes
);
} }
flat_output *= scales.flatten().unsqueeze(0); flat_output *= scales.flatten().unsqueeze(0);
if (bias.has_value()) { if (bias.has_value()) {
@@ -596,21 +498,18 @@ torch::Tensor code2x8_matmat(
} }
// Accumulate the partition sizes. // Accumulate the partition sizes.
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
{
int4 cumulative_sizes; int4 cumulative_sizes;
auto cumulative_size = &cumulative_sizes.x; auto cumulative_size = &cumulative_sizes.x;
int i = 0; int i = 0;
int last = 0; int last = 0;
assert(codebook_partition_sizes.size(0) <= 4); assert(codebook_partition_sizes.size(0) <= 4);
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) {
{
*cumulative_size = codebook_partition_sizes[i].item<int>() + last; *cumulative_size = codebook_partition_sizes[i].item<int>() + last;
last = *cumulative_size; last = *cumulative_size;
} }
// fill in the rest with unreachable. // fill in the rest with unreachable.
for (; i < 4; ++i, ++cumulative_size) for (; i < 4; ++i, ++cumulative_size) {
{
*cumulative_size = last * 10; *cumulative_size = last * 10;
} }
return cumulative_sizes; return cumulative_sizes;
@@ -619,41 +518,36 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
} // namespace aqlm } // namespace aqlm
} // namespace vllm } // namespace vllm
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
torch::Tensor aqlm_gemm(
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& scales, const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes, const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias const std::optional<torch::Tensor>& bias) {
) int4 cumulative_sizes =
{ vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
int const entries = codebooks.size(1); int const entries = codebooks.size(1);
if (nbooks == 1 && entries == (1 << 16)) if (nbooks == 1 && entries == (1 << 16)) {
{ return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales,
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); cumulative_sizes, bias);
} }
if (nbooks == 2 && entries == (1 << 8)) if (nbooks == 2 && entries == (1 << 8)) {
{ return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales,
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); cumulative_sizes, bias);
} }
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
" entries is not currently supported.")
return {}; return {};
} }
torch::Tensor aqlm_dequant( torch::Tensor aqlm_dequant(const torch::Tensor& codes,
const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes const torch::Tensor& codebook_partition_sizes) {
) int4 cumulative_sizes =
{ vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
int const entries = codebooks.size(1); int const entries = codebooks.size(1);
@@ -670,43 +564,35 @@ torch::Tensor aqlm_dequant(
auto weights = torch::empty({out_features, in_features}, auto weights = torch::empty({out_features, in_features},
torch::TensorOptions() torch::TensorOptions()
.dtype(codebooks.dtype()) .dtype(codebooks.dtype())
.device(codebooks.device()) .device(codebooks.device()));
);
if (nbooks == 1 && entries == (1 << 16)) if (nbooks == 1 && entries == (1 << 16)) {
{ vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
vllm::aqlm::code1x16_dequant_cuda( codebooks.data_ptr(), out_features,
codes.data_ptr(), in_features, cumulative_sizes,
weights.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features,
cumulative_sizes,
vllm::aqlm::codebook_stride(codebooks)); vllm::aqlm::codebook_stride(codebooks));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.) // if you wanted to flip to scaling the weights, (though it's 30%-ish slower
// weights *= scales.index({"...", 0, 0}); // and not consistent with gemv implementation.) weights *=
// scales.index({"...", 0, 0});
return weights; return weights;
} }
if (nbooks == 2 && entries == (1 << 8)) if (nbooks == 2 && entries == (1 << 8)) {
{ vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
vllm::aqlm::code2x8_dequant_cuda( codebooks.data_ptr(), out_features,
codes.data_ptr(), in_features, cumulative_sizes,
weights.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features,
cumulative_sizes,
vllm::aqlm::codebook_stride(codebooks)); vllm::aqlm::codebook_stride(codebooks));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation) // if you wanted to flip to scaling the weights, (though it's 30%-ish slower
// weights *= scales.index({"...", 0, 0}); // and not consistent with gemv implementation) weights *=
// scales.index({"...", 0, 0});
return weights; return weights;
} }
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
" entries is not currently supported.")
return {}; return {};
} }

View File

@@ -1,11 +1,11 @@
/* /*
Adapted from https://github.com/mit-han-lab/llm-awq Adapted from https://github.com/mit-han-lab/llm-awq
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h Modified from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq, @article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, title={AWQ: Activation-aware Weight Quantization for LLM Compression and
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
journal={arXiv}, Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
year={2023}
} }
*/ */
@@ -14,8 +14,7 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
namespace vllm { namespace vllm {
namespace awq { namespace awq {
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
assert(false); assert(false);
#else #else
@@ -30,33 +29,40 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
static constexpr uint32_t TOP_MASK = 0x00f000f0; static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing // Note that the entire sequence only requires 1 shift instruction. This is
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. // thanks to the register packing format and the fact that we force our
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and // integers to be unsigned, and account for this in the fp16 subtractions. In
// elt_67 to fp16 without having to shift them to the bottom bits before hand. // addition, I exploit the fact that sub and fma have the same throughput in
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
// the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// immediately before required. // dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8; const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0]) : "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1]) : "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2]) : "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3]) : "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the // I use inline PTX below because I am not sure if the compiler will emit
// half2 ctor. In this case, I chose performance reliability over code readability. // float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer. // This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
@@ -71,13 +77,21 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// Finally, we construct the output numbers. // Finally, we construct the output numbers.
// Convert elt_01 // Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[0])
: "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23 // Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(h[1])
: "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45 // Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[2])
: "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67 // Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(h[3])
: "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result; return result;
#endif #endif

View File

@@ -1,15 +1,13 @@
/* /*
Adapted from https://github.com/mit-han-lab/llm-awq Adapted from https://github.com/mit-han-lab/llm-awq
@article{lin2023awq, @article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, title={AWQ: Activation-aware Weight Quantization for LLM Compression and
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
journal={arXiv}, Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
year={2023}
} }
*/ */
#include <torch/all.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include "dequantize.cuh" #include "dequantize.cuh"
@@ -20,26 +18,20 @@ namespace vllm {
namespace awq { namespace awq {
// Pack two half values. // Pack two half values.
static inline __device__ __host__ unsigned static inline __device__ __host__ unsigned __pack_half2(const half x,
__pack_half2(const half x, const half y) { const half y) {
unsigned v0 = *((unsigned short*)&x); unsigned v0 = *((unsigned short*)&x);
unsigned v1 = *((unsigned short*)&y); unsigned v1 = *((unsigned short*)&y);
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
template <int N> template <int N>
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( __global__ void __launch_bounds__(64)
int G, gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
int split_k_iters, half* __restrict__ A, int* __restrict__ B,
half* __restrict__ A,
int* __restrict__ B,
half* __restrict__ scaling_factors, half* __restrict__ scaling_factors,
int* __restrict__ zeros, int* __restrict__ zeros, int M, int IC,
int M, int OC, half* __restrict__ C) {
int IC,
int OC,
half* __restrict__ C)
{
// Only support matrix n = 64 or 128 // Only support matrix n = 64 or 128
assert(N == 64 || N == 128); assert(N == 64 || N == 128);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
@@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
static constexpr int row_stride = 2 * 32 * 8 / N; static constexpr int row_stride = 2 * 32 * 8 / N;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id bool ld_A_flag =
(blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp +
threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M; // bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A half* A_ptr =
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC A +
+ (((int)threadIdx.x) % (32 / 8)) * 8; (((int)blockIdx_y) / j_factors1 * 16 +
(((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) *
IC +
(((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) +
+ ((int)threadIdx.y) * (OC / 8) * (256 / N) (((int)threadIdx.x) / (N / 8)) * (OC / 8) +
+ (((int)threadIdx.x) / (N / 8)) * (OC / 8) (((int)blockIdx_y) % j_factors1) * (N / 8) +
+ (((int)blockIdx_y) % j_factors1) * (N / 8) (((int)threadIdx.x) % (N / 8)) * 1;
+ (((int)threadIdx.x) % (N / 8)) * 1;
// Why * 1 in the above line? // Why * 1 in the above line?
half* A_shared_ptr = A_shared half* A_shared_ptr = A_shared +
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8) ((int)threadIdx.y) * row_stride_warp * (32 + 8) +
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8) (((int)threadIdx.x) / (32 / 8)) * (32 + 8) +
+ (((int)threadIdx.x) % (32 / 8) ) * 8; (((int)threadIdx.x) % (32 / 8)) * 8;
half* B_shared_ptr = B_shared half* B_shared_ptr = B_shared +
+ ((int)threadIdx.y) * (row_stride / 2) * (N + 8) ((int)threadIdx.y) * (row_stride / 2) * (N + 8) +
+ (((int)threadIdx.x) / (N / 8)) * (N + 8) (((int)threadIdx.x) / (N / 8)) * (N + 8) +
+ (((int)threadIdx.x) % (N / 8)) * 8; (((int)threadIdx.x) % (N / 8)) * 8;
int* zeros_ptr = zeros int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) +
+ (((int)blockIdx_y) % j_factors1) * (N / 8) ((int)threadIdx.x) % (N / 8);
+ ((int)threadIdx.x) % (N / 8);
half* scaling_factors_ptr = scaling_factors half* scaling_factors_ptr = scaling_factors +
+ (((int)blockIdx_y) % j_factors1) * N (((int)blockIdx_y) % j_factors1) * N +
+ (((int)threadIdx.x) % (N / 8)) * 8; (((int)threadIdx.x) % (N / 8)) * 8;
half* C_ptr = C half* C_ptr =
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim C +
+ (((int)blockIdx_y) % j_factors1) * N static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ ((int)threadIdx.y) * (N / 2) + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) +
+ (((int)threadIdx.x) % 4) * 2; (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros // preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
@@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads(); __syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag) if (ld_A_flag) {
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
} } else {
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
} }
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); uint4 B_loaded_scale =
*(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/* /*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 &&
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x,
B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x,
B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
} }
*/ */
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16 // B: 32 x 136 (128+8) float16
// each warp: 32 x 4 // each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); // zero -> WB UINT4
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N) // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) *
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15)
// * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 *
// 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) *
// 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) *
// 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded =
*(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
// 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
// % (cta_N / 8)) * 8);
// - zero and * scale // - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. // TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); // q * scale - zero * scale.
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); asm volatile("sub.f16x2 %0, %1, %2;\n"
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); : "=r"(B_loaded_fp16.x)
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); : "=r"(B_loaded_fp16.x)
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.y)
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.y)
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.z)
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.z)
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.w)
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.w)
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/* /*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 ==
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n",
B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
} }
*/ */
// write back // write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) =
B_loaded_fp16;
} }
__syncthreads(); __syncthreads();
@@ -173,34 +194,43 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
{ {
unsigned int addr; unsigned int addr;
__asm__ __volatile__( __asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
"addr; }\n"
: "=r"(addr) : "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) : "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
); (((((int)threadIdx.x) & 15) * 40) +
((((int)threadIdx.x) >> 4) * 8)))));
__asm__ __volatile__( __asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16" "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n" "{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) : "=r"(((unsigned*)(A_shared_warp + 0))[0]),
: "r"(addr) "=r"(((unsigned*)(A_shared_warp + 0))[1]),
); "=r"(((unsigned*)(A_shared_warp + 0))[2]),
"=r"(((unsigned*)(A_shared_warp + 0))[3])
: "r"(addr));
} }
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
{ {
unsigned int addr; unsigned int addr;
__asm__ __volatile__( __asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
"addr; }\n"
: "=r"(addr) : "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
); (((int)threadIdx.y) * (N / 2))) +
(ax1_0 * 16))])) +
(((((int)threadIdx.x) & 15) * (N + 8)) +
((((int)threadIdx.x) >> 4) * 8)))));
__asm__ __volatile__( __asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n" "{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
: "r"(addr) "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
); "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]),
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr));
} }
} }
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
@@ -209,48 +239,110 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
"r"(((unsigned*)(A_shared_warp + 0))[1]),
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
} }
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
"r"(((unsigned*)(A_shared_warp + 0))[1]),
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
} }
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
"r"(((unsigned*)(A_shared_warp + 0))[3]),
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
} }
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
"r"(((unsigned*)(A_shared_warp + 0))[3]),
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
} }
#else #else
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) "%13};\n"
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
"r"(((unsigned*)(A_shared_warp + 0))[1]),
"r"(((unsigned*)(A_shared_warp + 0))[2]),
"r"(((unsigned*)(A_shared_warp + 0))[3]),
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
} }
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) "%13};\n"
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
"r"(((unsigned*)(A_shared_warp + 0))[1]),
"r"(((unsigned*)(A_shared_warp + 0))[2]),
"r"(((unsigned*)(A_shared_warp + 0))[3]),
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
} }
#endif #endif
@@ -261,24 +353,20 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
// TODO: Shang: Hoist loop invariance. // TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) { for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; int row_offset = (((int)blockIdx_y) / j_factors1) * 16 +
if (row_offset < M) ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
{ if (row_offset < M) {
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 +
local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
} }
} }
} }
#endif #endif
} }
__global__ void __launch_bounds__(64) dequantize_weights( __global__ void __launch_bounds__(64)
int* __restrict__ B, dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors,
half* __restrict__ scaling_factors, int* __restrict__ zeros, half* __restrict__ C, int G) {
int* __restrict__ zeros,
half* __restrict__ C,
int G
)
{
int j_factors1 = 4; int j_factors1 = 4;
int row_stride2 = 4; int row_stride2 = 4;
int split_k_iters = 1; int split_k_iters = 1;
@@ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights(
uint32_t B_loaded = *(uint32_t*)B_ptr2; uint32_t B_loaded = *(uint32_t*)B_ptr2;
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); asm volatile("sub.f16x2 %0, %1, %2;\n"
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); : "=r"(B_loaded_fp16.x)
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); : "=r"(B_loaded_fp16.x)
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); asm volatile("sub.f16x2 %0, %1, %2;\n"
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); : "=r"(B_loaded_fp16.y)
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.y)
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.z)
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.z)
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.w)
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.w)
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
*(uint4*)B_shared_ptr2 = B_loaded_fp16; *(uint4*)B_shared_ptr2 = B_loaded_fp16;
@@ -329,14 +433,10 @@ __global__ void __launch_bounds__(64) dequantize_weights(
} // namespace awq } // namespace awq
} // namespace vllm } // namespace vllm
torch::Tensor awq_dequantize( torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _scaling_factors,
torch::Tensor _zeros, torch::Tensor _zeros, int64_t split_k_iters,
int split_k_iters, int64_t thx, int64_t thy) {
int thx,
int thy)
{
int in_c = _kernel.size(0); int in_c = _kernel.size(0);
int qout_c = _kernel.size(1); int qout_c = _kernel.size(1);
int out_c = qout_c * 8; int out_c = qout_c * 8;
@@ -362,12 +462,15 @@ torch::Tensor awq_dequantize(
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); auto options = torch::TensorOptions()
.dtype(_scaling_factors.dtype())
.device(_scaling_factors.device());
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>()); auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>()); auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>()); auto scaling_factors =
reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>()); auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
dim3 num_blocks(x_blocks, y_blocks); dim3 num_blocks(x_blocks, y_blocks);
@@ -386,26 +489,26 @@ torch::Tensor awq_dequantize(
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
// assume that batch_size < 16 for now // assume that batch_size < 16 for now
torch::Tensor awq_gemm( torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _in_feats, torch::Tensor _scaling_factors, torch::Tensor _zeros,
torch::Tensor _kernel, int64_t split_k_iters) {
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters)
{
int num_in_feats = _in_feats.size(0); int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1); int num_in_channels = _in_feats.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); auto options = torch::TensorOptions()
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); .dtype(_in_feats.dtype())
.device(_in_feats.device());
at::Tensor _out_feats =
torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
int num_out_feats = _out_feats.size(-2); int num_out_feats = _out_feats.size(-2);
int num_out_channels = _out_feats.size(-1); int num_out_channels = _out_feats.size(-1);
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>()); auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>()); auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>()); auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>()); auto scaling_factors =
reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>()); auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
int group_size = num_in_channels / _scaling_factors.size(0); int group_size = num_in_channels / _scaling_factors.size(0);
@@ -419,28 +522,28 @@ torch::Tensor awq_gemm(
throw std::invalid_argument("OC is not multiple of Group size"); throw std::invalid_argument("OC is not multiple of Group size");
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (num_out_channels % 128 == 0) if (num_out_channels % 128 == 0) {
{
int j_factors1 = num_out_channels / 128 / 1; int j_factors1 = num_out_channels / 128 / 1;
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32 // threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2] // threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2); dim3 threads_per_block(32, 2);
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>( vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128>
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, <<<num_blocks, threads_per_block, 0, stream>>>(
num_out_channels, out_feats); group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
} num_in_feats, num_in_channels, num_out_channels, out_feats);
else if (num_out_channels % 64 == 0) } else if (num_out_channels % 64 == 0) {
{
int j_factors1 = num_out_channels / 64 / 1; int j_factors1 = num_out_channels / 64 / 1;
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 *
split_k_iters);
// threadIdx.x: 32 // threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2] // threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2); dim3 threads_per_block(32, 2);
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>( vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64>
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, <<<num_blocks, threads_per_block, 0, stream>>>(
num_out_channels, out_feats); group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
num_in_feats, num_in_channels, num_out_channels, out_feats);
} }
return _out_feats.sum(0); return _out_feats.sum(0);
} }

View File

@@ -0,0 +1,115 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <cmath>
#include "../../dispatch_utils.h"
#include "../../reduction_utils.cuh"
static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
static const float i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
static const float i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
// round
float dst = std::nearbyint(x);
// saturate
dst = std::clamp(dst, i8_min, i8_max);
return static_cast<int8_t>(dst);
#else
// CUDA path
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
#endif
}
namespace vllm {
template <typename scalar_t, typename scale_type>
__global__ void static_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
}
}
template <typename scalar_t, typename scale_type>
__global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
float absmax_val = 0.0f;
float const zero = 0.0f;
for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = static_cast<float>(input[token_idx * hidden_size + i]);
val = val > zero ? val : -val;
absmax_val = val > absmax_val ? val : absmax_val;
}
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
__shared__ float block_absmax_val;
if (tid == 0) {
block_absmax_val = block_absmax_val_maybe;
scale[token_idx] = block_absmax_val / 127.0f;
}
__syncthreads();
float const tmp_scale = 127.0f / block_absmax_val;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
}
}
} // namespace vllm
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& scale) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1);
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scale.data_ptr<float>(), hidden_size);
});
}
void dynamic_scaled_int8_quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor& scales) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scales.data_ptr<float>(), hidden_size);
});
}

View File

@@ -0,0 +1,346 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/visitor_load.hpp from
// https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either
// row/column or scalar broadcasting where the tensor being loaded from is
// always passed in via a device pointer. This lets one compiled kernel handle
// all cases of per-tensor or per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graph
// breaks when moving scales to the CPU.
//
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
#include "cute/tensor.hpp"
namespace cutlass::epilogue::threadblock {
using namespace cute;
using namespace detail;
template<
class ThreadMap,
class Element,
class StrideMNL
>
struct VisitorRowOrScalarBroadcast {
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast.
struct Arguments {
Element const* ptr_row = nullptr;
bool row_broadcast = true;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
struct SharedStorage {};
// Global load type
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
CUTLASS_HOST_DEVICE
VisitorRowOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params_ptr(&params) { }
Params const* params_ptr;
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
struct Callbacks : EmptyCallbacks {
CUTLASS_DEVICE
Callbacks(
GTensor&& tC_gRow,
RTensor&& tC_rRow,
CTensor&& tC_cRow,
ProblemShape problem_shape,
Params const* params_ptr
):
tC_gRow(cute::forward<GTensor>(tC_gRow)),
tC_rRow(cute::forward<RTensor>(tC_rRow)),
tC_cRow(cute::forward<CTensor>(tC_cRow)),
n(get<1>(problem_shape)),
params_ptr(params_ptr) { }
GTensor tC_gRow;
RTensor tC_rRow;
CTensor tC_cRow;
Params const* params_ptr;
int n;
// This function is modified from VisitorRowBroadcast
CUTLASS_DEVICE void
begin_epilogue() {
clear(tC_rRow);
auto src_v = filter(tC_gRow);
auto coord_v = filter(tC_cRow);
auto dst_v = filter(tC_rRow);
if (params_ptr->row_broadcast) {
// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
bool guard = get<1>(coord_v(i)) < n;
cutlass::arch::global_load<VecType, sizeof(VecType)>(
dst_v(i), (void const*)&src_v(i), guard);
}
} else {
// In this case we are loading from a scalar and broadcasting
VecType filled_vec;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < VecLength; i++) {
reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
if (get<1>(coord_v(i)) < n) {
dst_v(i) = filled_vec;
}
}
}
}
template <class ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE auto // returns an Array
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
return rRow_frg(column_idx);
}
};
template <class ProblemShape>
CUTLASS_DEVICE auto
get_callbacks(
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
ProblemShape problem_shape
) {
Tensor mRow = make_tensor(
make_gmem_ptr(params_ptr->ptr_row),
problem_shape,
params_ptr->dRow);
// VECTOR, FRAGMENT_COLUMN
Tensor tC_gRow = recast<VecType>(
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
)(_,_,_0{},_0{},_0{},_0{});
Tensor tC_rRow = make_tensor_like(tC_gRow);
// Generate the pred tensor
Tensor cRow = make_identity_tensor(mRow.shape());
Tensor tC_cRow = outer_partition(
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
Shape<Int<VecLength>>{},
(_0{})
);
return Callbacks<
decltype(tC_gRow), decltype(tC_rRow),
decltype(tC_cRow), ProblemShape>(
cute::move(tC_gRow),
cute::move(tC_rRow),
cute::move(tC_cRow),
problem_shape,
params_ptr
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
template<
class ThreadMap,
class Element,
class StrideMNL = Stride<_1,_0,_0>
>
struct VisitorColOrScalarBroadcast {
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast.
struct Arguments {
Element const* ptr_col = nullptr;
bool col_broadcast = true;
StrideMNL dCol = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
struct SharedStorage { };
CUTLASS_HOST_DEVICE
VisitorColOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params_ptr(&params) { }
Params const* params_ptr;
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
struct Callbacks : EmptyCallbacks {
CUTLASS_DEVICE
Callbacks(
GTensor&& tC_gCol,
RTensor&& tC_rCol,
CTensor&& tC_cCol,
ProblemShape problem_shape,
Params const* params_ptr
):
tC_gCol(cute::forward<GTensor>(tC_gCol)),
tC_rCol(cute::forward<RTensor>(tC_rCol)),
tC_cCol(cute::forward<CTensor>(tC_cCol)),
m(get<0>(problem_shape)),
params_ptr(params_ptr) { }
GTensor tC_gCol;
RTensor tC_rCol;
CTensor tC_cCol;
Params const* params_ptr;
int m;
// This function is modified from VisitorColBroadcast
CUTLASS_DEVICE void
begin_epilogue() {
clear(tC_rCol);
Tensor pred = make_tensor<bool>(shape(tC_gCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tC_cCol(i)) < m;
}
if (params_ptr->col_broadcast) {
// In this case we are loading from a column vector and broadcasting
copy_if(pred, tC_gCol, tC_rCol);
} else {
// In this case we are loading from a scalar and broadcasting
auto dst_v = filter(tC_rCol);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(dst_v); ++i) {
if (pred(i)) {
dst_v(i) = *(params_ptr->ptr_col);
}
}
}
}
template <class ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE auto // returns an Array
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
Array<Element, FragmentSize> frg_col;
frg_col.fill(tC_rCol(row_idx,iter_idx));
return frg_col;
}
};
template <class ProblemShape>
CUTLASS_DEVICE auto
get_callbacks(
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
ProblemShape problem_shape
) {
Tensor mCol = make_tensor(
make_gmem_ptr(params_ptr->ptr_col),
problem_shape,
params_ptr->dCol);
// VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
Tensor tC_gCol = group_modes<1,4>(
ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
Tensor tC_rCol = make_tensor_like(tC_gCol);
// Generate the pred tensor
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tC_cCol = group_modes<1,4>(
ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
return Callbacks<
decltype(tC_gCol), decltype(tC_rCol),
decltype(tC_cCol), ProblemShape>(
cute::move(tC_gCol),
cute::move(tC_rCol),
cute::move(tC_cCol),
problem_shape,
params_ptr
);
}
};
}

View File

@@ -0,0 +1,389 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
// from https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either row/column or scalar broadcasting
// where the tensor being loaded from is always passed in via a device pointer.
// This lets one compiled kernel handle all cases of per-tensor or
// per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graphs
// breaks when moving scales to the CPU.
//
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/arch/barrier.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
namespace cutlass::epilogue::fusion {
using namespace cute;
using namespace detail;
// Row vector broadcast
template<
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_0,_1,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90RowOrScalarBroadcast {
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
static_assert(
(cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias
(cute::is_same_v<StrideMNL, Stride<_0,_1,int>>)); // batched row vector broadcast
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
struct SharedStorage {
alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row;
};
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_row is null.
struct Arguments {
Element const* ptr_row = nullptr;
bool row_broadcast = true;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params(params),
smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { }
Params params;
Element* smem_row;
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return true;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
}
template <int EpiTiles, class GTensor, class STensor>
struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
CUTLASS_DEVICE
ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params)
: gRow(cute::forward<GTensor>(gRow)),
sRow(cute::forward<STensor>(sRow)),
params(params) {}
GTensor gRow; // (CTA_M,CTA_N)
STensor sRow; // (CTA_M,CTA_N,PIPE)
Params const& params;
CUTLASS_DEVICE void
begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {
if (params.ptr_row == nullptr) {
return;
}
if (issue_tma_load) {
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8;
cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes);
// Issue the TMA bulk copy
auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr);
// Filter so we don't issue redundant copies over stride-0 modes
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index)));
}
}
};
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>(
cute::move(gRow), cute::move(sRow), params);
}
template <int EpiTiles, class RTensor, class STensor>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params)
: tCrRow(cute::forward<RTensor>(tCrRow)),
tCsRow(cute::forward<STensor>(tCsRow)),
params(params) {}
RTensor tCrRow; // (CPY,CPY_M,CPY_N)
STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
Params const& params;
CUTLASS_DEVICE void
previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {
if (!params.row_broadcast) {
fill(tCrRow, *(params.ptr_row));
return;
}
if (epi_m == 0) { // Assumes M-major subtile loop
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow));
}
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_row;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_row[i] = tCrRow(epi_v * FragmentSize + i);
}
return frg_row;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N)
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>(
cute::move(tCrRow), cute::move(tCsRow), params);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_1,_0,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90ColOrScalarBroadcast {
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
static_assert(
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
struct SharedStorage { };
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_col is null.
struct Arguments {
Element const* ptr_col = nullptr;
bool col_broadcast = true;
StrideMNL dCol = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return (!params.col_broadcast && *(params.ptr_col) == Element(0));
}
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params(params) { }
Params params;
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template<class GTensor, class RTensor>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params)
: tCgCol(cute::forward<GTensor>(tCgCol)),
tCrCol(cute::forward<RTensor>(tCrCol)),
params(params) {}
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params const& params;
CUTLASS_DEVICE void
begin() {
if (!params.col_broadcast) {
fill(tCrCol, *(params.ptr_col));
return;
}
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_aligned(filter(tCgCol), filter(tCrCol));
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
}
return frg_col;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
return ConsumerStoreCallbacks<decltype(tCgCol), decltype(tCrCol)>(
cute::move(tCgCol), cute::move(tCrCol), params);
}
};
}

View File

@@ -0,0 +1,12 @@
#pragma once
#include "cutlass/cutlass.h"
/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
TORCH_CHECK(status == cutlass::Status::kSuccess, \
cutlassGetStatusString(status)) \
}

View File

@@ -0,0 +1,329 @@
#include <stddef.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
// clang-format on
using namespace cute;
/*
This defines a quantized GEMM operation with dequantized output, similar to
torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
namespace {
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm75_to_sm80 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm80_to_sm89 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm89_to_sm90 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Operator =
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
cutlass::arch::OpMultiplyAddSaturate,
cutlass::arch::OpMultiplyAdd>::type;
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
Stride<int64_t, Int<1>, Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1>;
// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType =
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
float, cutlass::layout::RowMajor, 4,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
TileShape, WarpShape, InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
MainLoopStages, Operator,
1 /* epilogue stages */
>::GemmKernel>;
// clang-format on
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
};
template <typename Gemm>
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);
cutlass::gemm::GemmCoord problem_size{m, n, k};
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
auto a_scales_ptr = a_scales.data_ptr<float>();
auto b_scales_ptr = b_scales.data_ptr<float>();
using ScaleAArgs = typename Gemm::ScaleA::Arguments;
using ScaleBArgs = typename Gemm::ScaleB::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args};
typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args,
evt0_compute_args};
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
typename Gemm::EVTD::Arguments epilogue_args{
evt1_compute_args,
d_args,
};
typename Gemm::Op::Arguments args{
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
problem_size, // problem size
1, // batch count
epilogue_args,
a_ptr,
b_ptr,
nullptr,
nullptr,
0,
0,
0,
0,
lda,
ldb,
ldc,
ldc};
// Launch the CUTLASS GEMM kernel.
typename Gemm::Op gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace.get(), stream);
CUTLASS_CHECK(status);
}
} // namespace
void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
}
}
void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
}
}
void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
} else {
assert(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::bfloat16_t, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::half_t, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
}
}
}

View File

@@ -0,0 +1,340 @@
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <iostream>
#include <sstream>
#include <vector>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "broadcast_load_epilogue_c3x.hpp"
#include "common.hpp"
// clang-format on
using namespace cute;
/*
This defines a quantized GEMM operation with dequantized output, similar to
torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for
NVIDIA GPUs with sm90a (Hopper) or later.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
namespace {
uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm90_or_later : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
template <typename ElementAB_, typename ElementD_, typename TileShape,
typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
ElementD, EpilogueSchedule>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
Stride<Int<1>, Int<0>, Int<0>>>;
using ScaleBDescriptor =
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
EpilogueDescriptor, float>;
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
using ElementC = void;
using StrideC = StrideD;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
EpilogueSchedule, EVTCompute1>::CollectiveOp;
static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage);
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(CEStorageSize)>;
// clang-format off
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementAB, cutlass::layout::RowMajor, 16,
ElementAB, cutlass::layout::ColumnMajor, 16,
ElementAcc, TileShape, ClusterShape,
Stages,
KernelSchedule>::CollectiveOp;
// clang-format on
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>>;
struct GemmKernel : public KernelType {};
};
template <typename Gemm>
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
using StrideC = typename Gemm::StrideC;
StrideA a_stride{lda, Int<1>{}, Int<0>{}};
StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
using GemmKernel = typename Gemm::GemmKernel;
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
b_stride};
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride};
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape, mainloop_args, epilogue_args};
using ScaleA_Args = typename Gemm::ScaleA::Arguments;
using ScaleB_Args = typename Gemm::ScaleB::Arguments;
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
args.epilogue.thread = {a_args, {b_args}};
// Launch the CUTLASS GEMM kernel.
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
CUTLASS_CHECK(status);
}
template <typename InType, typename OutType, int32_t M>
struct sm90_fp8_config {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
};
template <typename InType, typename OutType>
struct sm90_fp8_config<InType, OutType, 128> {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
};
template <typename InType, typename OutType>
struct sm90_fp8_config<InType, OutType, 64> {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
};
} // namespace
template <typename InType, typename OutType>
void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
using Cutlass3xGemmDefault =
typename sm90_fp8_config<InType, OutType, 0>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_fp8_config<InType, OutType, 64>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_fp8_config<InType, OutType, 128>::Cutlass3xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
if (mp2 <= 64) {
// m in [1, 64]
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM64>(
out, a, b, a_scales, b_scales);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM128>(
out, a, b, a_scales, b_scales);
} else {
// m in (128, inf)
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmDefault>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8);
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule =
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<
cutlass_3x_gemm<int8_t, cutlass::bfloat16_t, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<
cutlass_3x_gemm<int8_t, cutlass::half_t, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>>(
out, a, b, a_scales, b_scales);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
}
#endif

View File

@@ -0,0 +1,75 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
#endif
void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
int32_t major_capability;
int32_t minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
0);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
0);
int32_t version_num = major_capability * 10 + minor_capability;
// Checks for conformality
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
if (version_num >= 90) {
// Hopper
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
#else
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
#endif
} else if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);
} else if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
} else {
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales);
}
}

Some files were not shown because too many files have changed in this diff Show More