Compare commits

...

152 Commits

Author SHA1 Message Date
Isotr0py
5506435419 [Misc] Clean up Gemma4 implementation (#38872)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
2026-04-03 05:47:02 +00:00
Yifan Qiao
311c981647 [MRV2][KVConnector] Fix missing build_connector_worker_meta (#38698)
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
2026-04-03 08:42:52 +03:00
Li, Jiang
21d7ecc5b0 [CI/Build] Add audio deps in Dockerfile.cpu (#38876)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
2026-04-03 05:05:14 +00:00
Aaron Hao
4729b90838 [Bug] Add e_score_correction_bias to SKIP_TENSORS (#38746)
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
2026-04-02 21:15:05 -07:00
shunting314
8b141ed8c3 full cudagraph for flex-attn (#36298)
Signed-off-by: shunting314 <shunting@meta.com>
2026-04-02 21:15:01 -07:00
Varun Sundar Rabindranath
2ad7c0335f [Model] Add Phi4ForCausalLMV for microsoft/Phi-4-reasoning-vision-15B (#38306)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
2026-04-02 21:14:57 -07:00
Bowen Bao
201d2ea5bf [CI][ROCm] Add Qwen3.5-35B-A3B-MXFP4 model eval into CI (#38664)
Signed-off-by: Bowen Bao <bowenbao@amd.com>
2026-04-03 04:05:45 +00:00
Bowen Bao
103f0de565 [ROCm][Quantization][1/N] Refactor quark_moe w_mxfp4 w/ oracle (#38774)
Signed-off-by: Bowen Bao <bowenbao@amd.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
2026-04-03 03:29:57 +00:00
wliao2
32e0c0bfa2 refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)
Signed-off-by: Liao, Wei <wei.liao@intel.com>
2026-04-03 11:21:47 +08:00
Itay Etelis
4a06e1246e [Perf] Batch KV cache swap copies via cuMemcpyBatchAsync (#38460)
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
2026-04-03 03:13:23 +00:00
Carl Y
3bc2734dd0 [Kernel] Fuse FP8 output quantization into merge_attn_states (#36518)
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
2026-04-03 01:47:04 +00:00
Carl Y
1f5ec2889c [mla] Support fused FP8/NVFP4 output quantization in MLA attention (#35792) (#36205)
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl Y <4531192+carlyou@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 21:16:11 -04:00
Yan Ma
ee3cf45739 [XPU] Initial support for GDN attention on Qwen3-next/Qwen3.5 (#33657)
Signed-off-by: Yan Ma <yan.ma@intel.com>
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Co-authored-by: Chendi Xue <chendi.xue@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
2026-04-03 08:59:11 +08:00
Matthew Bonanni
05e68e1f81 [CI] Fix test_nixl_connector (#38838) 2026-04-02 17:52:13 -07:00
Vadim Gimpelson
771913e4a0 [Bugfix] Fix NVFP4+MTP crash: force unquantized mtp.fc for Qwen3.5 (#38832)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
2026-04-03 04:45:57 +04:00
1096125073
71a9125c67 [New Model]: add support for telechat3 (#38510)
Signed-off-by: xiayongqiang <xiayq1@chinatelecom.cn>
Co-authored-by: xiayongqiang <xiayq1@chinatelecom.cn>
2026-04-03 08:26:22 +08:00
Nicolò Lucchesi
66e86f1dbd [Kernel] Mamba support different layout for Conv state (#37416) 2026-04-03 01:50:09 +02:00
Michael
bb39382b2b [Bugfix]: Fix Gemma4ToolParser.__init__() missing tools parameter (#38847)
Signed-off-by: Michael Hospedales <hospedales@me.com>
2026-04-02 14:35:19 -07:00
zhanqiuhu
7b743ba953 [CI] Fix: pass string cache_dtype in test_register_kv_caches (#38836) 2026-04-02 19:42:09 +00:00
Stefano Castagnetta
188defbd0b [CI] Add flashinfer.py to attention test source deps (#38792)
Signed-off-by: Stefano Castagnetta <scastagnetta@nvidia.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
2026-04-02 19:24:29 +00:00
Luciano Martins
08ed2b9688 feat(models): implement Google Gemma 4 architecture support (MoE, Multimodal, Reasoning, Tool-Use) (#38826)
Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Luciano Martins <lucianomartins@google.com>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
2026-04-02 11:13:28 -07:00
Yanan Cao
ecd5443dbc Bump helion dependency from 0.3.2 to 0.3.3 (#38062)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 10:59:33 -07:00
Stefano Castagnetta
58262dec6e [Bugfix] Fix test mocks after SM100 restriction in #38730 (#38791)
Signed-off-by: Stefano Castagnetta <scastagnetta@nvidia.com>
Co-authored-by: Claude <noreply@anthropic.com>
2026-04-02 13:12:58 -04:00
Lucas Wilkinson
cb3935a8fc [FA4] Update flash-attention to latest upstream FA4 (#38690)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
2026-04-02 17:02:37 +00:00
Bowen Bao
82a006beeb [CI][ROCm] Add gpt-oss w4a8 in CI (#38292)
Signed-off-by: Bowen Bao <bowenbao@amd.com>
2026-04-03 00:06:01 +08:00
wang.yuqi
a9b4f07ba2 [Frontend] Re-enable running MaxSim on GPU (#38620)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
2026-04-03 00:03:13 +08:00
Koushik Dutta
d9408ffba3 Triton MLA perf fixes (#33529)
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
2026-04-02 09:40:01 -04:00
Yusuf Mohammad
16a65e4173 [Bugfix] Enable batch-invariant Triton matmul on all Ampere GPUs (SM 8x) (#38427)
Signed-off-by: yusuf <yusufmohammad@live.com>
Signed-off-by: yusuf <yusuf@deeplearningmachine.mynet>
Signed-off-by: Yusuf Mohammad <79484377+YM2132@users.noreply.github.com>
Signed-off-by: <>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: yusuf <yusuf@deeplearningmachine.mynet>
2026-04-02 09:29:58 -04:00
bsliu
c0817e4d39 [Model] Add support for Cheers multimodal model (#38788)
Signed-off-by: bsliu <1187291748@qq.com>
Signed-off-by: 吴炳贤 <wubingxian24@mails.ucas.ac.cn>
2026-04-02 21:01:40 +08:00
Harry Mellor
dfe5e31689 Don't compile vision encoder for Transformers backend (#30518)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2026-04-02 12:42:29 +00:00
JartX
2ce3d0ce36 [Feature] KV cache per-token-head INT8/FP8 quantization (#38378)
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: yangyang4991 <yangyang4991@gmail.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
2026-04-02 08:13:26 -04:00
Jiangyun Zhu
4eefbf9609 [Perf] fuse kernels in gdn (#37813)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
2026-04-02 11:52:18 +00:00
vllmellm
551b3fb39f [ROCm] Enable VLLM triton FP8 moe for gfx1201, tuned for Qwen3-30B-A3B-FP8 tp=2 and Qwen/Qwen3.5-35B-A3B-FP8 tp=2 (#38086)
Signed-off-by: big-yellow-duck <jeffaw99@hotmail.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
2026-04-02 08:13:42 +00:00
Li, Jiang
c6f722b93e [CPU] Support gelu act in cpu_fused_moe (#38770)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
2026-04-02 14:14:32 +08:00
Xin Yang
9bd7231106 Revert "[Kernel] Add gpt-oss Router GEMM kernel (#37205)" (#38778)
Signed-off-by: Xin Yang <xyangx@amazon.com>
2026-04-01 22:02:32 -07:00
Yanan Cao
73f48ce559 [Kernel] [Helion] Use warning_once in get_gpu_name to prevent log spam (#38743)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
Co-authored-by: Claude Sonnet 4 <noreply@anthropic.com>
2026-04-01 21:30:31 -07:00
Gregory Shtrasberg
3aab680e3e [ROCm][Bugfix] Fix ROCm runtime failure due to missing symbol (#38750)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: tjtanaavllm <tunjian.tan@amd.com>
2026-04-01 21:30:11 -07:00
Sergey Zinchenko
5a2d420c17 [Bugfix] Use dedicated MM processor cache in /tokenize to prevent sender-cache pollution (#38545)
Signed-off-by: Sergey Zinchenko <sergey.zinchenko.rnd@gmail.com>
2026-04-01 21:14:49 -07:00
Benjamin Chislett
5f96f9aff1 [Perf] DSV3.2 Indexer Fused Weights Projection (#38684)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
2026-04-02 03:34:49 +00:00
Luka Govedič
694449050f Fix multiline-format string for python 3.10 (#38739)
Signed-off-by: Luka Govedic <luka.govedic@gmail.com>
2026-04-02 03:19:35 +00:00
Nick Hill
6241521dd2 [BugFix] Fix precommit breakage due to conflicting in-flight merges (#38759)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
2026-04-01 15:35:06 -07:00
Kevin H. Luu
1785dc5501 Revert "[Bugfix] Fix Qwen3CoderToolParser anyOf/oneOf type resolution for nullable params (#37831)" (#38751) 2026-04-02 06:34:28 +08:00
Chang Su
54500546ac [Bugfix] Preserve original ImportError in gRPC server entrypoint (#38673)
Signed-off-by: Chang Su <chang.s.su@oracle.com>
2026-04-01 22:16:44 +00:00
Jeffrey Wang
de5e6c44c6 [Feat][Executor] Introduce RayExecutorV2 (#36836)
Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>
2026-04-01 14:34:29 -07:00
yzong-rh
cb268e4e55 [Refactor] Simplify FutureWrapper in MultiprocExecutor (#38644)
Signed-off-by: Yifan <yzong@redhat.com>
Signed-off-by: Yifan Zong <yzong@redhat.com>
2026-04-01 21:28:26 +00:00
Stefano Castagnetta
6183cae1bd [Bugfix] Restrict TRTLLM attention to SM100, fixing GB300 (SM103) hang (#38730)
Signed-off-by: Stefano Castagnetta <scastagnetta@nvidia.com>
2026-04-01 12:08:40 -07:00
Monishver
c09ad767cd Feature/silu block quant fusion v1 (#32996)
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
2026-04-01 18:50:43 +00:00
Wentao Ye
c9a9db0e02 [Compile] Fix nvfp4 compile warning (#38573)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2026-04-01 18:28:57 +00:00
Chauncey
cbe7d18096 [Misc] Rename think_start_str/think_end_str to reasoning_start_str/reasoning_end_str (#38242)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
2026-04-01 09:56:45 -07:00
Michael Goin
db5d0719e1 [Kernel] Add MXFP8 to Marlin GEMM/MoE and refactor Mxfp8LinearOp (#34664)
Signed-off-by: mgoin <mgoin64@gmail.com>
2026-04-01 09:41:42 -07:00
yzong-rh
dc0428ebb8 [NIXL][BUG] Fix Triton heterogeneous TP (#37940)
Signed-off-by: Yifan <yzong@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
2026-04-01 17:23:15 +02:00
Jesus Talavera
148c2072ec Add ibm-granite/granite-vision-3.3-2b to supported models documentation (#38714)
Signed-off-by: Jesus Talavera <jesus.talavera@ibm.com>
2026-04-01 08:22:25 -07:00
majianhan
2f5c3c1ec0 [Misc] Fix docstring typo: buildin -> builtin (#38722)
Co-authored-by: majianhan <majianhan@kylinos.cn>
2026-04-01 07:39:46 -07:00
Fynn Schmitt-Ulms
fa246d5231 Fix shape comment in extract_hidden_states example (#38723)
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
2026-04-01 07:29:33 -07:00
bnellnm
7cf56a59a2 [MoE Refactor] Make SharedExperts class for use with DefaultMoERunner (#35153)
Signed-off-by: Bill Nell <bnell@redhat.com>
2026-04-01 09:44:08 -04:00
Elvir Crnčević
5e30e9b9a9 [Bugfix] Revert "Zero-init MLA attention output buffers to prevent NaN from CUDA graph padding" (#38359)
Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
2026-04-01 09:11:10 -04:00
손세정
582340f273 [Bugfix] Fix Qwen3CoderToolParser anyOf/oneOf type resolution for nullable params (#37831)
Signed-off-by: AAISSJ <maze0717@g.skku.edu>
Signed-off-by: <>
Co-authored-by: 세덩 <saison@sedeong-ui-MacBookAir.local>
2026-04-01 20:22:29 +08:00
yjz
992368522f [KVTransfer] Fix TpKVTopology.is_kv_replicated equality case (#38179)
Signed-off-by: JianDan0212 <zhangyj0212@gmail.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
2026-04-01 12:41:49 +02:00
Juan Pérez de Algaba
58ee614221 (security) Enforce frame limit in VideoMediaIO (#38636)
Signed-off-by: jperezde <jperezde@redhat.com>
2026-04-01 10:23:45 +00:00
Harry Mellor
f9f6a9097a Add verified label to trigger pre-commit (#38708)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2026-04-01 02:31:02 -07:00
Zhanda Zhu
c75a313824 [Perf] triton bilinear_pos_embed kernel for ViT (#37948)
Signed-off-by: Zhanda Zhu <zhandazhu@gmail.com>
2026-04-01 01:52:02 -07:00
Lukas Geiger
4f6eed3bd4 [Core] Simplify multimodal masking (#34246)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
2026-04-01 01:18:22 -07:00
Li, Jiang
36d7f19897 [CPU] Support head_size 512 in cpu_attn (#38676)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
2026-04-01 05:42:27 +00:00
Jeffrey Wang
2d725b89c5 [Bugfix] Lazy import diskcache to avoid sqlite3/libstdc++ ImportError at startup (#38649)
Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>
2026-04-01 05:31:20 +00:00
Augusto Yao
ef53395e2c [bugfix] do not add extra linebreak for score/rerank with chat template (#38617)
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: wang.yuqi <noooop@126.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
2026-04-01 04:50:07 +00:00
Lucas Wilkinson
eb47454987 [Bugfix][MLA] Add logits size budget to sparse indexer prefill chunking (#36178)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
2026-04-01 00:15:53 -04:00
Matthew Bonanni
116f4be405 [1/N][Cleanup] Standardize on use of is_quantized_kv_cache (#38659)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
2026-04-01 04:08:01 +00:00
Wentao Ye
7b01d97a22 [Perf] Optimize mean pooling using chunks and index_add, 5.9% E2E throughput improvement (#38559)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2026-04-01 03:54:58 +00:00
HarshRathva
17b72fd1c8 Fix priority preemption regression test in scheduler (#37051)
Signed-off-by: HarshRathva <harshrathvaai@gmail.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
2026-04-01 06:36:12 +03:00
Samu Tamminen
c49497726b [ROCm][perf] Shuffle KV cache to use paged_attention_common (#32914)
Signed-off-by: Samu Tamminen <stammine@amd.com>
Co-authored-by: Tuukka Sarvi <tuukka.sarvi@amd.com>
2026-04-01 03:30:19 +00:00
Ben Browning
cb0b443274 [Misc] Add 20 regression tests for 11 tool parser bug fixes (#38172)
Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
2026-04-01 03:00:31 +00:00
Luka Govedič
40bb175027 [vLLM IR] 1/N Implement IR skeleton and rms_norm op (#33825)
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: chzhang <chaojun.zhang@intel.com>
Signed-off-by: Luka Govedic <luka.govedic@gmail.com>
Co-authored-by: Xinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: Chaojun Zhang <chaojun.zhang@intel.com>
Co-authored-by: Luka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
2026-03-31 22:15:05 -04:00
Elvir Crnčević
0fab52f0aa Fix NaN from stale FP4 scale padding in create_fp4_scale_tensor (#38148)
Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
2026-03-31 19:14:59 -07:00
Yifan Qiao
91e4521f9f [Feat][v1] Simple yet General CPU KV Cache Offloading (#37160)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
2026-03-31 17:58:37 -07:00
Stig-Arne Grönroos
31a719bcd3 [ROCm][perf] fix Aiter sparse MLA with MTP>1 (#37887)
Signed-off-by: Stig-Arne Grönroos <stig-arne.gronroos@amd.com>
Signed-off-by: Stig-Arne Grönroos <sgronroo@amd.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
2026-03-31 19:22:23 -04:00
Vedant V Jhaveri
2e56975657 Generative Scoring (#34539)
Signed-off-by: Vedant Jhaveri <vjhaveri@linkedin.com>
Co-authored-by: Vedant Jhaveri <vjhaveri@linkedin.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
2026-03-31 16:02:11 -07:00
Chang Su
36f1dc19ae feat(grpc): add periodic stats logging and servicer log forwarding (#38333)
Signed-off-by: Chang Su <chang.s.su@oracle.com>
2026-03-31 15:50:07 -07:00
Asaf Gardin
3dc01ef352 [Quantization] Consolidate dummy format logic into DummyModelLoader (#38637)
Signed-off-by: Josephasafg <ajgard7@gmail.com>
2026-03-31 22:20:45 +00:00
Yanan Cao
cc671cb110 [Kernel] [Helion] [17/N] Add Helion kernel torch.compile support (#38592)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
Co-authored-by: Claude Sonnet 4 <noreply@anthropic.com>
2026-03-31 17:06:42 -04:00
Wentao Ye
856589ed9a [Refactor] Remove dead code in kv connector and model runner (#38383)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2026-03-31 17:05:23 -04:00
czhu-cohere
517b769b58 [Perf] Fix DBO overlap: capture DeepEP event before yield (#38451)
Signed-off-by: root <conway.zhu@cohere.com>
2026-03-31 20:38:59 +00:00
yzong-rh
d9b90a07ac [MoE Refactor] Migrate Unquantized to Full Oracle Flow (#36286)
Signed-off-by: Yifan Zong <yzong@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: yzong-rh <yzong@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
2026-03-31 15:43:33 -04:00
Olya Kozlova
598190aac3 [fix] Remove trtllm ragged mla prefills (#36540)
Signed-off-by: Olya Kozlova <okozlova@nvidia.com>
2026-03-31 12:30:27 -07:00
Xu Jinyang
b779eb3363 [Model] Sync upstream BT=chunk_size fix for GDN chunk_fwd_kernel_o, simplify warmup to single pass (#38343)
Signed-off-by: AuYang <459461160@qq.com>
Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
2026-03-31 23:03:24 +04:00
BadrBasowid
077a9a8e37 [torch.compile] Refactor Attention Quant Fusion Pass and Remove Boilerplate (#37373)
Signed-off-by: BadrBasowid <badr.basowid@gmail.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
2026-03-31 14:15:50 -04:00
Run Yu
07edd551cc [CI/Build] Resolve a dependency deadlock when installing the test dependencies used in CI (#37766)
Signed-off-by: Run Yu <yurun00@gmail.com>
2026-03-31 18:05:14 +00:00
mikaylagawarecki
7c080dd3c5 [4/n] Migrate FP4/W4A8 CUTLASS kernels to torch stable ABI (#37503)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
2026-03-31 10:21:13 -07:00
Yi Liu
0dd25a44ea [Quantization][Autoround][XPU] Add W4A16 Support (#37986)
Signed-off-by: yiliu30 <yi4.liu@intel.com>
2026-03-31 16:48:24 +00:00
SandishKumarHN
3896e021a0 [Bugfix] Fix FusedMoE weight loading with padded hidden dimensions (#37010)
Signed-off-by: SandishKumarHN <sandish@fb.com>
2026-03-31 12:22:26 -04:00
zhang-prog
b6e636c12c [Fix] handle PaddleOCR-VL image processor max_pixels across Transformers v4/v5 (#38629)
Signed-off-by: zhangyue66 <zhangyue66@baidu.com>
2026-03-31 15:50:41 +00:00
Jingu Kang
f1ff50c86c [Bugfix] clamp dA_cumsum differences to prevent Inf in Mamba2 SSD kernels (#37501)
Signed-off-by: Jingu Kang <jg.k@navercorp.com>
2026-03-31 17:35:51 +02:00
Matthew Bonanni
757068dc65 [Bugfix][Async] Fix async spec decoding with hybrid models (#38556)
Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: SandishKumarHN <sandishkumarhn@gmail.com>
2026-03-31 11:08:54 -04:00
Nicolò Lucchesi
7337ff7f03 [Docs] PD with Nixl compat matrix (#38628)
Signed-off-by: NickLucche <nlucches@redhat.com>
2026-03-31 15:01:21 +00:00
Kyle Sayers
5869f69c5f [Online Quant] [QeRL] Minor code cleanup (#38574)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
2026-03-31 14:56:43 +00:00
wliao2
4dfad17ed1 replace cuda_device_count_stateless() to current_platform.device_count() (#37841)
Signed-off-by: Liao, Wei <wei.liao@intel.com>
Signed-off-by: wliao2 <wei.liao@intel.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
2026-03-31 22:32:54 +08:00
wenjun liu
e8057c00bc [CI] Avoid concurrent docker pull in intel XPU CI runners to prevent rate limit issues (#38594)
Signed-off-by: wendyliu235 <wenjun.liu@intel.com>
2026-03-31 22:23:18 +08:00
Nicolò Lucchesi
7430389669 [Bugfix][CI] Skip flaky test_eagle test (#38566)
Signed-off-by: NickLucche <nlucches@redhat.com>
2026-03-31 09:42:37 -04:00
ElizaWszola
202f147cf2 Fix MLA runs when use_inductor_graph_partition=True (#38631)
Signed-off-by: ElizaWszola <ewszola@redhat.com>
2026-03-31 13:37:43 +00:00
Jiangyun Zhu
ea7bfde6e4 [CI] fix LM Eval Qwen3.5 Models (B200) (#38632)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
2026-03-31 13:20:08 +00:00
sihao_li
d71a15041f [XPU]move testing dependencies from Dockerfile to xpu-test.in (#38596)
Signed-off-by: sihao.li <sihao.li@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
2026-03-31 12:49:43 +00:00
Ilya Markov
abdbb68386 [EPLB] Add alternative communication for EPLB weight exchange (#33176)
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Markov Ilya <markovilya19@gmail.com>
Co-authored-by: Markov Ilya <markovilya19@gmail.com>
2026-03-31 08:17:12 -04:00
liuzhenwei
0c63739135 [EPD] update EPD script arguments (#36742)
Signed-off-by: zhenwei-intel <zhenwei.liu@intel.com>
2026-03-31 12:02:09 +00:00
wang.yuqi
719735d6c5 [CI Failure] pin colmodernvbert revision (#38612)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-31 10:54:54 +00:00
Maosheng Liao
aae3e688f8 Fix document of torchrun_example.py (#31113) 2026-03-31 10:54:23 +00:00
Matthew Bonanni
7d65463528 [WIP][CI][Bugfix] Fix test_run_eagle_dp (#38584)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
2026-03-31 12:30:25 +02:00
Mateusz Sokół
8278825b57 DOC: TPU mention fix (#38129)
Signed-off-by: Mateusz Sokół <mat646@gmail.com>
2026-03-31 03:27:56 -07:00
Chang Su
acf7292bf2 [Misc] Move --grpc CLI argument into make_arg_parser (#38570)
Signed-off-by: Chang Su <chang.s.su@oracle.com>
2026-03-31 03:24:05 -07:00
Chauncey
ce884756f0 [Feature]: add presence_penalty and frequency_penalty fields to Responses API (#38613)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
2026-03-31 08:45:57 +00:00
wang.yuqi
d9d21eb8e3 [Frontend][3/n] Improve pooling entrypoints | scoring. (#28631)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
2026-03-31 07:52:00 +00:00
Yintong Lu
f09daea261 [CPU] Support int8 compute mode in CPU AWQ (#35697)
Signed-off-by: Yintong Lu <yintong.lu@intel.com>
2026-03-31 15:27:37 +08:00
Kevin H. Luu
42318c840b [ci] Remove benchmarks job (#38611) 2026-03-31 06:46:21 +00:00
zhangyiming
1ac6694297 [OOT] Add OOT support for linear kernel. (#37989)
Signed-off-by: menogrey <1299267905@qq.com>
2026-03-31 14:33:21 +08:00
Kfir Toledo
6cc7abdc66 [kv_offload+HMA] Fix num_blocks with different per-layer page sizes and improve assert message (#38554)
Signed-off-by: Kfir Toledo <kfir.toledo@ibm.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
2026-03-31 06:00:40 +00:00
Flora Feng
d53cb9cb8e [Tool Parser][2/3] Use self.tools instead of request.tools in tool parsers (#38189)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
2026-03-31 13:41:36 +08:00
Louie Tsai
44eef0ca1e vLLM Benchmark Suite perf regression after PR#32723 (#38576)
Signed-off-by: louie-tsai <louie.tsai@intel.com>
2026-03-31 05:23:17 +00:00
Andreas Karatzas
b9cdc85207 [ROCm][CI] Fix Whisper translation test attention backend selection (#38508)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
2026-03-31 13:21:49 +08:00
Flora Feng
3e802e8786 [Mypy] Fix adjust_request typing (#38264)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
2026-03-31 04:21:18 +00:00
Martin Hickey
350af48e14 [KVConnector] Remove redundant method KVConnectorOutput::merge() (#38546)
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
2026-03-31 07:11:02 +03:00
Lucas Kabela
e31915063d [Bugfix] Fix for builtins (forward fix of pytorch/177558) (#37234)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
2026-03-31 01:08:11 +00:00
Flora Feng
29e48707e8 [Refactor] Consolidate Tool type alias in tool_parsers/utils.py (#38265)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
2026-03-31 00:55:51 +00:00
sungsoo ha
4ac227222f [Bugfix][DCP] Fix CUDA graph capture for Decode Context Parallelism (#36070)
Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-30 20:20:43 -04:00
Vadim Gimpelson
bb51d5b40d Add @vadiklyutiy as committer (#38589)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
2026-03-31 07:50:04 +08:00
Prathmesh Bhatt
93b3ec1585 feat(attention): extract KV-cache update from FlashAttentionDiffKV ba… (#36466)
Signed-off-by: Prathmesh Bhatt <71340361+Prathmesh234@users.noreply.github.com>
2026-03-30 23:16:09 +00:00
Netanel Haber
e812bf70bd Restore non-hf processor path for Nano-Nemotron-VL (bypass call_hf_processor_mm_only) - fixes #38018 (#38567)
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
2026-03-30 21:56:52 +00:00
SandishKumarHN
bcc6f67447 [Bugfix] Use null block (0) for padded block table entries (#35431)
Signed-off-by: SandishKumarHN <sandish@fb.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
2026-03-30 14:02:51 -07:00
Asaf Gardin
1fc69f59bb [Bug fix][Quantization] Fix dummy weight loading (#38478)
Signed-off-by: Josephasafg <ajgard7@gmail.com>
2026-03-30 16:38:02 -04:00
Micah Williamson
d9c7db18da [ROCm][CI] Pin test_hybrid test to TRITON_ATTN on ROCm (#38381)
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
2026-03-30 20:26:46 +00:00
Ilya Markov
12701e8af2 [EPLB] Optmize eplb mapping and record in router for prefill (#36261)
Signed-off-by: ilmarkov <markovilya197@gmail.com>
2026-03-30 19:48:33 +00:00
Benjamin Chislett
494636b29d [Feat][Spec Decode] DFlash (#36847)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
2026-03-30 15:03:15 -04:00
mikaylagawarecki
ab1a6a43fa [3/n] Migrate cutlass/scaled_mm_entry.cu torch stable ABI (#37221)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
2026-03-30 11:20:13 -07:00
fangyuchu
b5e608258e [Refactor] Unify engine process monitoring in engine manager and add Ray backend support (#35862)
Signed-off-by: fangyuchu <fangyuchu@qq.com>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
2026-03-30 10:16:09 -07:00
Matthew Bonanni
2c734ed0e0 [Bugfix][MLA] Change default SM100 MLA prefill backend back to TRT-LLM (#38562)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
2026-03-30 09:51:24 -07:00
Chendi.Xue
3b1dbaad4e [HMA]Fix corner case when hybrid page_size can not be evenly divided issue (blk_size=64,tp=4) (#37467)
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
2026-03-30 16:47:30 +00:00
Johnny
b4a2f3ac36 [NVIDIA] Bugfix NVFP4 DGX Spark and RTX50 (#38423)
Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
Signed-off-by: Johnny <johnnynuca14@gmail.com>
2026-03-30 09:36:18 -07:00
roikoren755
8e6293e838 [Mamba] Add stochastic rounding support (#35753)
Signed-off-by: Roi Koren <roik@nvidia.com>
2026-03-30 12:33:49 -04:00
Hongxia Yang
dbdd9ae067 [ROCm][Bugfix] fix exception related to trust_remote_code for MiniMax-M2.1-MXFP4 (#37698)
Signed-off-by: Hongxia Yang <hongxiay.yang@amd.com>
Co-authored-by: Hongxia Yang <hongxiay.yang@amd.com>
2026-03-30 15:49:23 +00:00
Matthias Gehre
e8b055a5ac [Bugfix] Handle ParallelLMHead in compressed-tensors get_quant_method (#37291)
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
2026-03-30 07:30:52 -07:00
tomeras91
246dc7d864 [Misc] Add @tomeras91 as a maintainer of Nemotron related code + mamba block (#38547)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
2026-03-30 21:12:17 +08:00
Thomas Parnell
7c3f88b2a8 [Bugfix] Remove false-positive format mismatch warnings in FLA ops (#38255)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
2026-03-30 12:32:26 +00:00
Li, Jiang
6557f4937f [Bugfix][CPU] Skip set_num_threads after thread binding (#38535)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
2026-03-30 20:13:00 +08:00
Andreas Karatzas
677424c7ac [Core][CI] Add opt-in media URL caching via VLLM_MEDIA_CACHE (#37123)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
2026-03-30 04:58:53 -07:00
Collin McCarthy
1031c84c36 Fix ambiguous num_blocks for hybrid attn mamba (#37236)
Signed-off-by: Collin McCarthy <cmccarthy@nvidia.com>
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Co-authored-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
2026-03-30 11:09:45 +00:00
aliialsaeedii
7e76af14fa [Bugfix][Frontend] Return 400 for corrupt/truncated image inputs instead of 500 (#38253)
Signed-off-by: aliialsaeedii <ali.al-saeedi@nscale.com>
2026-03-30 10:26:46 +00:00
yzong-rh
3683fe6c06 [Bugfix] Fix shared-object aliasing in n>1 streaming with tool calls (#38158)
Signed-off-by: Yifan Zong <yzong@redhat.com>
Signed-off-by: Yifan <yzong@redhat.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
2026-03-30 10:12:13 +00:00
Nicolò Lucchesi
cc06b4e86b [Mamba][Bugfix] Raise on insufficient cache blocks instead of silently capping cudagraph sizes (#38270)
Signed-off-by: NickLucche <nlucches@redhat.com>
2026-03-30 09:41:50 +00:00
TJian
03ac6ca895 [ROCm] [DOC] Update the Documentation to include ROCm Nightly Wheel support (#38457)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
2026-03-30 02:25:46 -07:00
haosdent
a08b7733fd [CI] Fix SPLADE pooler test broken by #38139 (#38495)
Signed-off-by: haosdent <haosdent@gmail.com>
2026-03-30 07:48:33 +00:00
Tan Pin Siang
85c0950b1f [ROCm] Enable MORI EP for unquantized MoE with AITER backend (#37529)
Signed-off-by: Tan Pin Siang <pinsiang.tan@amd.com>
2026-03-30 15:19:33 +08:00
Juan Pérez de Algaba
57861ae48d (security) Fix SSRF in batch runner download_bytes_from_url (#38482)
Signed-off-by: jperezde <jperezde@redhat.com>
2026-03-30 07:10:01 +00:00
Jee Jee Li
ac30a8311e [Bugfix][Model] Fix PixtralForConditionalGeneration LoRA (#36963)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
2026-03-29 23:59:42 -07:00
PikaPikachu
63babd17f1 [Model][Quantization] Add GGUF support for MiniMax-M2.1 (#36965)
Signed-off-by: kangletian <Letian.Kang@amd.com>
2026-03-30 14:24:06 +08:00
Kevin H. Luu
fec5aeca12 [ci] Soft fail and disable retry for AMD build image job (#38505)
Signed-off-by: Kevin H. Luu <khluu000@gmail.com>
2026-03-29 23:05:26 -07:00
648 changed files with 37572 additions and 9674 deletions

View File

@@ -5,6 +5,7 @@ steps:
depends_on: []
device: amd_cpu
no_plugin: true
soft_fail: true
commands:
- >
docker build
@@ -20,11 +21,3 @@ steps:
- docker push "rocm/vllm-ci:${BUILDKITE_COMMIT}"
env:
DOCKER_BUILDKIT: "1"
retry:
automatic:
- exit_status: -1 # Agent was lost
limit: 1
- exit_status: -10 # Agent was lost
limit: 1
- exit_status: 1 # Machine occasionally fail
limit: 1

View File

@@ -13,12 +13,14 @@ steps:
- tests/kernels/attention/test_cpu_attn.py
- tests/kernels/moe/test_cpu_fused_moe.py
- tests/kernels/test_onednn.py
- tests/kernels/test_awq_int4_to_int8.py
commands:
- |
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py
pytest -x -v -s tests/kernels/test_onednn.py"
pytest -x -v -s tests/kernels/test_onednn.py
pytest -x -v -s tests/kernels/test_awq_int4_to_int8.py"
- label: CPU-Compatibility Tests
depends_on: []

View File

@@ -36,6 +36,7 @@
"model": "meta-llama/Llama-3.1-8B-Instruct",
"backend": "vllm",
"ignore-eos": "",
"temperature": 0,
"num_prompts": 200
}
},
@@ -127,4 +128,4 @@
}
}
]
}
}

View File

@@ -22,6 +22,7 @@
"hf_split": "test",
"no_stream": "",
"no_oversample": "",
"temperature": 0,
"num_prompts": 200
}
},

View File

@@ -26,6 +26,7 @@
"model": "meta-llama/Llama-3.1-8B-Instruct",
"backend": "vllm",
"ignore-eos": "",
"temperature": 0,
"num_prompts": 200
}
},

View File

@@ -26,6 +26,7 @@
"model": "meta-llama/Llama-3.1-8B-Instruct",
"backend": "vllm",
"ignore-eos": "",
"temperature": 0,
"num_prompts": 200
}
},

View File

@@ -21,6 +21,7 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -47,6 +48,7 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -73,6 +75,7 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -100,6 +103,7 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -127,6 +131,7 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -151,6 +156,7 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
}

View File

@@ -13,6 +13,7 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -30,6 +31,7 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -47,6 +49,7 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -67,6 +70,7 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
}

View File

@@ -1,9 +1,10 @@
#!/bin/bash
set -euox pipefail
export VLLM_CPU_CI_ENV=0
export VLLM_CPU_KVCACHE_SPACE=1 # avoid OOM
echo "--- PP+TP"
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 &
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 --max-model-len=4096 &
server_pid=$!
timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
vllm bench serve \
@@ -23,7 +24,7 @@ if [ "$failed_req" -ne 0 ]; then
fi
echo "--- DP+TP"
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 &
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 --max-model-len=4096 &
server_pid=$!
timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
vllm bench serve \

View File

@@ -239,13 +239,29 @@ fi
# --- Docker housekeeping ---
cleanup_docker
aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin "$REGISTRY"
aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 936637512419.dkr.ecr.us-east-1.amazonaws.com
# --- Build or pull test image ---
if [[ -n "${IMAGE_TAG_XPU:-}" ]]; then
echo "Using prebuilt XPU image: ${IMAGE_TAG_XPU}"
docker pull "${IMAGE_TAG_XPU}"
IMAGE="${IMAGE_TAG_XPU:-${image_name}}"
echo "Using image: ${IMAGE}"
if docker image inspect "${IMAGE}" >/dev/null 2>&1; then
echo "Image already exists locally, skipping pull"
else
echo "Using prebuilt XPU image: ${image_name}"
docker pull "${image_name}"
echo "Image not found locally, waiting for lock..."
flock /tmp/docker-pull.lock bash -c "
if docker image inspect '${IMAGE}' >/dev/null 2>&1; then
echo 'Image already pulled by another runner'
else
echo 'Pulling image...'
timeout 900 docker pull '${IMAGE}'
fi
"
echo "Pull step completed"
fi
remove_docker_container() {

View File

@@ -42,6 +42,7 @@ docker run \
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
python3 examples/basic/offline_inference/generate.py --model OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc --block-size 64 --enforce-eager --max-model-len 8192
cd tests
pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py --ignore=v1/core/test_scheduler_e2e.py
pytest -v -s v1/engine

View File

@@ -790,7 +790,7 @@ steps:
- tests/kernels/helion/
- vllm/platforms/rocm.py
commands:
- pip install helion
- pip install helion==0.3.3
- pytest -v -s kernels/helion/

View File

@@ -2,14 +2,6 @@ group: Benchmarks
depends_on:
- image-build
steps:
- label: Benchmarks
timeout_in_minutes: 20
working_dir: "/vllm-workspace/.buildkite"
source_file_dependencies:
- benchmarks/
commands:
- bash scripts/run-benchmarks.sh
- label: Benchmarks CLI Test
timeout_in_minutes: 20
source_file_dependencies:

View File

@@ -72,6 +72,7 @@ steps:
- vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes
- tests/compile/passes/test_fusion_attn.py
- tests/compile/passes/test_mla_attn_quant_fusion.py
- tests/compile/passes/test_silu_mul_quant_fusion.py
- tests/compile/passes/distributed/test_fusion_all_reduce.py
- tests/compile/fullgraph/test_full_graph.py
@@ -79,6 +80,7 @@ steps:
# b200 runners are limited, so we limit the tests to the minimum set only supported on Blackwell
- nvidia-smi
- pytest -v -s tests/compile/passes/test_fusion_attn.py -k FLASHINFER
- pytest -v -s tests/compile/passes/test_mla_attn_quant_fusion.py
- pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py
# this runner has 2 GPUs available even though num_devices=2 is not set
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py

View File

@@ -224,6 +224,20 @@ steps:
commands:
- ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 2 $IMAGE_TAG "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=0 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py" "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=1 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code"
- label: MessageQueue TCP Multi-Node (2 GPUs)
timeout_in_minutes: 10
working_dir: "/vllm-workspace/tests"
num_devices: 1
num_nodes: 2
no_plugin: true
optional: true
source_file_dependencies:
- vllm/distributed/device_communicators/shm_broadcast.py
- vllm/distributed/parallel_state.py
- tests/distributed/test_mq_tcp_multinode.py
commands:
- ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 1 $IMAGE_TAG "torchrun --nnodes 2 --nproc-per-node=1 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_mq_tcp_multinode.py" "torchrun --nnodes 2 --nproc-per-node=1 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_mq_tcp_multinode.py"
- label: Distributed NixlConnector PD accuracy (4 GPUs)
timeout_in_minutes: 30
working_dir: "/vllm-workspace/tests"
@@ -294,3 +308,23 @@ steps:
commands:
- pytest -v -s distributed/test_pp_cudagraph.py
- pytest -v -s distributed/test_pipeline_parallel.py
- label: RayExecutorV2 (4 GPUs)
timeout_in_minutes: 60
working_dir: "/vllm-workspace/tests"
num_devices: 4
source_file_dependencies:
- vllm/v1/executor/ray_executor_v2.py
- vllm/v1/executor/abstract.py
- vllm/v1/executor/multiproc_executor.py
- tests/distributed/test_ray_v2_executor.py
- tests/distributed/test_ray_v2_executor_e2e.py
- tests/distributed/test_pipeline_parallel.py
- tests/basic_correctness/test_basic_correctness.py
commands:
- export VLLM_USE_RAY_V2_EXECUTOR_BACKEND=1
- export NCCL_CUMEM_HOST_ENABLE=0
- pytest -v -s distributed/test_ray_v2_executor.py
- pytest -v -s distributed/test_ray_v2_executor_e2e.py
- pytest -v -s distributed/test_pipeline_parallel.py -k "ray"
- TARGET_TEST_SUITE=L4 pytest -v -s basic_correctness/test_basic_correctness.py -k "ray"

View File

@@ -13,8 +13,8 @@ steps:
- pytest -v -s distributed/test_eplb_algo.py
- pytest -v -s distributed/test_eplb_utils.py
- label: EPLB Execution
timeout_in_minutes: 20
- label: EPLB Execution # 17min
timeout_in_minutes: 27
working_dir: "/vllm-workspace/tests"
num_devices: 4
source_file_dependencies:

View File

@@ -2,6 +2,16 @@ group: Kernels
depends_on:
- image-build
steps:
- label: vLLM IR Tests
timeout_in_minutes: 10
working_dir: "/vllm-workspace/"
source_file_dependencies:
- vllm/ir
- vllm/kernels
commands:
- pytest -v -s tests/ir
- pytest -v -s tests/kernels/ir
- label: Kernels Core Operation Test
timeout_in_minutes: 75
source_file_dependencies:
@@ -19,6 +29,7 @@ steps:
- vllm/v1/attention
# TODO: remove this dependency (https://github.com/vllm-project/vllm/issues/32267)
- vllm/model_executor/layers/attention
- vllm/utils/flashinfer.py
- tests/kernels/attention
commands:
- pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
@@ -129,7 +140,7 @@ steps:
- vllm/utils/import_utils.py
- tests/kernels/helion/
commands:
- pip install helion
- pip install helion==0.3.3
- pytest -v -s kernels/helion/

View File

@@ -18,5 +18,6 @@ steps:
# Avoid importing model tests that cause CUDA reinitialization error
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
- pytest models/multimodal/generation/test_phi4siglip.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/generation/test_phi4siglip.py
- VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)'

30
.github/CODEOWNERS vendored
View File

@@ -2,16 +2,20 @@
# for more info about CODEOWNERS file
# This lists cover the "core" components of vLLM that require careful review
/vllm/compilation @zou3519 @youkaichao @ProExpertProg @BoyuanFeng
/vllm/compilation @zou3519 @youkaichao @ProExpertProg @BoyuanFeng @vadiklyutiy
/vllm/distributed/kv_transfer @NickLucche @ApostaC @orozery
/vllm/lora @jeejeelee
/vllm/model_executor/layers/attention @LucasWilkinson @MatthewBonanni
/vllm/model_executor/layers/fused_moe @mgoin @pavanimajety
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety
/vllm/model_executor/layers/mamba @tdoublep
/vllm/model_executor/layers/mamba/gdn_linear_attn.py @tdoublep @ZJY0516
/vllm/model_executor/layers/mamba @tdoublep @tomeras91
/vllm/model_executor/layers/mamba/gdn_linear_attn.py @tdoublep @ZJY0516 @vadiklyutiy
/vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy
/vllm/model_executor/model_loader @22quinn
/vllm/model_executor/layers/batch_invariant.py @yewentao256
/vllm/ir @ProExpertProg
/vllm/kernels/ @ProExpertProg @tjtanaa
/vllm/kernels/helion @ProExpertProg @zou3519
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
/vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
@@ -47,9 +51,9 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/vllm/v1/attention @LucasWilkinson @MatthewBonanni
/vllm/v1/attention/backend.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @njhill
/vllm/v1/attention/backends/mla @pavanimajety
/vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety
/vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety @vadiklyutiy
/vllm/v1/attention/backends/triton_attn.py @tdoublep
/vllm/v1/attention/backends/gdn_attn.py @ZJY0516
/vllm/v1/attention/backends/gdn_attn.py @ZJY0516 @vadiklyutiy
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery
/vllm/v1/sample @22quinn @houseroad @njhill
/vllm/v1/spec_decode @benchislett @luccafong @MatthewBonanni
@@ -71,8 +75,9 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/distributed/test_pipeline_parallel.py @youkaichao
/tests/distributed/test_same_node.py @youkaichao
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
/tests/evals @mgoin
/tests/evals @mgoin @vadiklyutiy
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
/tests/kernels/ir @ProExpertProg @tjtanaa
/tests/models @DarkLight1337 @ywang96
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety
@@ -82,7 +87,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery
/tests/weight_loading @mgoin @youkaichao @yewentao256
/tests/lora @jeejeelee
/tests/models/language/generation/test_hybrid.py @tdoublep
/tests/models/language/generation/test_hybrid.py @tdoublep @tomeras91
/tests/v1/kv_connector/nixl_integration @NickLucche
/tests/v1/kv_connector @ApostaC @orozery
/tests/v1/kv_offload @ApostaC @orozery
@@ -126,9 +131,14 @@ mkdocs.yaml @hmellor
/vllm/platforms/xpu.py @jikunshang
/docker/Dockerfile.xpu @jikunshang
# Nemotron-specific files
/vllm/model_executor/models/*nemotron* @tomeras91
/vllm/transformers_utils/configs/*nemotron* @tomeras91
/tests/**/*nemotron* @tomeras91
# Qwen-specific files
/vllm/attention/backends/dual_chunk_flash_attn.py @sighingnow
/vllm/model_executor/models/qwen* @sighingnow
/vllm/model_executor/models/qwen* @sighingnow @vadiklyutiy
/vllm/transformers_utils/configs/qwen* @sighingnow @vadiklyutiy
# MTP-specific files
/vllm/model_executor/models/deepseek_mtp.py @luccafong
@@ -144,7 +154,7 @@ mkdocs.yaml @hmellor
# Kernels
/vllm/v1/attention/ops/chunked_prefill_paged_decode.py @tdoublep
/vllm/v1/attention/ops/triton_unified_attention.py @tdoublep
/vllm/model_executor/layers/fla @ZJY0516
/vllm/model_executor/layers/fla @ZJY0516 @vadiklyutiy
# ROCm related: specify owner with write access to notify AMD folks for careful code review
/vllm/**/*rocm* @tjtanaa

View File

@@ -28,6 +28,7 @@ jobs:
});
const hasReadyLabel = pr.labels.some(l => l.name === 'ready');
const hasVerifiedLabel = pr.labels.some(l => l.name === 'verified');
const { data: mergedPRs } = await github.rest.search.issuesAndPullRequests({
q: `repo:${context.repo.owner}/${context.repo.repo} is:pr is:merged author:${pr.user.login}`,
@@ -35,10 +36,10 @@ jobs:
});
const mergedCount = mergedPRs.total_count;
if (hasReadyLabel || mergedCount >= 4) {
core.info(`Check passed: ready label=${hasReadyLabel}, 4+ merged PRs=${mergedCount >= 4}`);
if (hasReadyLabel || hasVerifiedLabel || mergedCount >= 4) {
core.info(`Check passed: verified label=${hasVerifiedLabel}, ready label=${hasReadyLabel}, 4+ merged PRs=${mergedCount >= 4}`);
} else {
core.setFailed(`PR must have the 'ready' label or the author must have at least 4 merged PRs (found ${mergedCount}).`);
core.setFailed(`PR must have the 'verified' or 'ready' (which also triggers tests) label or the author must have at least 4 merged PRs (found ${mergedCount}).`);
}
pre-commit:

View File

@@ -39,7 +39,7 @@ repos:
rev: 0.11.1
hooks:
- id: pip-compile
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28, --python-version, "3.12"]
args: [requirements/test.in, -c, requirements/common.txt, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28, --python-version, "3.12"]
files: ^requirements/test\.(in|txt)$
- id: pip-compile
alias: pip-compile-rocm

View File

@@ -309,7 +309,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
set(CUTLASS_REVISION "v4.2.1")
set(CUTLASS_REVISION "v4.4.2")
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@@ -340,10 +340,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp")
"csrc/cutlass_extensions/common.cpp"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu")
set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
@@ -490,185 +488,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures")
endif()
set(SCALED_MM_3X_ARCHS)
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"Hopper.")
else()
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
# require CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
# (Build 8.9 for FP8)
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
else()
if (SCALED_MM_3X_ARCHS)
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
" for and covered by scaled_mm_c3x")
else()
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# FP4 Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# CUTLASS MLA Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
@@ -693,55 +512,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MLA_ARCHS)
endif()
# CUTLASS MoE kernels
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
# if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
"if you intend on running FP8 quantized MoE models on Blackwell.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
# Expert-specialization MXFP8 blockscaled grouped kernels (SM100+).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
@@ -787,36 +557,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"in CUDA target architectures.")
endif()
# moe_data.cu is used by all CUTLASS MoE kernels.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
message(STATUS "Not building moe_data as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
else()
message(STATUS "Not building moe_data as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
#
# Machete kernels
@@ -887,34 +627,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()
# Only build W4A8 kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_utils.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${W4A8_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
AND W4A8_ARCHS)
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running w4a16 quantized models on "
"Hopper.")
else()
message(STATUS "Not building W4A8 kernels as no compatible archs "
"found in CUDA target architectures")
endif()
endif()
# Hadacore kernels
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
@@ -964,7 +676,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY)
#
set(VLLM_STABLE_EXT_SRC
"csrc/libtorch_stable/torch_bindings.cpp")
"csrc/libtorch_stable/torch_bindings.cpp"
"csrc/cutlass_extensions/common.cpp"
"csrc/cuda_utils_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
@@ -979,6 +696,299 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
CUDA_ARCHS "${CUDA_ARCHS}")
endif()
#
# CUTLASS scaled_mm kernels (moved from _C to _C_stable_libtorch)
#
set(SCALED_MM_3X_ARCHS)
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"Hopper.")
else()
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
# require CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
# (Build 8.9 for FP8)
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
else()
if (SCALED_MM_3X_ARCHS)
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
" for and covered by scaled_mm_c3x")
else()
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
#
# CUTLASS MoE kernels (moved from _C to _C_stable_libtorch)
#
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
# if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
"if you intend on running FP8 quantized MoE models on Blackwell.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
# moe_data.cu is used by all CUTLASS MoE kernels.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/moe_data.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
message(STATUS "Not building moe_data as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
else()
message(STATUS "Not building moe_data as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
#
# FP4/NVFP4 kernels (moved from _C to _C_stable_libtorch)
#
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# FP4 Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
#
# W4A8 kernels (moved from _C to _C_stable_libtorch)
#
# Only build W4A8 kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_utils.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${W4A8_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
AND W4A8_ARCHS)
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running w4a16 quantized models on "
"Hopper.")
else()
message(STATUS "Not building W4A8 kernels as no compatible archs "
"found in CUDA target architectures")
endif()
endif()
message(STATUS "Enabling C_stable extension.")
define_extension_target(
_C_stable_libtorch
@@ -987,6 +997,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SOURCES ${VLLM_STABLE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)
@@ -1000,6 +1011,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Needed to use cuda APIs from C-shim
target_compile_definitions(_C_stable_libtorch PRIVATE
USE_CUDA)
# Needed by CUTLASS kernels
target_compile_definitions(_C_stable_libtorch PRIVATE
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
endif()
#
@@ -1015,7 +1030,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
"csrc/moe/grouped_topk_kernels.cu"
"csrc/moe/gpt_oss_router_gemm.cu"
"csrc/moe/router_gemm.cu")
endif()

View File

@@ -0,0 +1,264 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark: Fused FP8 output quantization in merge_attn_states
Compares fused vs unfused approaches for producing FP8-quantized merged
attention output:
1. Fused CUDA -- single CUDA kernel (merge + FP8 quant)
2. Fused Triton -- single Triton kernel (merge + FP8 quant)
3. Unfused CUDA -- CUDA merge + torch.compiled FP8 quant
4. Unfused Triton -- Triton merge + torch.compiled FP8 quant
Usage:
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --tp 1 4 8
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --dtype bfloat16
"""
import argparse
import itertools
import torch
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.v1.attention.ops.triton_merge_attn_states import (
merge_attn_states as merge_attn_states_triton,
)
# ---------------------------------------------------------------------------
# Configuration defaults
# ---------------------------------------------------------------------------
NUM_TOKENS_LIST = [1, 16, 64, 256, 1024, 4096]
# (label, num_heads, head_size) — num_heads is for TP=1
HEAD_CONFIGS = [
("DeepSeek-V3 MLA", 128, 128),
("Llama-70B", 64, 128),
("Llama-8B", 32, 128),
]
TP_SIZES = [1, 2, 4, 8]
INPUT_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
QUANTILES = [0.5, 0.2, 0.8]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def short_dtype(dtype: torch.dtype) -> str:
return str(dtype).removeprefix("torch.")
def make_inputs(
num_tokens: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
):
"""Create random prefix/suffix outputs and LSEs."""
prefix_output = torch.randn(
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
)
suffix_output = torch.randn(
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
)
prefix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
suffix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
# Sprinkle some inf values to exercise edge-case paths
mask = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
prefix_lse[mask] = float("inf")
mask2 = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
suffix_lse[mask2] = float("inf")
return prefix_output, suffix_output, prefix_lse, suffix_lse
def build_configs(head_configs, num_tokens_list, input_dtypes, tp_sizes):
"""Build (num_tokens, num_heads, head_size, dtype_str) config tuples,
applying TP division to num_heads and skipping invalid combos."""
configs = []
for (_, nh, hs), nt, dtype, tp in itertools.product(
head_configs, num_tokens_list, input_dtypes, tp_sizes
):
nh_tp = nh // tp
if nh_tp >= 1:
configs.append((nt, nh_tp, hs, short_dtype(dtype)))
return configs
def parse_args():
parser = argparse.ArgumentParser(
description="Benchmark merge_attn_states fused FP8 quantization"
)
parser.add_argument(
"--num-tokens",
type=int,
nargs="+",
default=None,
help=f"Override token counts (default: {NUM_TOKENS_LIST})",
)
parser.add_argument(
"--tp",
type=int,
nargs="+",
default=None,
help=f"TP sizes to simulate (divides num_heads) (default: {TP_SIZES})",
)
parser.add_argument(
"--dtype",
type=str,
nargs="+",
default=None,
help="Input dtypes (e.g. bfloat16 float16 float32). "
f"Default: {[short_dtype(d) for d in INPUT_DTYPES]}",
)
return parser.parse_args()
# ---------------------------------------------------------------------------
# Parse args and build configs before decorators
# ---------------------------------------------------------------------------
args = parse_args()
num_tokens_list = args.num_tokens if args.num_tokens else NUM_TOKENS_LIST
tp_sizes = args.tp if args.tp else TP_SIZES
if args.dtype:
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
input_dtypes = [STR_DTYPE_TO_TORCH_DTYPE[d] for d in args.dtype]
else:
input_dtypes = INPUT_DTYPES
configs = build_configs(HEAD_CONFIGS, num_tokens_list, input_dtypes, tp_sizes)
torch._dynamo.config.recompile_limit = 8888
# ---------------------------------------------------------------------------
# Benchmark function
# ---------------------------------------------------------------------------
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_heads", "head_size", "dtype_str"],
x_vals=configs,
line_arg="provider",
line_vals=["fused_cuda", "fused_triton", "unfused_cuda", "unfused_triton"],
line_names=["Fused CUDA", "Fused Triton", "Unfused CUDA", "Unfused Triton"],
styles=[("blue", "-"), ("green", "-"), ("blue", "--"), ("green", "--")],
ylabel="us",
plot_name="merge_attn_states FP8 (fused vs unfused)",
args={},
)
)
@default_vllm_config()
def benchmark(num_tokens, num_heads, head_size, dtype_str, provider):
input_dtype = getattr(torch, dtype_str)
fp8_dtype = current_platform.fp8_dtype()
prefix_out, suffix_out, prefix_lse, suffix_lse = make_inputs(
num_tokens, num_heads, head_size, input_dtype
)
output_scale = torch.tensor([0.1], dtype=torch.float32, device="cuda")
if provider == "fused_cuda":
output = torch.empty(
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
)
fn = lambda: merge_attn_states_cuda(
output,
prefix_out,
prefix_lse,
suffix_out,
suffix_lse,
output_scale=output_scale,
)
elif provider == "fused_triton":
output = torch.empty(
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
)
fn = lambda: merge_attn_states_triton(
output,
prefix_out,
prefix_lse,
suffix_out,
suffix_lse,
output_scale=output_scale,
)
elif provider == "unfused_cuda":
merge_buf = torch.empty(
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
)
quant_fp8 = QuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
column_major_scales=False,
)
quant_input = merge_buf.view(-1, head_size)
compiled_quant = torch.compile(
quant_fp8.forward_native, fullgraph=True, dynamic=False
)
def unfused_fn():
merge_attn_states_cuda(
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
)
compiled_quant(quant_input, output_scale)
fn = unfused_fn
else: # unfused_triton
merge_buf = torch.empty(
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
)
quant_fp8 = QuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
column_major_scales=False,
)
quant_input = merge_buf.view(-1, head_size)
compiled_quant = torch.compile(
quant_fp8.forward_native, fullgraph=True, dynamic=False
)
def unfused_fn():
merge_attn_states_triton(
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
)
compiled_quant(quant_input, output_scale)
fn = unfused_fn
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=QUANTILES)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms # us
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
device_name = current_platform.get_device_name()
print(f"Device: {device_name}")
print(f"Token counts: {num_tokens_list}")
print(f"TP sizes: {tp_sizes}")
print(f"Input dtypes: {[short_dtype(d) for d in input_dtypes]}")
print(f"Head configs: {[(c[0], c[1], c[2]) for c in HEAD_CONFIGS]}")
benchmark.run(print_data=True)
if __name__ == "__main__":
with torch.inference_mode():
main()

View File

@@ -0,0 +1,211 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from itertools import product
import torch
import torch.nn.functional as F
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from tqdm import tqdm
import vllm._custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
@dataclass
class bench_params_t:
num_tokens: int
hidden_size: int
dtype: torch.dtype
group_size: int # Changed from list[int] to int
def description(self):
return (
f"N {self.num_tokens} "
f"x D {self.hidden_size} "
f"x DT {self.dtype} "
f"x GS {self.group_size}"
)
def get_bench_params() -> list[bench_params_t]:
"""Test configurations covering common model sizes."""
NUM_TOKENS = [16, 128, 512, 2048]
HIDDEN_SIZES = [1024, 2048, 4096, 5120, 14336] # Common FFN sizes
DTYPES = [torch.float16, torch.bfloat16]
GROUP_SIZES = [64, 128] # Changed from [[1, 64], [1, 128]]
combinations = product(NUM_TOKENS, HIDDEN_SIZES, DTYPES, GROUP_SIZES)
bench_params = list(
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
)
return bench_params
# Reference implementations
def unfused_fp8_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int, # Changed from list[int]
):
"""Unfused: SiLU+Mul then per-tensor quantize."""
hidden = x.shape[-1] // 2
gate, up = x.split(hidden, dim=-1)
# SiLU(gate) * up
silu_out = F.silu(gate) * up
# Per-tensor quantize (no group_size used here)
silu_out, _ = ops.scaled_fp8_quant(silu_out)
def unfused_groupwise_fp8_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int, # Changed from list[int]
):
"""Unfused: SiLU+Mul then group-wise quantize."""
hidden = x.shape[-1] // 2
gate, up = x.split(hidden, dim=-1)
# SiLU(gate) * up
silu_out = F.silu(gate) * up
# Group quantize - use group_size directly
silu_out, _ = per_token_group_quant_fp8(
silu_out, group_size=group_size, use_ue8m0=False
)
def fused_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int,
):
"""Fused: SiLU+Mul+Block Quantization in single kernel."""
out, _ = ops.silu_and_mul_per_block_quant(
x,
group_size=group_size,
quant_dtype=quant_dtype,
is_scale_transposed=False,
)
# Bench functions
def bench_fn(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int,
label: str,
sub_label: str,
fn: Callable,
description: str,
) -> TMeasurement:
min_run_time = 1
globals = {
"x": x,
"quant_dtype": quant_dtype,
"group_size": group_size,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(x, quant_dtype, group_size)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
"""Run benchmarks for all implementations."""
# Make inputs: [num_tokens, hidden_size * 2] for [gate || up]
scale = 1 / params.hidden_size
x = (
torch.randn(
params.num_tokens,
params.hidden_size * 2,
dtype=params.dtype,
device="cuda",
)
* scale
)
timers = []
# Unfused per-tensor FP8
timers.append(
bench_fn(
x,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_fp8_impl,
"unfused_fp8_impl",
)
)
# Unfused group-wise FP8
timers.append(
bench_fn(
x,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_groupwise_fp8_impl,
"unfused_groupwise_fp8_impl",
)
)
# Fused group-wise FP8
timers.append(
bench_fn(
x,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
fused_impl,
"fused_groupwise_fp8_impl",
)
)
return timers
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
def main():
torch.set_default_device("cuda")
bench_params = get_bench_params()
print(f"Running {len(bench_params)} benchmark configurations...")
print(
f"This will take approximately {len(bench_params) * 3} seconds (1s per variant)"
)
print()
timers = []
for bp in tqdm(bench_params):
result_timers = bench(bp, "silu-mul-block-quant", bp.description())
timers.extend(result_timers)
print("\n" + "=" * 80)
print("FINAL COMPARISON - ALL RESULTS")
print("=" * 80)
print_timers(timers)
if __name__ == "__main__":
main()

View File

@@ -1,134 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
# Dimensions supported by the DSV3 specialized kernel
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
# Dimensions supported by the gpt-oss specialized kernel
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
def get_batch_size_range(max_batch_size):
return [2**x for x in range(14) if 2**x <= max_batch_size]
def get_model_params(config):
if config.architectures[0] in (
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
):
num_experts = config.n_routed_experts
hidden_size = config.hidden_size
elif config.architectures[0] in ("GptOssForCausalLM",):
num_experts = config.num_local_experts
hidden_size = config.hidden_size
else:
raise ValueError(f"Unsupported architecture: {config.architectures}")
return num_experts, hidden_size
def get_benchmark(model, max_batch_size, trust_remote_code):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=get_batch_size_range(max_batch_size),
x_log=False,
line_arg="provider",
line_vals=[
"torch",
"vllm",
],
line_names=["PyTorch", "vLLM"],
styles=([("blue", "-"), ("red", "-")]),
ylabel="TFLOPs",
plot_name=f"{model} router gemm throughput",
args={},
)
)
def benchmark(batch_size, provider):
config = get_config(model=model, trust_remote_code=trust_remote_code)
num_experts, hidden_size = get_model_params(config)
mat_a = torch.randn(
(batch_size, hidden_size), dtype=torch.bfloat16, device="cuda"
).contiguous()
mat_b = torch.randn(
(num_experts, hidden_size), dtype=torch.bfloat16, device="cuda"
).contiguous()
bias = torch.randn(
num_experts, dtype=torch.bfloat16, device="cuda"
).contiguous()
is_hopper_or_blackwell = current_platform.is_device_capability(
90
) or current_platform.is_device_capability_family(100)
allow_dsv3_router_gemm = (
is_hopper_or_blackwell
and num_experts in DSV3_SUPPORTED_NUM_EXPERTS
and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES
)
allow_gpt_oss_router_gemm = (
is_hopper_or_blackwell
and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS
and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES
)
has_bias = False
if allow_gpt_oss_router_gemm:
has_bias = True
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
def runner():
if has_bias:
F.linear(mat_a, mat_b, bias)
else:
F.linear(mat_a, mat_b)
elif provider == "vllm":
def runner():
if allow_dsv3_router_gemm:
ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16)
elif allow_gpt_oss_router_gemm:
ops.gpt_oss_router_gemm(mat_a, mat_b, bias)
else:
raise ValueError("Unsupported router gemm")
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
runner, quantiles=quantiles
)
def tflops(t_ms):
flops = 2 * batch_size * hidden_size * num_experts
return flops / (t_ms * 1e-3) / 1e12
return tflops(ms), tflops(max_ms), tflops(min_ms)
return benchmark
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--model", type=str, default="openai/gpt-oss-20b")
parser.add_argument("--max-batch-size", default=16, type=int)
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()
# Get the benchmark function
benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code)
# Run performance benchmark
benchmark.run(print_data=True)

View File

@@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Benchmarks the fused Triton bilinear position-embedding kernel against
# the pure-PyTorch (native) implementation used in Qwen3-VL ViT models.
#
# == Usage Examples ==
#
# Default benchmark:
# python3 benchmark_vit_bilinear_pos_embed.py
#
# Custom parameters:
# python3 benchmark_vit_bilinear_pos_embed.py --hidden-dim 1152 \
# --num-grid-per-side 48 --save-path ./configs/vit_pos_embed/
import itertools
import torch
from vllm.model_executor.models.qwen3_vl import (
pos_embed_interpolate_native,
triton_pos_embed_interpolate,
)
from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
# (h, w) configurations to benchmark
h_w_configs = [
(16, 16),
(32, 32),
(48, 48),
(64, 64),
(128, 128),
(32, 48),
(60, 80),
]
# Temporal dimensions
t_range = [1]
configs = list(itertools.product(t_range, h_w_configs))
def get_benchmark(
num_grid_per_side: int,
spatial_merge_size: int,
hidden_dim: int,
dtype: torch.dtype,
device: str,
):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["t", "h_w"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["native", "triton"],
line_names=["Native (PyTorch)", "Triton"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name=(
f"vit-bilinear-pos-embed-"
f"grid{num_grid_per_side}-"
f"dim{hidden_dim}-"
f"{dtype}"
),
args={},
)
)
def benchmark(t, h_w, provider):
h, w = h_w
torch.manual_seed(42)
embed_weight = (
torch.randn(
num_grid_per_side * num_grid_per_side,
hidden_dim,
device=device,
dtype=dtype,
)
* 0.25
)
quantiles = [0.5, 0.2, 0.8]
if provider == "native":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: pos_embed_interpolate_native(
embed_weight,
t,
h,
w,
num_grid_per_side,
spatial_merge_size,
dtype,
),
quantiles=quantiles,
)
else:
assert HAS_TRITON, "Triton not available"
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_pos_embed_interpolate(
embed_weight,
t,
h,
w,
num_grid_per_side,
spatial_merge_size,
dtype,
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark bilinear position embedding interpolation."
)
parser.add_argument(
"--num-grid-per-side",
type=int,
default=48,
help="Position embedding grid size (default: 48 for Qwen3-VL)",
)
parser.add_argument(
"--spatial-merge-size",
type=int,
default=2,
help="Spatial merge size (default: 2)",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=1152,
help="Embedding hidden dimension (default: 1152 for Qwen3-VL)",
)
parser.add_argument(
"--device",
type=str,
choices=["cuda:0", "cuda:1"],
default="cuda:0",
)
parser.add_argument(
"--save-path",
type=str,
default="./vit_pos_embed/",
)
args = parser.parse_args()
dtype = torch.bfloat16
bench = get_benchmark(
args.num_grid_per_side,
args.spatial_merge_size,
args.hidden_dim,
dtype,
args.device,
)
bench.run(print_data=True, save_path=args.save_path)

View File

@@ -373,6 +373,7 @@ if (ENABLE_X86_ISA)
"csrc/cpu/sgl-kernels/gemm.cpp"
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
"csrc/cpu/sgl-kernels/gemm_int4.cpp"
"csrc/cpu/sgl-kernels/moe.cpp"
"csrc/cpu/sgl-kernels/moe_int8.cpp"
"csrc/cpu/sgl-kernels/moe_fp8.cpp")

View File

@@ -39,7 +39,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 29210221863736a08f71a866459e368ad1ac4a95
GIT_TAG c0ec424fd8a546d0cbbf4bf050bbcfe837c55afb
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

View File

@@ -3,22 +3,33 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <limits>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "../quantization/w8a8/fp8/common.cuh"
#include "../dispatch_utils.h"
namespace vllm {
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
template <typename scalar_t, const uint NUM_THREADS>
template <typename scalar_t, typename output_t, const uint NUM_THREADS,
bool USE_FP8_OUTPUT>
__global__ void merge_attn_states_kernel(
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
output_t* output, float* output_lse, const scalar_t* prefix_output,
const float* prefix_lse, const scalar_t* suffix_output,
const float* suffix_lse, const uint num_tokens, const uint num_heads,
const uint head_size, const uint prefix_head_stride,
const uint output_head_stride) {
using pack_128b_t = uint4;
const uint output_head_stride, const uint prefix_num_tokens,
const float* output_scale) {
// Inputs always load 128-bit packs (pack_size elements of scalar_t).
// Outputs store pack_size elements of output_t, which is smaller for FP8.
using input_pack_t = uint4;
using output_pack_t =
std::conditional_t<USE_FP8_OUTPUT,
std::conditional_t<sizeof(scalar_t) == 4, uint, uint2>,
uint4>;
const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size;
@@ -41,8 +52,45 @@ __global__ void merge_attn_states_kernel(
head_idx * output_head_stride;
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
scalar_t* output_head_ptr = output + dst_head_offset;
output_t* output_head_ptr = output + dst_head_offset;
// Pre-invert scale: multiplication is faster than division
float fp8_scale_inv = 1.0f;
if constexpr (USE_FP8_OUTPUT) {
fp8_scale_inv = 1.0f / *output_scale;
}
// If token_idx >= prefix_num_tokens, just copy from suffix
if (token_idx >= prefix_num_tokens) {
if (pack_offset < head_size) {
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
suffix_head_ptr)[pack_offset / pack_size];
if constexpr (USE_FP8_OUTPUT) {
output_t o_out_pack[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
const float val =
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
o_out_pack[i] =
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] =
*reinterpret_cast<output_pack_t*>(o_out_pack);
} else {
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] = s_out_pack;
}
}
if (output_lse != nullptr && pack_idx == 0) {
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
output_lse[head_idx * num_tokens + token_idx] = s_lse;
}
return;
}
// For tokens within prefix range, merge prefix and suffix
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
@@ -53,20 +101,34 @@ __global__ void merge_attn_states_kernel(
/* In certain edge cases, MLA can produce p_lse = s_lse = -inf;
continuing the pipeline then yields NaN. Root cause: with chunked prefill
a batch may be split into two chunks; if a request in that batch has no
prefix hit, every LSE entry for that requests position is -inf, and at
prefix hit, every LSE entry for that request's position is -inf, and at
this moment we merge cross-attention at first. For now we simply emit
prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix
this problem.
*/
if (std::isinf(max_lse)) {
if (pack_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
prefix_head_ptr)[pack_offset / pack_size];
// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
p_out_pack;
if constexpr (USE_FP8_OUTPUT) {
// Convert prefix values to FP8 (since -inf means no data,
// prefix_output is expected to be zeros)
output_t o_out_pack[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
const float val =
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
o_out_pack[i] =
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] =
*reinterpret_cast<output_pack_t*>(o_out_pack);
} else {
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] = p_out_pack;
}
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
@@ -84,30 +146,43 @@ __global__ void merge_attn_states_kernel(
const float s_scale = s_se / out_se;
if (pack_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
prefix_head_ptr)[pack_offset / pack_size];
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
suffix_head_ptr)[pack_offset / pack_size];
pack_128b_t o_out_pack;
// Compute merged values in float32
float o_out_f[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
// Always use float for FMA to keep high precision.
// half(uint16_t), bfloat16, float -> float.
const float p_out_f =
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
const float s_out_f =
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
// float -> half(uint16_t), bfloat16, float.
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
o_out_f[i] = p_out_f * p_scale + (s_out_f * s_scale);
}
// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
o_out_pack;
// Convert and store
if constexpr (USE_FP8_OUTPUT) {
output_t o_out_pack[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
o_out_pack[i] = vllm::scaled_fp8_conversion<true, output_t>(
o_out_f[i], fp8_scale_inv);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] =
*reinterpret_cast<output_pack_t*>(o_out_pack);
} else {
output_pack_t o_out_pack;
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
o_out_f[i]);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] = o_out_pack;
}
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
@@ -134,50 +209,73 @@ __global__ void merge_attn_states_kernel(
} \
}
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, output_t, NUM_THREADS, \
USE_FP8_OUTPUT) \
{ \
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
vllm::merge_attn_states_kernel<scalar_t, output_t, NUM_THREADS, \
USE_FP8_OUTPUT> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<output_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size, prefix_head_stride, output_head_stride); \
num_heads, head_size, prefix_head_stride, output_head_stride, \
prefix_num_tokens, output_scale_ptr); \
}
/*@brief Merges the attention states from prefix and suffix
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
*
* @param output [n,h,d] The output tensor to store the merged attention states.
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
* @param output_lse [h,n] Optional tensor to store the log-sum-exp values.
* @param prefix_output [n,h,d] The prefix attention states.
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
* states.
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
* states.
* @param prefill_tokens_with_context Number of prefill tokens with context
* For the first p tokens (0 <= token_idx < prefill_tokens_with_context), output
* is computed by merging prefix_output and suffix_output. For remaining tokens
* (prefill_tokens_with_context <= token_idx < n), output is copied directly
* from suffix_output.
* @param output_scale Optional scalar tensor for FP8 static quantization.
* When provided, output must be FP8 dtype.
*/
template <typename scalar_t>
void merge_attn_states_launcher(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse) {
void merge_attn_states_launcher(
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
const std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::Tensor>& output_scale) {
constexpr uint NUM_THREADS = 128;
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
const uint head_size = output.size(2);
const uint prefix_head_stride = prefix_output.stride(1);
const uint output_head_stride = output.stride(1);
// Thread mapping is based on input BF16 pack_size
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
const uint prefix_num_tokens =
prefill_tokens_with_context.has_value()
? static_cast<uint>(prefill_tokens_with_context.value())
: num_tokens;
TORCH_CHECK(prefix_num_tokens <= num_tokens,
"prefix_num_tokens must be <= num_tokens");
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();
}
float* output_scale_ptr = nullptr;
if (output_scale.has_value()) {
output_scale_ptr = output_scale.value().data_ptr<float>();
}
// Process one pack elements per thread. for float, the
// pack_size is 4 for half/bf16, the pack_size is 8.
const uint threads_per_head = head_size / pack_size;
@@ -189,14 +287,22 @@ void merge_attn_states_launcher(torch::Tensor& output,
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
auto stream = at::cuda::getCurrentCUDAStream();
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
if (output_scale.has_value()) {
// FP8 output path - dispatch on output FP8 type
VLLM_DISPATCH_FP8_TYPES(output.scalar_type(), "merge_attn_states_fp8", [&] {
LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true);
});
} else {
// Original BF16/FP16/FP32 output path
LAUNCH_MERGE_ATTN_STATES(scalar_t, scalar_t, NUM_THREADS, false);
}
}
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
prefix_lse, suffix_output, \
suffix_lse); \
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
merge_attn_states_launcher<scalar_t>( \
output, output_lse, prefix_output, prefix_lse, suffix_output, \
suffix_lse, prefill_tokens_with_context, output_scale); \
}
void merge_attn_states(torch::Tensor& output,
@@ -204,6 +310,21 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse) {
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
const torch::Tensor& suffix_lse,
std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::Tensor>& output_scale) {
if (output_scale.has_value()) {
TORCH_CHECK(output.scalar_type() == at::ScalarType::Float8_e4m3fn ||
output.scalar_type() == at::ScalarType::Float8_e4m3fnuz,
"output must be FP8 when output_scale is provided, got: ",
output.scalar_type());
} else {
TORCH_CHECK(output.scalar_type() == prefix_output.scalar_type(),
"output dtype (", output.scalar_type(),
") must match prefix_output dtype (",
prefix_output.scalar_type(), ") when output_scale is not set");
}
// Always dispatch on prefix_output (input) dtype
DISPATCH_BY_SCALAR_DTYPE(prefix_output.dtype(),
CALL_MERGE_ATTN_STATES_LAUNCHER);
}

View File

@@ -10,6 +10,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping);
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,

View File

@@ -24,6 +24,8 @@
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include <cuda.h>
#endif
#if defined(__gfx942__)
@@ -73,6 +75,59 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
}
}
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes) {
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
TORCH_CHECK(src_ptrs.dtype() == torch::kInt64, "src_ptrs must be int64");
TORCH_CHECK(dst_ptrs.dtype() == torch::kInt64, "dst_ptrs must be int64");
TORCH_CHECK(sizes.dtype() == torch::kInt64, "sizes must be int64");
const int64_t n = src_ptrs.size(0);
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");
if (n == 0) return;
const int64_t* src_data = src_ptrs.data_ptr<int64_t>();
const int64_t* dst_data = dst_ptrs.data_ptr<int64_t>();
const int64_t* size_data = sizes.data_ptr<int64_t>();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single
// driver call, amortizing per-copy submission overhead.
// int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms,
// so we reinterpret_cast the tensor data directly to avoid copies.
static_assert(sizeof(CUdeviceptr) == sizeof(int64_t));
static_assert(sizeof(size_t) == sizeof(int64_t));
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080
CUmemcpyAttributes attr = {};
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
size_t attrs_idx = 0;
size_t fail_idx = 0;
CUresult result = cuMemcpyBatchAsync(
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(dst_data)),
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(src_data)),
reinterpret_cast<size_t*>(const_cast<int64_t*>(size_data)),
static_cast<size_t>(n), &attr, &attrs_idx, 1, &fail_idx,
static_cast<CUstream>(stream));
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result);
#else
// Fallback for CUDA < 12.8 and ROCm: individual async copies.
// cudaMemcpyDefault lets the driver infer direction from pointer types.
for (int64_t i = 0; i < n; i++) {
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
reinterpret_cast<void*>(src_data[i]),
static_cast<size_t>(size_data[i]), cudaMemcpyDefault,
stream);
}
#endif
}
namespace vllm {
// Grid: (num_layers, num_pairs)

View File

@@ -30,13 +30,15 @@
}()
namespace {
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul };
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul, GeluAndMul };
FusedMOEAct get_act_type(const std::string& act) {
if (act == "silu") {
return FusedMOEAct::SiluAndMul;
} else if (act == "swigluoai") {
return FusedMOEAct::SwigluOAIAndMul;
} else if (act == "gelu") {
return FusedMOEAct::GeluAndMul;
} else {
TORCH_CHECK(false, "Invalid act type: " + act);
}
@@ -104,6 +106,43 @@ void silu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
}
}
template <typename scalar_t>
void gelu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
const int32_t m_size, const int32_t n_size,
const int32_t input_stride, const int32_t output_stride) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
const int32_t dim = n_size / 2;
float* __restrict__ gate = input;
float* __restrict__ up = input + dim;
vec_op::FP32Vec16 one_vec(1.0);
vec_op::FP32Vec16 w1_vec(M_SQRT1_2);
vec_op::FP32Vec16 w2_vec(0.5);
alignas(64) float temp[16];
DEFINE_FAST_EXP
for (int32_t m = 0; m < m_size; ++m) {
for (int32_t n = 0; n < dim; n += 16) {
vec_op::FP32Vec16 gate_vec(gate + n);
vec_op::FP32Vec16 up_vec(up + n);
auto er_input_vec = gate_vec * w1_vec;
er_input_vec.save(temp);
for (int32_t i = 0; i < 16; ++i) {
temp[i] = std::erf(temp[i]);
}
vec_op::FP32Vec16 er_vec(temp);
auto gelu = gate_vec * w2_vec * (one_vec + er_vec);
auto gated_output_fp32 = up_vec * gelu;
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
gated_output.save(output + n);
}
gate += input_stride;
up += input_stride;
output += output_stride;
}
}
template <typename scalar_t>
FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
float* __restrict__ input,
@@ -118,6 +157,9 @@ FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
case FusedMOEAct::SiluAndMul:
silu_and_mul(input, output, m, n, input_stride, output_stride);
return;
case FusedMOEAct::GeluAndMul:
gelu_and_mul(input, output, m, n, input_stride, output_stride);
return;
default:
TORCH_CHECK(false, "Unsupported act type.");
}

View File

@@ -8,7 +8,7 @@ Generate CPU attention dispatch switch cases and kernel instantiations.
import os
# Head dimensions divisible by 32 (support all ISAs)
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256]
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256, 512]
# Head dimensions divisible by 16 but not 32 (VEC16 only)
HEAD_DIMS_16 = [80, 112]

View File

@@ -117,6 +117,14 @@ inline void parallel_for(int n, const func_t& f) {
#endif
}
inline int get_thread_num() {
#if defined(_OPENMP)
return omp_get_thread_num();
#else
return 0;
#endif
}
// for 1d parallel, use `actual_nth`
// for 2d parallel, use even nths, e.g. 43->42
int inline adjust_num_threads(int m) {

View File

@@ -17,8 +17,8 @@ constexpr int block_size_n() { return 2 * TILE_N; }
template <typename T> inline bool can_use_brgemm(int M);
template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; }
template <> inline bool can_use_brgemm<at::Half>(int M) { return true; }
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
template <> inline bool can_use_brgemm<int8_t>(int M) { return false; }
template <> inline bool can_use_brgemm<int8_t>(int M) { return M > 4; }
template <> inline bool can_use_brgemm<uint8_t>(int M) { return M > 4; }
template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; }
template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; }
@@ -40,9 +40,17 @@ inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
}
// pack weight to vnni format
inline int64_t get_4bit_block_k_size(int64_t group_size) {
return group_size > 128 ? 128 : group_size;
}
// pack weight into vnni format
at::Tensor convert_weight_packed(at::Tensor& weight);
// pack weight to vnni format for int4 (adapted from sglang)
std::tuple<at::Tensor, at::Tensor, at::Tensor>
convert_weight_packed_scale_zp(at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);
// moe implementations for int8 w8a8
template <typename scalar_t>
void fused_experts_int8_kernel_impl(
@@ -233,6 +241,31 @@ void tinygemm_kernel(
int64_t strideBs,
bool brg);
// int4 scaled GEMM (adapted from sglang)
at::Tensor int4_scaled_mm_cpu(
at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros, at::Tensor& w_scales, std::optional<at::Tensor> bias);
// int4 tinygemm kernel interface(adapted from sglang)
template <typename scalar_t>
void tinygemm_kernel(
scalar_t* C,
float* C_temp,
const uint8_t* A,
const float* scales_a,
const int32_t* qzeros_a,
const uint8_t* B,
const float* scales_b,
const int8_t* qzeros_b,
const int32_t* compensation,
int8_t* dqB_tmp,
int64_t M,
int64_t K,
int64_t lda,
int64_t ldc_f,
int64_t ldc_s,
bool store_out,
bool use_brgemm);
// TODO: debug print, remove me later
inline void print_16x32i(const __m512i x) {
int32_t a[16];

View File

@@ -0,0 +1,755 @@
// SPDX-License-Identifier: Apache-2.0
// Adapted from sgl-project/sglang
// https://github.com/sgl-project/sglang/pull/8226
#include <ATen/ATen.h>
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
#define BLOCK_N block_size_n()
#define BLOCK_M 128
template <bool sym_quant_act>
struct ActDtype;
template <>
struct ActDtype<true> {
using type = int8_t;
};
template <>
struct ActDtype<false> {
using type = uint8_t;
};
struct alignas(32) m256i_wrapper {
__m256i data;
};
#if defined(CPU_CAPABILITY_AVX512)
inline std::array<m256i_wrapper, 2> load_zps_4vnni(
const int8_t* __restrict__ zps) {
__m256i vzps_low = _mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(zps));
__m256i vzps_high =
_mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(zps + 8));
__m256i shuffle_mask =
_mm256_set_epi8(7, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3,
3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0);
vzps_low = _mm256_shuffle_epi8(vzps_low, shuffle_mask);
vzps_high = _mm256_shuffle_epi8(vzps_high, shuffle_mask);
m256i_wrapper vzps_low_wp, vzps_high_wp;
vzps_low_wp.data = vzps_low;
vzps_high_wp.data = vzps_high;
return {vzps_low_wp, vzps_high_wp};
}
inline std::array<m256i_wrapper, 2> load_uint4_as_int8(
const uint8_t* __restrict__ qB) {
__m256i packed = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(qB));
const __m256i low_mask = _mm256_set1_epi8(0x0f);
__m256i high = _mm256_srli_epi16(packed, 4);
high = _mm256_and_si256(high, low_mask);
__m256i low = _mm256_and_si256(packed, low_mask);
m256i_wrapper low_wp, high_wp;
low_wp.data = low;
high_wp.data = high;
return {low_wp, high_wp};
}
template <int N, int ldb>
void _dequant_weight_zp_only(const uint8_t* __restrict__ B, int8_t* dqB,
const int8_t* __restrict__ qzeros, int64_t K) {
#pragma GCC unroll 2
for (int n = 0; n < N; n += 16) {
auto [zps_low_wp, zps_high_wp] = load_zps_4vnni(&qzeros[n]);
auto zps_low = zps_low_wp.data;
auto zps_high = zps_high_wp.data;
for (int k = 0; k < K; k += 4) {
auto [vb_low_wp, vb_high_wp] =
load_uint4_as_int8(B + ldb * k + n / 2 * 4);
auto vb_low = vb_low_wp.data;
auto vb_high = vb_high_wp.data;
vb_high = _mm256_sub_epi8(vb_high, zps_high);
vb_low = _mm256_sub_epi8(vb_low, zps_low);
_mm256_storeu_si256(reinterpret_cast<__m256i_u*>(dqB + N * k + n * 4),
vb_low);
_mm256_storeu_si256(
reinterpret_cast<__m256i_u*>(dqB + N * k + (n + 8) * 4), vb_high);
}
}
}
template <bool sym_quant_act, int N, bool accum>
void _dequant_and_store(float* __restrict__ output,
const int32_t* __restrict__ input,
const float* __restrict__ scale_a,
const int32_t* __restrict__ zp_a,
const float* __restrict__ scale_b,
const int32_t* __restrict__ comp_b, int M, int ldi,
int ldo, int ldsa = 1) {
for (int m = 0; m < M; ++m) {
float a_scale = *(scale_a + m * ldsa);
__m512 va_scale = _mm512_set1_ps(a_scale);
int32_t a_zp;
__m512i va_zp;
if constexpr (!sym_quant_act) {
a_zp = *(zp_a + m * ldsa);
va_zp = _mm512_set1_epi32(a_zp);
}
int n = 0;
#pragma GCC unroll 2
for (; n < N; n += 16) {
__m512i vc = _mm512_loadu_si512(input + m * ldi + n);
if constexpr (!sym_quant_act) {
__m512i vb_comp = _mm512_loadu_si512(comp_b + n);
vc = _mm512_sub_epi32(vc, _mm512_mullo_epi32(vb_comp, va_zp));
}
__m512 vc_f = _mm512_cvtepi32_ps(vc);
__m512 vc_f_mul = _mm512_mul_ps(vc_f, va_scale);
__m512 vb_s = _mm512_loadu_ps(scale_b + n);
vc_f_mul = _mm512_mul_ps(vc_f_mul, vb_s);
if constexpr (accum) {
__m512 vo = _mm512_loadu_ps(output + m * ldo + n);
_mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f_mul));
} else {
_mm512_storeu_ps(output + m * ldo + n, vc_f_mul);
}
}
for (; n < N; ++n) {
float dq_val;
if constexpr (sym_quant_act) {
dq_val = (float)input[m * ldi + n] * a_scale * scale_b[n];
} else {
dq_val = (float)(input[m * ldi + n] - a_zp * comp_b[n]) * a_scale *
scale_b[n];
}
if constexpr (accum) {
output[m * ldo + n] += dq_val;
} else {
output[m * ldo + n] = dq_val;
}
}
}
}
#else
template <int N, int ldb>
void _dequant_weight_zp_only(const uint8_t* B, int8_t* dqB,
const int8_t* qzeros, int64_t K) {
for (int k = 0; k < K; ++k) {
for (int n = 0; n < N / 2; ++n) {
int32_t b = (int32_t)B[k * ldb + n];
dqB[k * N + n * 2] = (b & 0xf) - qzeros[n];
dqB[k * N + n * 2 + 1] = (b >> 4) - qzeros[n];
}
}
}
#endif
#if defined(CPU_CAPABILITY_AVX512)
inline __m512i combine_m256i(__m256i a, __m256i b) {
__m512i c = _mm512_castsi256_si512(a);
return _mm512_inserti64x4(c, b, 1);
}
inline __m512i combine_m256i(std::array<m256i_wrapper, 2> two_256) {
return combine_m256i(two_256[0].data, two_256[1].data);
}
static inline __m512i _mm512_sign_epi8(__m512i a, __m512i b) {
__m512i zero = _mm512_setzero_si512();
__mmask64 blt0 = _mm512_movepi8_mask(b);
return _mm512_mask_sub_epi8(a, blt0, zero, a);
}
template <bool sym_quant_act, int M, int N, int ldb>
void _dequant_gemm_accum_small_M(float* __restrict__ C, const uint8_t* A,
const float* scales_a, const int32_t* qzeros_a,
const uint8_t* B, const float* scales_b,
const int8_t* qzeros_b, int64_t K, int64_t lda,
int64_t ldc) {
constexpr int COLS = N / 16;
__m512i ones = _mm512_set1_epi8(1);
__m512i va;
__m512i vb[COLS];
__m512i vc[M * COLS];
__m512 vscales[COLS];
__m512i vzps[COLS];
__m512i vcompensate[COLS];
Unroll<COLS>{}([&](auto i) {
vscales[i] = _mm512_loadu_ps(scales_b + i * 16);
vzps[i] = combine_m256i(load_zps_4vnni(qzeros_b + i * 16));
if constexpr (!sym_quant_act) {
vcompensate[i] = _mm512_setzero_epi32();
}
});
Unroll<M * COLS>{}([&](auto i) { vc[i] = _mm512_setzero_epi32(); });
auto compute = [&](auto i, int k) {
constexpr const int row = i / COLS;
constexpr const int col = i % COLS;
if constexpr (col == 0) {
va = _mm512_set1_epi32(*(int32_t*)(A + row * lda + k));
}
if constexpr (row == 0) {
int B_offset = k * ldb + col * 16 * 2;
vb[col] = combine_m256i(load_uint4_as_int8(B + B_offset));
vb[col] = _mm512_sub_epi8(vb[col], vzps[col]);
if constexpr (!sym_quant_act) {
vcompensate[col] = _mm512_dpbusd_epi32(vcompensate[col], ones, vb[col]);
}
_mm_prefetch(B + B_offset + 128 * ldb, _MM_HINT_T0);
}
if constexpr (sym_quant_act) {
auto vsb = _mm512_sign_epi8(vb[col], va);
auto vabsa = _mm512_sign_epi8(va, va);
vc[i] = _mm512_dpbusds_epi32(vc[i], vabsa, vsb);
} else {
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
}
};
constexpr const int unroll = 4;
int k = 0;
for (; k < K / 4 / unroll; k++) {
Unroll<unroll>{}(
[&](auto i) { Unroll<M * COLS>{}(compute, 4 * (k * unroll + i)); });
}
k *= 4 * unroll;
for (; k < K; k += 4) {
Unroll<M * COLS>{}(compute, k);
}
auto store = [&](auto i) {
constexpr const int row = i / COLS;
constexpr const int col = i % COLS;
__m512 vc_float;
if constexpr (!sym_quant_act) {
vc[i] = _mm512_sub_epi32(
vc[i], _mm512_mullo_epi32(vcompensate[col],
_mm512_set1_epi32(*(qzeros_a + row))));
}
vc_float = _mm512_cvtepi32_ps(vc[i]);
vc_float = _mm512_mul_ps(vc_float, _mm512_set1_ps(*(scales_a + row)));
vc_float = _mm512_mul_ps(vc_float, vscales[col]);
auto vc_old = _mm512_loadu_ps(C + row * ldc + col * 16);
vc_float = _mm512_add_ps(vc_float, vc_old);
_mm512_storeu_ps(C + row * ldc + col * 16, vc_float);
};
Unroll<M * COLS>{}(store);
}
#define CALL_DEQUANT_GEMM_ACCUM_SMALL_M(M) \
_dequant_gemm_accum_small_M<sym_quant_act, M, N, ldb>( \
C, A, scales_a, qzeros_a, B, scales_b, qzeros_b, K, lda, ldc);
#endif
template <bool sym_quant_act, int N, int ldb>
void _dequant_gemm_accum(float* C, const uint8_t* A, const float* scales_a,
const int32_t* qzeros_a, const uint8_t* B,
const float* scales_b, const int8_t* qzeros_b,
const int32_t* compensation, int8_t* dqB, int64_t M,
int64_t K, int64_t lda, int64_t ldc, bool use_brgemm) {
#if defined(CPU_CAPABILITY_AVX512)
if (!use_brgemm) {
switch (M) {
case 1:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(1);
break;
case 2:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(2);
break;
case 3:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(3);
break;
case 4:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(4);
break;
default:
TORCH_CHECK(false, "tinygemm_kernel: unexpected M for AVX path!");
}
return;
}
_dequant_weight_zp_only<N, ldb>(B, dqB, qzeros_b, K);
using Tin = typename ActDtype<sym_quant_act>::type;
Tin* A_ptr = (Tin*)A;
if (use_brgemm) {
int32_t C_i32[M * N];
at::native::cpublas::brgemm(M, N, K, lda, N /*ldb*/, N /*ldc*/,
false /* add_C */, A_ptr, dqB, C_i32,
true /* is_vnni */);
_mm_prefetch(B + N * K / 2, _MM_HINT_T0);
_mm_prefetch(A + K, _MM_HINT_T0);
_dequant_and_store<sym_quant_act, N, true>(C, C_i32, scales_a, qzeros_a,
scales_b, compensation, M,
N /*ldi*/, ldc, 1 /*ldsa*/);
} else
#endif
{
TORCH_CHECK(false, "tinygemm_kernel: scalar path not implemented!");
}
}
template <int N>
inline void copy_bias(const float* bias_ptr, float* y_buf, int64_t m) {
if (bias_ptr) {
for (int i = 0; i < m; ++i) {
int j = 0;
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for (; j < N; j += 16) {
__m512 bias_vec = _mm512_loadu_ps(bias_ptr + j);
_mm512_storeu_ps(y_buf + i * N + j, bias_vec);
}
#endif
for (; j < N; ++j) {
y_buf[i * N + j] = bias_ptr[j];
}
}
} else {
for (int i = 0; i < m; ++i) {
int j = 0;
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for (; j < N; j += 16) {
__m512 zero_vec = _mm512_setzero_ps();
_mm512_storeu_ps(y_buf + i * N + j, zero_vec);
}
#endif
for (; j < N; ++j) {
y_buf[i * N + j] = 0;
}
}
}
}
template <int N, typename out_dtype>
inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m,
int64_t lda) {
for (int i = 0; i < m; ++i) {
int j = 0;
if constexpr (std::is_same<out_dtype, float>::value) {
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for (; j < N; j += 16) {
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
_mm512_storeu_ps(c_ptr + i * lda + j, y_vec);
}
#endif
for (; j < N; ++j) {
c_ptr[i * lda + j] = y_buf[i * N + j];
}
} else if constexpr (std::is_same<out_dtype, at::BFloat16>::value) {
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for (; j < N; j += 16) {
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
__m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j),
y_bf16_vec);
}
#endif
for (; j < N; ++j) {
c_ptr[i * lda + j] = at::BFloat16(y_buf[i * N + j]);
}
} else if constexpr (std::is_same<out_dtype, at::Half>::value) {
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for (; j < N; j += 16) {
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
__m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j),
y_fp16_vec);
}
#endif
for (; j < N; ++j) {
c_ptr[i * lda + j] = at::Half(y_buf[i * N + j]);
}
} else {
TORCH_CHECK(false, "Unsupported output dtype");
}
}
}
void fill_val_stub(int32_t* __restrict__ output, int32_t value, int64_t size) {
using iVec = at::vec::Vectorized<int32_t>;
constexpr int VecSize = iVec::size();
const iVec fill_val_vec = iVec(value);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - VecSize; d += VecSize) {
fill_val_vec.store(output + d);
}
for (; d < size; ++d) {
output[d] = value;
}
}
template <bool sym_quant_act, typename act_dtype, typename out_dtype>
void _da8w4_linear_impl(
act_dtype* __restrict__ input, const float* __restrict__ input_scales,
const int32_t* __restrict__ input_qzeros,
const uint8_t* __restrict__ weight, const float* __restrict__ weight_scales,
const int8_t* __restrict__ weight_qzeros, const float* __restrict__ bias,
out_dtype* __restrict__ output, float* __restrict__ output_temp,
int8_t* __restrict__ dequant_weight_temp, int64_t M, int64_t N, int64_t K,
int64_t num_groups) {
const bool use_brgemm = can_use_brgemm<act_dtype>(M);
int64_t block_m = [&]() -> long {
if (M <= 48) {
return M;
} else if (M < 64) {
return 32;
} else if (M < 96) {
return 64;
} else {
return 128;
}
}();
int64_t Mc = div_up(M, block_m);
bool parallel_on_M = M > 128;
int64_t Nc = N / BLOCK_N;
int64_t num_blocks = parallel_on_M ? Mc * Nc : Nc;
int64_t group_size = div_up(K, num_groups);
int64_t _block_k = get_4bit_block_k_size(group_size);
int64_t Kc = K / _block_k;
int64_t block_per_group = group_size / _block_k;
at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) {
int tid = get_thread_num();
float* C_tmp = output_temp + tid * block_m * BLOCK_N;
int8_t* dqB_tmp = dequant_weight_temp + tid * _block_k * BLOCK_N;
for (const auto i : c10::irange(begin, end)) {
int64_t mc = parallel_on_M ? i / Nc : 0;
int64_t nc = parallel_on_M ? i % Nc : i;
int64_t mc_end = parallel_on_M ? mc + 1 : Mc;
for (int mci = mc; mci < mc_end; ++mci) {
int64_t m_size =
mci * block_m + block_m > M ? M - mci * block_m : block_m;
auto bias_data = bias ? bias + nc * BLOCK_N : nullptr;
copy_bias<BLOCK_N>(bias_data, C_tmp, m_size);
for (int kci = 0; kci < Kc; ++kci) {
int32_t* compensation_ptr =
sym_quant_act
? nullptr
: (int32_t*)(void*)(weight +
(nc * Kc + kci) *
(BLOCK_N *
(_block_k / 2 + sizeof(int32_t))) +
_block_k * BLOCK_N / 2);
_dequant_gemm_accum<sym_quant_act, BLOCK_N, BLOCK_N / 2>(
/*C*/ C_tmp,
/*A*/ (uint8_t*)input + mci * block_m * K + kci * _block_k,
/*scales_a*/ input_scales + mci * block_m,
/*qzeros_a*/ input_qzeros + mci * block_m,
/*B*/ weight + (nc * Kc + kci) *
(BLOCK_N * (_block_k / 2 + sizeof(int32_t))),
/*scales_b*/ weight_scales + nc * BLOCK_N * num_groups +
kci / block_per_group * BLOCK_N,
/*qzeros_b*/ weight_qzeros + nc * BLOCK_N * num_groups +
kci / block_per_group * BLOCK_N,
/*Bcomp*/ compensation_ptr,
/*dqB_tmp*/ dqB_tmp,
/*M*/ m_size,
/*K*/ _block_k,
/*lda*/ K,
/*ldc*/ BLOCK_N,
/*use_brgemm*/ use_brgemm);
}
store_out<BLOCK_N>(C_tmp, output + mci * block_m * N + nc * BLOCK_N,
m_size, N /*lda*/);
}
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
} // anonymous namespace
std::tuple<at::Tensor, at::Tensor, at::Tensor>
convert_int4_weight_packed_with_compensation(const at::Tensor& weight,
const at::Tensor& scales,
const at::Tensor& qzeros) {
TORCH_CHECK(weight.dim() == 2,
"DA8W4 CPU: Weight should be a 2D tensor for packing");
TORCH_CHECK(
weight.size(1) % 2 == 0,
"DA8W4 CPU: Weight should have even number of columns for packing");
auto new_scales = scales;
auto new_qzeros = qzeros;
if (new_scales.dim() == 1) {
new_scales.unsqueeze_(1);
}
new_scales = new_scales.to(at::kFloat);
if (new_qzeros.dim() == 1) {
new_qzeros.unsqueeze_(1);
}
new_qzeros = new_qzeros.to(at::kChar);
int64_t N = weight.size(0);
int64_t K = weight.size(1);
int64_t G = scales.size(1);
int64_t group_size = K / G;
int64_t _block_k = get_4bit_block_k_size(group_size);
constexpr int block_n = block_size_n();
int64_t Nc = N / block_n;
int64_t Kc = K / _block_k;
auto weight_view = weight.view({Nc, block_n, Kc, _block_k});
at::Tensor weight_reordered = weight_view.permute({0, 2, 3, 1}).contiguous();
at::Tensor blocked_weight;
at::Tensor blocked_scales =
new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous();
at::Tensor blocked_qzeros =
new_qzeros.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous();
auto weight_sub_qzero = weight.view({Nc, block_n, G, -1}).to(at::kInt) -
new_qzeros.view({Nc, block_n, G, -1});
weight_sub_qzero = weight_sub_qzero.view({Nc, block_n, Kc, _block_k});
at::Tensor compensation = weight_sub_qzero.sum(-1);
compensation = compensation.permute({0, 2, 1}).contiguous().to(at::kInt);
int64_t buffer_size_nbytes =
_block_k * block_n / 2 + block_n * sizeof(int32_t);
blocked_weight = at::empty({Nc, Kc, buffer_size_nbytes}, weight.options());
auto weight_ptr = weight_reordered.data_ptr<uint8_t>();
auto compensation_ptr = compensation.data_ptr<int32_t>();
auto blocked_weight_ptr = blocked_weight.data_ptr<uint8_t>();
int64_t num_blocks = Nc * Kc;
at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
auto in_ptr = weight_ptr + i * _block_k * block_n;
auto out_ptr =
blocked_weight_ptr + i * block_n * (_block_k / 2 + sizeof(int32_t));
int32_t* comp_in_prt = compensation_ptr + i * block_n;
int32_t* comp_out_prt =
(int32_t*)(void*)(blocked_weight_ptr +
i * block_n * (_block_k / 2 + sizeof(int32_t)) +
_block_k * block_n / 2);
constexpr int n_group_size = 8;
constexpr int vnni_size = 4;
constexpr int n_group = block_n / n_group_size;
for (int nb = 0; nb < n_group; nb += 2) {
for (int k = 0; k < _block_k; k += vnni_size) {
for (int ni = 0; ni < n_group_size; ++ni) {
for (int ki = 0; ki < vnni_size; ++ki) {
int src_idx_1 = nb * n_group_size + ni + (k + ki) * block_n;
int src_idx_2 = (nb + 1) * n_group_size + ni + (k + ki) * block_n;
int dst_idx = (nb / 2 * n_group_size + ni) * vnni_size +
k * block_n / 2 + ki;
uint8_t src_1 = *(in_ptr + src_idx_1);
uint8_t src_2 = *(in_ptr + src_idx_2);
uint8_t dst = (src_1 & 0x0f) | ((src_2 & 0x0f) << 4);
*(out_ptr + dst_idx) = dst;
}
}
}
}
for (int nb = 0; nb < block_n; nb++) {
*(comp_out_prt + nb) = *(comp_in_prt + nb);
}
}
});
return std::make_tuple(std::move(blocked_weight), std::move(blocked_scales),
std::move(blocked_qzeros));
}
std::tuple<at::Tensor, at::Tensor> autoawq_to_int4pack(at::Tensor qweight,
at::Tensor qzeros) {
auto bitshifts = at::tensor({0, 4, 1, 5, 2, 6, 3, 7}, at::kInt) * 4;
auto qweight_unsq = qweight.unsqueeze(-1);
auto unpacked = at::bitwise_right_shift(qweight_unsq, bitshifts) & 0xF;
auto qweight_final = unpacked.flatten(-2).transpose(-1, -2).to(at::kByte);
auto qzeros_unsq = qzeros.unsqueeze(-1);
auto qzeros_unpacked = at::bitwise_right_shift(qzeros_unsq, bitshifts) & 0xF;
auto qzeros_final = qzeros_unpacked.flatten(-2).to(at::kByte);
return std::make_tuple(qweight_final, qzeros_final);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> convert_weight_packed_scale_zp(
at::Tensor qweight, at::Tensor qzeros, at::Tensor scales) {
auto res = autoawq_to_int4pack(qweight, qzeros);
auto _qweight = std::get<0>(res);
auto _qzeros = std::get<1>(res);
auto _scales = scales;
_qzeros = _qzeros.transpose(-2, -1).contiguous();
_scales = _scales.transpose(-2, -1).contiguous();
if (_qweight.dim() == 3) {
int64_t E = _qweight.size(0);
int64_t K = _qweight.size(2);
int64_t G = _scales.size(2);
int64_t group_size = K / G;
int64_t _block_k = get_4bit_block_k_size(group_size);
int64_t block_n = block_size_n();
int64_t Nc = _qweight.size(1) / block_n;
int64_t Kc = K / _block_k;
int64_t buffer_size_nbytes =
_block_k * block_n / 2 + block_n * sizeof(int32_t);
auto blocked_weight =
at::empty({E, Nc, Kc, buffer_size_nbytes}, _qweight.options());
auto blocked_scales =
at::empty({E, Nc, G, block_n}, _scales.options()).to(at::kFloat);
auto blocked_qzeros =
at::empty({E, Nc, G, block_n}, _qzeros.options()).to(at::kChar);
for (int i = 0; i < _qweight.size(0); i++) {
auto res_ = convert_int4_weight_packed_with_compensation(
_qweight[i], _scales[i], _qzeros[i]);
blocked_weight[i] = std::get<0>(res_);
blocked_scales[i] = std::get<1>(res_);
blocked_qzeros[i] = std::get<2>(res_);
}
_qweight = blocked_weight;
_scales = blocked_scales;
_qzeros = blocked_qzeros;
} else {
auto res_ = convert_int4_weight_packed_with_compensation(_qweight, _scales,
_qzeros);
_qweight = std::get<0>(res_);
_scales = std::get<1>(res_);
_qzeros = std::get<2>(res_);
}
return std::make_tuple(_qweight, _qzeros, _scales);
}
at::Tensor int4_scaled_mm_cpu_with_quant(const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& weight_scales,
const at::Tensor& weight_qzeros,
const std::optional<at::Tensor>& bias,
at::ScalarType output_dtype) {
RECORD_FUNCTION("vllm::int4_scaled_mm_cpu_with_quant",
std::vector<c10::IValue>({input, weight}));
int64_t M_a = input.size(0);
int64_t K_a = input.size(1);
int64_t lda = input.stride(0);
const auto st = input.scalar_type();
TORCH_CHECK(
st == at::kBFloat16 || st == at::kHalf,
"int4_scaled_mm_cpu_with_quant: expect A to be bfloat16 or half.");
constexpr bool sym_quant_act = false;
using Tin = typename ActDtype<sym_quant_act>::type;
int64_t act_buffer_size =
M_a * K_a + M_a * sizeof(float) + M_a * sizeof(int32_t);
auto act_buffer =
at::empty({act_buffer_size}, input.options().dtype(at::kByte));
auto Aq_data = act_buffer.data_ptr<uint8_t>();
auto As_data = reinterpret_cast<float*>(Aq_data + M_a * K_a);
auto Azp_data = reinterpret_cast<int32_t*>(As_data + M_a);
fill_val_stub(Azp_data, 128, M_a);
auto out_sizes = input.sizes().vec();
int64_t N = weight_scales.size(0) * weight_scales.size(-1);
out_sizes.back() = N;
auto output = at::empty(out_sizes, input.options());
int64_t Nc = weight.size(0);
int64_t Kc = weight.size(1);
int64_t _block_k = K_a / Kc;
TORCH_CHECK(N == Nc * BLOCK_N, "DA8W4: weight and input shapes mismatch");
int64_t num_groups = weight_scales.size(1);
const uint8_t* b_ptr = weight.data_ptr<uint8_t>();
const float* b_scales_ptr = weight_scales.data_ptr<float>();
const int8_t* b_qzeros_ptr = weight_qzeros.data_ptr<int8_t>();
const float* bias_ptr =
bias.has_value() ? bias.value().data_ptr<float>() : nullptr;
int num_threads = at::get_num_threads();
int64_t temp_buffer_size = num_threads * BLOCK_M * BLOCK_N * sizeof(float) +
num_threads * _block_k * BLOCK_N;
auto c_temp_buffer =
at::empty({temp_buffer_size}, input.options().dtype(at::kChar));
float* c_temp_ptr = (float*)((void*)(c_temp_buffer.data_ptr<int8_t>()));
int8_t* dqB_temp_ptr =
(int8_t*)((void*)(c_temp_ptr + num_threads * BLOCK_M * BLOCK_N));
#define LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL(sym_quant_act) \
AT_DISPATCH_FLOATING_TYPES_AND2( \
at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, \
"int4_scaled_mm_cpu", [&] { \
const scalar_t* __restrict__ A_data = input.data_ptr<scalar_t>(); \
scalar_t* __restrict__ c_ptr = output.data_ptr<scalar_t>(); \
at::parallel_for(0, M_a, 0, [&](int64_t begin, int64_t end) { \
for (int64_t m = begin; m < end; ++m) { \
quantize_row_int8<scalar_t>(Aq_data + m * K_a, As_data[m], \
A_data + m * lda, K_a); \
} \
}); \
_da8w4_linear_impl<sym_quant_act, Tin, scalar_t>( \
Aq_data, As_data, Azp_data, b_ptr, b_scales_ptr, b_qzeros_ptr, \
bias_ptr, c_ptr, c_temp_ptr, dqB_temp_ptr, M_a, N, K_a, \
num_groups); \
});
LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL(sym_quant_act);
return output;
}
namespace {
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out,
const float* __restrict__ input, int64_t size) {
using Vec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += Vec::size()) {
fVec x0 = fVec::loadu(input + d);
fVec x1 = fVec::loadu(input + d + fVec::size());
Vec res = convert_from_float_ext<scalar_t>(x0, x1);
res.store(out + d);
}
}
} // anonymous namespace
template <typename scalar_t>
void tinygemm_kernel(scalar_t* C, float* C_temp, const uint8_t* A,
const float* scales_a, const int32_t* qzeros_a,
const uint8_t* B, const float* scales_b,
const int8_t* qzeros_b, const int32_t* compensation,
int8_t* dqB_tmp, int64_t M, int64_t K, int64_t lda,
int64_t ldc_f, int64_t ldc_s, bool store_out,
bool use_brgemm) {
_dequant_gemm_accum<false, BLOCK_N, BLOCK_N / 2>(
C_temp, A, scales_a, qzeros_a, B, scales_b, qzeros_b, compensation,
dqB_tmp, M, K, lda, ldc_f, use_brgemm);
if (store_out) {
for (int64_t m = 0; m < M; ++m) {
copy_stub<scalar_t>(C + m * ldc_s, C_temp + m * ldc_f, BLOCK_N);
}
}
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
TYPE * C, float* C_temp, const uint8_t* A, const float* scales_a, \
const int32_t* qzeros_a, const uint8_t* B, const float* scales_b, \
const int8_t* qzeros_b, const int32_t* compensation, int8_t* dqB_tmp, \
int64_t M, int64_t K, int64_t lda, int64_t ldc_f, int64_t ldc_s, \
bool store_out, bool use_brgemm)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
at::Tensor& w_scales,
std::optional<at::Tensor> bias) {
return int4_scaled_mm_cpu_with_quant(x, w, w_scales, w_zeros, bias,
x.scalar_type());
}

View File

@@ -79,6 +79,14 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, bool is_vnni);
// Adapted from sglang: INT4 W4A8 kernels
std::tuple<at::Tensor, at::Tensor, at::Tensor> convert_weight_packed_scale_zp(
at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);
at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
at::Tensor& w_scales,
std::optional<at::Tensor> bias);
torch::Tensor get_scheduler_metadata(
const int64_t num_req, const int64_t num_heads_q,
const int64_t num_heads_kv, const int64_t head_dim,
@@ -285,6 +293,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
&int8_scaled_mm_with_quant);
// Adapted from sglang: INT4 W4A8 kernels
ops.def(
"convert_weight_packed_scale_zp(Tensor qweight, Tensor qzeros, "
"Tensor scales) -> (Tensor, Tensor, Tensor)");
ops.impl("convert_weight_packed_scale_zp", torch::kCPU,
&convert_weight_packed_scale_zp);
ops.def(
"int4_scaled_mm_cpu(Tensor(a0!) x, Tensor(a1!) w, Tensor(a2!) w_zeros, "
"Tensor(a3!) w_scales, Tensor? bias) -> Tensor");
ops.impl("int4_scaled_mm_cpu", torch::kCPU, &int4_scaled_mm_cpu);
#endif
// CPU attention kernels

View File

@@ -3,8 +3,8 @@
#pragma once
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Half.h>
#include <cassert>
#ifdef USE_ROCM

View File

@@ -6,14 +6,16 @@
#include <cstdio>
#include <cstdlib>
#include <torch/headeronly/util/shim_utils.h>
/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
STD_TORCH_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
}
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {

View File

@@ -1,7 +1,6 @@
#pragma once
#include <cute/tensor.hpp>
#include <torch/all.h>
namespace cute {
////////////////////////////////////////////////////////////////////

View File

@@ -189,9 +189,9 @@ struct Sm90RowOrScalarBroadcastArray {
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
@@ -211,8 +211,8 @@ struct Sm90RowOrScalarBroadcastArray {
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt);
}
}
@@ -241,9 +241,9 @@ struct Sm90RowOrScalarBroadcastArray {
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
@@ -251,16 +251,16 @@ struct Sm90RowOrScalarBroadcastArray {
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
@@ -389,7 +389,7 @@ struct Sm90ColOrScalarBroadcastArray {
CUTLASS_DEVICE void
begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol));
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
@@ -409,7 +409,7 @@ struct Sm90ColOrScalarBroadcastArray {
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
@@ -431,16 +431,16 @@ struct Sm90ColOrScalarBroadcastArray {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor cCol = make_identity_tensor(mCol.shape());
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(

View File

@@ -186,9 +186,9 @@ struct Sm90RowOrScalarBroadcast {
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
@@ -208,8 +208,8 @@ struct Sm90RowOrScalarBroadcast {
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt);
}
}
@@ -238,9 +238,9 @@ struct Sm90RowOrScalarBroadcast {
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
@@ -248,16 +248,16 @@ struct Sm90RowOrScalarBroadcast {
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
@@ -382,7 +382,7 @@ struct Sm90ColOrScalarBroadcast {
CUTLASS_DEVICE void
begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol));
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
@@ -402,7 +402,7 @@ struct Sm90ColOrScalarBroadcast {
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
@@ -422,16 +422,16 @@ struct Sm90ColOrScalarBroadcast {
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor cCol = make_identity_tensor(mCol.shape());
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(

View File

@@ -3,6 +3,14 @@
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
// This header is shared by both _C (unstable ABI) and _C_stable_libtorch
// (stable ABI) targets. When compiled under the stable ABI target,
// TORCH_TARGET_VERSION is defined and Tensor is unavailable, so we
// use torch::stable::Tensor instead.
#ifdef TORCH_TARGET_VERSION
#include <torch/csrc/stable/tensor.h>
#endif
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
@@ -15,6 +23,12 @@
namespace vllm::c3x {
#ifdef TORCH_TARGET_VERSION
using TensorType = torch::stable::Tensor;
#else
using TensorType = torch::Tensor;
#endif
using namespace cute;
template <typename T>
@@ -84,7 +98,7 @@ struct ScaledEpilogueBase {
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
static auto args_from_tensor(TensorType const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
@@ -100,7 +114,7 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
static auto args_from_tensor(std::optional<TensorType> const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
@@ -158,8 +172,8 @@ struct ScaledEpilogue
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
static ArgumentType prepare_args(TensorType const& a_scales,
TensorType const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
@@ -203,9 +217,9 @@ struct ScaledEpilogueBias
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
static ArgumentType prepare_args(TensorType const& a_scales,
TensorType const& b_scales,
TensorType const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -246,9 +260,9 @@ struct ScaledEpilogueColumnBias
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
static ArgumentType prepare_args(TensorType const& a_scales,
TensorType const& b_scales,
TensorType const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -304,10 +318,10 @@ struct ScaledEpilogueBiasAzp
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& bias) {
static ArgumentType prepare_args(TensorType const& a_scales,
TensorType const& b_scales,
TensorType const& azp_adj,
std::optional<TensorType> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -380,11 +394,11 @@ struct ScaledEpilogueBiasAzpToken
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
std::optional<torch::Tensor> const& bias) {
static ArgumentType prepare_args(TensorType const& a_scales,
TensorType const& b_scales,
TensorType const& azp_adj,
TensorType const& azp,
std::optional<TensorType> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);

View File

@@ -1,6 +1,21 @@
#pragma once
#include <torch/all.h>
// This header is shared between _C (unstable ABI, used by machete) and
// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION
// is defined only for the stable target, so we switch includes and types
// accordingly. TorchTensor (not Tensor) avoids ambiguity with cute::Tensor.
#ifdef TORCH_TARGET_VERSION
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Half.h>
#include <torch/headeronly/util/shim_utils.h> // for STD_TORCH_CHECK
using TorchTensor = torch::stable::Tensor;
#define TORCH_UTILS_CHECK STD_TORCH_CHECK
#else
#include <torch/all.h>
using TorchTensor = torch::Tensor;
#define TORCH_UTILS_CHECK TORCH_CHECK
#endif
#include "cute/layout.hpp"
#include "cutlass/layout/matrix.h"
@@ -55,35 +70,35 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
// strides are set to be 0 or 1.
template <typename Stride>
static inline auto make_cute_layout(torch::Tensor const& tensor,
static inline auto make_cute_layout(TorchTensor const& tensor,
std::string_view name = "tensor") {
TORCH_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx(
Stride{}, [&](auto const& stride_ele, auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;
TORCH_UTILS_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx(Stride{}, [&](auto const& stride_ele,
auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;
if (idx < tensor.dim()) {
if constexpr (cute::is_static_v<StrideEle>) {
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
return tensor.stride(idx);
}
}
if (idx < tensor.dim()) {
if constexpr (cute::is_static_v<StrideEle>) {
TORCH_UTILS_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
return tensor.stride(idx);
}
});
}
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
}
});
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
if (idx < tensor.dim())
@@ -97,7 +112,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
template <typename Stride>
static inline auto maybe_make_cute_layout(
std::optional<torch::Tensor> const& tensor,
std::optional<TorchTensor> const& tensor,
std::string_view name = "tensor") {
using Layout = decltype(make_cute_layout<Stride>(*tensor));
@@ -121,12 +136,12 @@ template <typename T>
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
template <>
struct equivalent_cutlass_type<c10::Half> {
struct equivalent_cutlass_type<torch::headeronly::Half> {
using type = cutlass::half_t;
};
template <>
struct equivalent_cutlass_type<c10::BFloat16> {
struct equivalent_cutlass_type<torch::headeronly::BFloat16> {
using type = cutlass::bfloat16_t;
};
@@ -134,8 +149,8 @@ struct equivalent_cutlass_type<c10::BFloat16> {
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
//
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
// Return a `torch::headeronly::CppTypeToScalarType<T>` compatible type, i.e.
// get the C++ type equivalent to T, e.g.: `cutlass::half_t -> Half`
template <typename T>
struct equivalent_scalar_type {
using type = T;
@@ -146,15 +161,15 @@ using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
template <>
struct equivalent_scalar_type<cutlass::half_t> {
using type = c10::Half;
using type = torch::headeronly::Half;
};
template <>
struct equivalent_scalar_type<cutlass::bfloat16_t> {
using type = c10::BFloat16;
using type = torch::headeronly::BFloat16;
};
// get equivalent c10::ScalarType tag from compile time type
// get equivalent torch::headeronly::ScalarType tag from compile time type
template <typename T>
static inline constexpr c10::ScalarType equivalent_scalar_type_v =
c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;
static inline constexpr torch::headeronly::ScalarType equivalent_scalar_type_v =
torch::headeronly::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/csrc/stable/tensor.h>
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/*
@@ -52,7 +54,7 @@ struct ScaledEpilogueBase {
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
static auto args_from_tensor(torch::stable::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
@@ -68,7 +70,8 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
static auto args_from_tensor(
std::optional<torch::stable::Tensor> const& tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
@@ -117,8 +120,8 @@ struct ScaledEpilogue
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
@@ -160,9 +163,9 @@ struct ScaledEpilogueBias
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -220,10 +223,11 @@ struct ScaledEpilogueBiasAzp
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& bias) {
static ArgumentType prepare_args(
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -298,11 +302,11 @@ struct ScaledEpilogueBiasAzpToken
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
std::optional<torch::Tensor> const& bias) {
static ArgumentType prepare_args(
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj, torch::stable::Tensor const& azp,
std::optional<torch::stable::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);

View File

@@ -49,6 +49,15 @@
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
// Half types dispatch (Half + BFloat16)
#define VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(...) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_STABLE_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__))
// Boolean dispatch
#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \
if (expr) { \

View File

@@ -27,4 +27,111 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input,
torch::stable::Tensor& output_s,
int64_t group_size, double eps, double int8_min,
double int8_max);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides, bool per_act_token,
bool per_out_ch);
void cutlass_scaled_mm_azp(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void get_cutlass_moe_mm_data(
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab);
void get_cutlass_batched_moe_mm_data(
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k);
// FP4/NVFP4 ops
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets);
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_scale, bool is_sf_swizzled_layout);
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
torch::stable::Tensor const& input_scale,
bool is_sf_swizzled_layout,
torch::stable::Tensor& output,
torch::stable::Tensor& output_scale);
void scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
torch::stable::Tensor& output_block_scale,
torch::stable::Tensor& input,
torch::stable::Tensor& input_global_scale);
#endif

View File

@@ -2,10 +2,9 @@
#pragma once
#include <cuda.h>
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
@@ -41,7 +40,7 @@ __global__ void get_group_gemm_starts(
}
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
cutlass::Array<cutlass::float_e4m3_t, 8>> \
<<<1, num_experts, 0, stream>>>( \
@@ -66,23 +65,34 @@ __global__ void get_group_gemm_starts(
namespace {
void run_get_group_gemm_starts(
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor& out_tensors,
torch::Tensor const& a_scales, torch::Tensor const& b_scales,
torch::Tensor const& b_group_scales, const int64_t b_group_size) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_group_scales.dtype() ==
torch::kFloat8_e4m3fn); // the underlying torch type is e4m3
TORCH_CHECK(out_tensors.dtype() ==
torch::kBFloat16); // only support bf16 for now
torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
torch::stable::Tensor& b_group_scales_ptrs,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& b_group_scales, const int64_t b_group_size) {
STD_TORCH_CHECK(a_tensors.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(
b_tensors.scalar_type() ==
torch::headeronly::ScalarType::Int); // int4 8x packed into int32
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(
b_group_scales.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn); // the underlying torch
// type is e4m3
STD_TORCH_CHECK(
out_tensors.scalar_type() ==
torch::headeronly::ScalarType::BFloat16); // only support bf16 for now
// expect int64_t to avoid overflow during offset calculations
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
STD_TORCH_CHECK(expert_offsets.scalar_type() ==
torch::headeronly::ScalarType::Long);
int num_experts = static_cast<int>(expert_offsets.size(0));
// logical k, n
@@ -90,15 +100,16 @@ void run_get_group_gemm_starts(
int64_t k = a_tensors.size(1);
int64_t scale_k = cutlass::ceil_div(k, b_group_size);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
if (false) {
}
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
} // namespace
} // namespace

View File

@@ -14,13 +14,12 @@
#include "cutlass/util/mixed_dtype_utils.hpp"
// vllm includes
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/torch_utils.hpp"
#include "cutlass_extensions/common.hpp"
#include "core/registration.h"
#include "get_group_starts.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "w4a8_utils.cuh"
@@ -168,31 +167,40 @@ struct W4A8GroupedGemmKernel {
static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0,
"LayoutB_Reordered size must be divisible by 4 bytes");
static void grouped_mm(
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides,
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::Tensor& group_scale_strides) {
static void grouped_mm(torch::stable::Tensor& out_tensors,
const torch::stable::Tensor& a_tensors,
const torch::stable::Tensor& b_tensors,
const torch::stable::Tensor& a_scales,
const torch::stable::Tensor& b_scales,
const torch::stable::Tensor& b_group_scales,
const int64_t b_group_size,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& problem_sizes_torch,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides) {
auto device = a_tensors.device();
auto device_id = device.index();
const at::cuda::OptionalCUDAGuard device_guard(device);
auto stream = at::cuda::getCurrentCUDAStream(device_id);
const torch::stable::accelerator::DeviceGuard device_guard(device_id);
auto stream = get_current_cuda_stream(device_id);
int num_experts = static_cast<int>(expert_offsets.size(0));
int n = static_cast<int>(b_tensors.size(1));
int k = static_cast<int>(b_tensors.size(2)) * PackFactor;
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(device);
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int);
torch::stable::Tensor a_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor out_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_group_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
// get the correct offsets to pass to gemm
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
@@ -247,9 +255,9 @@ struct W4A8GroupedGemmKernel {
// Allocate workspace
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
torch::Tensor workspace =
torch::empty(workspace_size,
torch::TensorOptions().dtype(torch::kU8).device(device));
torch::stable::Tensor workspace = torch::stable::empty(
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
device);
// Run GEMM
GemmShuffled gemm;
@@ -294,14 +302,20 @@ using Kernel_256x128_2x1x1_Coop =
using Kernel_128x256_2x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>, Coop, CoopEpi>;
void mm_dispatch(
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::Tensor& group_scale_strides, const std::string& schedule) {
void mm_dispatch(torch::stable::Tensor& out_tensors,
const torch::stable::Tensor& a_tensors,
const torch::stable::Tensor& b_tensors,
const torch::stable::Tensor& a_scales,
const torch::stable::Tensor& b_scales,
const torch::stable::Tensor& b_group_scales,
const int64_t b_group_size,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides,
const std::string& schedule) {
if (schedule == "Kernel_128x16_1x1x1_Coop") {
Kernel_128x16_1x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
@@ -358,18 +372,23 @@ void mm_dispatch(
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides);
} else {
TORCH_CHECK(false,
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
STD_TORCH_CHECK(false,
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
}
}
void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::Tensor& group_scale_strides,
void mm(torch::stable::Tensor& out_tensors,
const torch::stable::Tensor& a_tensors,
const torch::stable::Tensor& b_tensors,
const torch::stable::Tensor& a_scales,
const torch::stable::Tensor& b_scales,
const torch::stable::Tensor& b_group_scales, const int64_t b_group_size,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides,
std::optional<std::string> maybe_schedule) {
// user has specified a schedule
if (maybe_schedule) {
@@ -406,26 +425,27 @@ void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
a_strides, b_strides, c_strides, group_scale_strides, schedule);
}
std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
torch::Tensor const& b_tensors) {
TORCH_CHECK(b_tensors.dtype() == torch::kInt32);
TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
TORCH_CHECK(b_tensors.is_contiguous());
TORCH_CHECK(b_tensors.is_cuda());
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
encode_and_reorder_int4b(torch::stable::Tensor const& b_tensors) {
STD_TORCH_CHECK(b_tensors.scalar_type() ==
torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
STD_TORCH_CHECK(b_tensors.is_contiguous());
STD_TORCH_CHECK(b_tensors.is_cuda());
int n = static_cast<int>(b_tensors.size(1));
int k = static_cast<int>(b_tensors.size(2)) * PackFactor; // logical k
// CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0.
// These misalignments cause silent OOB unless run under Compute Sanitizer.
TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
STD_TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
STD_TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
// we will store the layout to an int32 tensor;
// this is the number of elements we need per layout
constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t);
torch::Tensor b_tensors_packed = torch::empty_like(b_tensors);
torch::stable::Tensor b_tensors_packed = torch::stable::empty_like(b_tensors);
int num_experts = static_cast<int>(b_tensors.size(0));
auto b_ptr = static_cast<QuantType const*>(b_tensors.const_data_ptr());
@@ -435,7 +455,7 @@ std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
size_t num_int4_elems = 1ull * num_experts * n * k;
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr,
num_int4_elems);
TORCH_CHECK(ok, "unified_encode_int4b failed");
STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
// construct the layout once; assumes each expert has the same layout
using LayoutType = LayoutB_Reordered;
@@ -456,28 +476,28 @@ std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
}
// save the packed layout to torch tensor so we can re-use it
auto cpu_opts =
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
torch::Tensor layout_cpu =
torch::empty({num_experts, layout_width}, cpu_opts);
torch::stable::Tensor layout_cpu = torch::stable::empty(
{num_experts, layout_width}, torch::headeronly::ScalarType::Int,
std::nullopt, torch::stable::Device(torch::stable::DeviceType::CPU));
int32_t* layout_data = layout_cpu.data_ptr<int32_t>();
int32_t* layout_data = layout_cpu.mutable_data_ptr<int32_t>();
for (int i = 0; i < num_experts; ++i) {
std::memcpy(layout_data + i * layout_width, // dst (int32*)
&layout_B_reordered, // src (LayoutType*)
sizeof(LayoutType)); // number of bytes
}
torch::Tensor packed_layout =
layout_cpu.to(b_tensors.device(), /*non_blocking=*/false);
torch::stable::Tensor packed_layout =
torch::stable::to(layout_cpu, b_tensors.device(),
/*non_blocking=*/false);
return {b_tensors_packed, packed_layout};
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_w4a8_moe_mm", &mm);
m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b);
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_w4a8_moe_mm", TORCH_BOX(&mm));
m.impl("cutlass_encode_and_reorder_int4b_grouped",
TORCH_BOX(&encode_and_reorder_int4b));
}
} // namespace vllm::cutlass_w4a8_moe
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -3,14 +3,12 @@
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
//
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/torch_utils.hpp"
#include "w4a8_utils.cuh"
#include "core/registration.h"
#include "cutlass/cutlass.h"
#include <limits>
@@ -161,31 +159,31 @@ struct W4A8GemmKernel {
using StrideD = typename GemmKernelShuffled::StrideD;
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
static torch::Tensor mm(torch::Tensor const& A,
torch::Tensor const& B, // already packed
torch::Tensor const& group_scales, // already packed
int64_t group_size,
torch::Tensor const& channel_scales,
torch::Tensor const& token_scales,
std::optional<at::ScalarType> const& maybe_out_type) {
static torch::stable::Tensor mm(
torch::stable::Tensor const& A,
torch::stable::Tensor const& B, // already packed
torch::stable::Tensor const& group_scales, // already packed
int64_t group_size, torch::stable::Tensor const& channel_scales,
torch::stable::Tensor const& token_scales,
std::optional<torch::headeronly::ScalarType> const& maybe_out_type) {
// TODO: param validation
int m = A.size(0);
int k = A.size(1);
int n = B.size(1);
// safely cast group_size to int
TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits<int>::max(),
"group_size out of supported range for int: ", group_size);
STD_TORCH_CHECK(
group_size > 0 && group_size <= std::numeric_limits<int>::max(),
"group_size out of supported range for int: ", group_size);
int const group_size_int = static_cast<int>(group_size);
// Allocate output
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
auto device = A.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index());
torch::Tensor D =
torch::empty({m, n}, torch::TensorOptions()
.dtype(equivalent_scalar_type_v<ElementD>)
.device(device));
auto stream = get_current_cuda_stream(device.index());
torch::stable::Tensor D = torch::stable::empty(
{m, n}, equivalent_scalar_type_v<ElementD>, std::nullopt, device);
// prepare arg pointers
auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr());
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
@@ -237,9 +235,9 @@ struct W4A8GemmKernel {
// Workspace
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
torch::Tensor workspace =
torch::empty(workspace_size,
torch::TensorOptions().dtype(torch::kU8).device(device));
torch::stable::Tensor workspace = torch::stable::empty(
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
device);
// Run GEMM
GemmShuffled gemm;
@@ -269,14 +267,14 @@ using Kernel_128x64_1x1x1 = W4A8GemmKernel<Shape<_128, _64>, Shape<_1, _1, _1>>;
using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>;
using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>;
torch::Tensor mm_dispatch(torch::Tensor const& A,
torch::Tensor const& B, // already packed
torch::Tensor const& group_scales, // already packed
int64_t group_size,
torch::Tensor const& channel_scales,
torch::Tensor const& token_scales,
std::optional<at::ScalarType> const& maybe_out_type,
const std::string& schedule) {
torch::stable::Tensor mm_dispatch(
torch::stable::Tensor const& A,
torch::stable::Tensor const& B, // already packed
torch::stable::Tensor const& group_scales, // already packed
int64_t group_size, torch::stable::Tensor const& channel_scales,
torch::stable::Tensor const& token_scales,
std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
const std::string& schedule) {
if (schedule == "256x128_1x1x1") {
return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size,
channel_scales, token_scales,
@@ -318,17 +316,18 @@ torch::Tensor mm_dispatch(torch::Tensor const& A,
channel_scales, token_scales,
maybe_out_type);
}
TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
STD_TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
return {};
}
torch::Tensor mm(torch::Tensor const& A,
torch::Tensor const& B, // already packed
torch::Tensor const& group_scales, // already packed
int64_t group_size, torch::Tensor const& channel_scales,
torch::Tensor const& token_scales,
std::optional<at::ScalarType> const& maybe_out_type,
std::optional<std::string> maybe_schedule) {
torch::stable::Tensor mm(
torch::stable::Tensor const& A,
torch::stable::Tensor const& B, // already packed
torch::stable::Tensor const& group_scales, // already packed
int64_t group_size, torch::stable::Tensor const& channel_scales,
torch::stable::Tensor const& token_scales,
std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
std::optional<std::string> maybe_schedule) {
// requested a specific schedule
if (maybe_schedule) {
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
@@ -378,14 +377,15 @@ torch::Tensor mm(torch::Tensor const& A,
// ----------------------------------------------------------------------------
// Pre-processing utils
// ----------------------------------------------------------------------------
torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(scales.is_contiguous());
TORCH_CHECK(scales.is_cuda());
torch::stable::Tensor pack_scale_fp8(torch::stable::Tensor const& scales) {
STD_TORCH_CHECK(scales.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(scales.is_contiguous());
STD_TORCH_CHECK(scales.is_cuda());
auto packed_scales = torch::empty(
{scales.numel() * ScalePackSize},
torch::TensorOptions().dtype(scales.dtype()).device(scales.device()));
auto packed_scales =
torch::stable::empty({scales.numel() * ScalePackSize},
scales.scalar_type(), std::nullopt, scales.device());
auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr());
auto packed_scales_ptr =
static_cast<cutlass::Array<ElementScale, ScalePackSize>*>(
@@ -396,15 +396,16 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
return packed_scales;
}
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
TORCH_CHECK(B.dtype() == torch::kInt32);
TORCH_CHECK(B.dim() == 2);
torch::stable::Tensor encode_and_reorder_int4b(torch::stable::Tensor const& B) {
STD_TORCH_CHECK(B.scalar_type() == torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(B.dim() == 2);
torch::Tensor B_packed = torch::empty_like(B);
torch::stable::Tensor B_packed = torch::stable::empty_like(B);
int k = B.size(0) * PackFactor; // logical k
int n = B.size(1);
TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks");
STD_TORCH_CHECK((n * k) % 32 == 0,
"need multiples of 32 int4s for 16B chunks");
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
@@ -415,16 +416,17 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr,
n * k);
TORCH_CHECK(ok, "unified_encode_int4b failed");
STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
return B_packed;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_w4a8_mm", &mm);
m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8);
m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b);
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_w4a8_mm", TORCH_BOX(&mm));
m.impl("cutlass_pack_scale_fp8", TORCH_BOX(&pack_scale_fp8));
m.impl("cutlass_encode_and_reorder_int4b",
TORCH_BOX(&encode_and_reorder_int4b));
}
} // namespace vllm::cutlass_w4a8
} // namespace vllm::cutlass_w4a8

View File

@@ -14,16 +14,15 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
@@ -118,17 +117,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
} // namespace vllm
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
torch::Tensor& output_sf,
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& input_sf) {
void silu_and_mul_nvfp4_quant_sm1xxa(
torch::stable::Tensor& output, // [..., d]
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input, // [..., 2 * d]
torch::stable::Tensor& input_sf) {
int32_t m = input.size(0);
int32_t n = input.size(1) / 2;
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
STD_TORCH_CHECK(
input.scalar_type() == torch::headeronly::ScalarType::Half ||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
@@ -136,8 +137,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
auto stream = get_current_cuda_stream(input.get_device_index());
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
@@ -149,7 +151,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES(
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());

View File

@@ -14,14 +14,12 @@
* limitations under the License.
*/
#include "core/registration.h"
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include <torch/all.h>
#include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cutlass_extensions/common.hpp"
#include "cute/tensor.hpp"
@@ -122,7 +120,7 @@ __global__ void __get_group_gemm_starts(
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
LayoutSFB, ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
@@ -150,50 +148,64 @@ __global__ void __get_group_gemm_starts(
}
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
void run_get_group_gemm_starts(
const torch::Tensor& a_starts, const torch::Tensor& b_starts,
const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
const torch::Tensor& a_strides, const torch::Tensor& b_strides,
const torch::Tensor& c_strides, int64_t a_stride_val, int64_t b_stride_val,
int64_t c_stride_val,
/*these are used for their base addresses*/
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& alphas,
torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets,
torch::Tensor const& problem_sizes, int M, int N, int K) {
void run_get_group_gemm_starts(const torch::stable::Tensor& a_starts,
const torch::stable::Tensor& b_starts,
const torch::stable::Tensor& out_starts,
const torch::stable::Tensor& a_scales_starts,
const torch::stable::Tensor& b_scales_starts,
const torch::stable::Tensor& alpha_starts,
const torch::stable::Tensor& layout_sfa,
const torch::stable::Tensor& layout_sfb,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
int64_t a_stride_val, int64_t b_stride_val,
int64_t c_stride_val,
/*these are used for their base addresses*/
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& out_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& alphas,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& sf_offsets,
torch::stable::Tensor const& problem_sizes,
int M, int N, int K) {
int num_experts = (int)expert_offsets.size(0);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
TORCH_CHECK(out_tensors.size(1) == N,
"Output tensor shape doesn't match expected shape");
TORCH_CHECK(K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match");
STD_TORCH_CHECK(out_tensors.size(1) == N,
"Output tensor shape doesn't match expected shape");
STD_TORCH_CHECK(K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match");
if (false) {
}
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16,
cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig)
cutlass::float_e2m1_t, cutlass::float_ue4m3_t,
torch::headeronly::ScalarType::BFloat16, cutlass::bfloat16_t, LayoutSFA,
LayoutSFB, ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t,
cutlass::float_ue4m3_t, torch::kFloat16,
half, LayoutSFA, LayoutSFB, ScaleConfig)
cutlass::float_ue4m3_t,
torch::headeronly::ScalarType::Half, half,
LayoutSFA, LayoutSFB, ScaleConfig)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
template <typename OutType>
void run_fp4_blockwise_scaled_group_mm_sm100(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
@@ -272,20 +284,40 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
torch::stable::Tensor a_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor out_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor a_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor alpha_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor layout_sfa = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor layout_sfb = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor a_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor c_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
@@ -308,7 +340,7 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device();
hw_info.device_id = a.get_device_index();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
@@ -350,32 +382,35 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, a.device());
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
STD_TORCH_CHECK(
can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size, " num_experts=", num_experts,
" M=", M, " N=", N, " K=", K);
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size,
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
void run_fp4_blockwise_scaled_group_mm_sm120(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
@@ -446,20 +481,40 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
torch::stable::Tensor a_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor out_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor a_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor alpha_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor layout_sfa = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor layout_sfb = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor a_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor c_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
@@ -480,7 +535,7 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device();
hw_info.device_id = a.get_device_index();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
@@ -523,33 +578,36 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, a.device());
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
STD_TORCH_CHECK(
can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size, " num_experts=", num_experts,
" M=", M, " N=", N, " K=", K);
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size,
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType>
void run_fp4_blockwise_scaled_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
if (version_num >= 120 && version_num < 130) {
@@ -567,7 +625,7 @@ void run_fp4_blockwise_scaled_group_mm(
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 100 or 120");
@@ -575,26 +633,31 @@ void run_fp4_blockwise_scaled_group_mm(
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
#endif
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets) {
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
// Input validation
@@ -602,30 +665,34 @@ void cutlass_fp4_group_mm(
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
CHECK_INPUT(alphas, torch::headeronly::ScalarType::Float, "alphas");
TORCH_CHECK(a_blockscale.dim() == 2,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: ",
a_blockscale.dim())
TORCH_CHECK(b_blockscales.dim() == 3,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: ",
b_blockscales.dim())
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have the shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32.");
STD_TORCH_CHECK(
a_blockscale.dim() == 2,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: ",
a_blockscale.dim())
STD_TORCH_CHECK(b_blockscales.dim() == 3,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: ",
b_blockscales.dim())
STD_TORCH_CHECK(problem_sizes.dim() == 2,
"problem_sizes must be a 2D tensor");
STD_TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have the shape (num_experts, 3)");
STD_TORCH_CHECK(
problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
STD_TORCH_CHECK(
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
"problem_sizes must be int32.");
int M = static_cast<int>(a.size(0));
int N = static_cast<int>(b.size(1));
int E = static_cast<int>(b.size(0));
int K = static_cast<int>(2 * b.size(2));
if (output.scalar_type() == torch::kBFloat16) {
if (output.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
@@ -633,7 +700,7 @@ void cutlass_fp4_group_mm(
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
int32_t version_num = get_sm_version_num();
if (version_num >= 120 && version_num < 130) {
TORCH_CHECK_NOT_IMPLEMENTED(
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ",
output.scalar_type());
}
@@ -643,7 +710,7 @@ void cutlass_fp4_group_mm(
expert_offsets, sf_offsets, M, N, K);
}
#else
TORCH_CHECK_NOT_IMPLEMENTED(
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
@@ -651,6 +718,6 @@ void cutlass_fp4_group_mm(
#endif
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm);
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_fp4_group_mm", TORCH_BOX(&cutlass_fp4_group_mm));
}

View File

@@ -14,16 +14,15 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h"
#include "nvfp4_utils.cuh"
@@ -327,25 +326,28 @@ void quant_impl(void* output, void* output_scale, void* input,
} // namespace vllm
/*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_TH_CUDA(x, m) \
STD_TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
STD_TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
constexpr auto HALF = at::ScalarType::Half;
constexpr auto BF16 = at::ScalarType::BFloat16;
constexpr auto FLOAT = at::ScalarType::Float;
constexpr auto INT = at::ScalarType::Int;
constexpr auto UINT8 = at::ScalarType::Byte;
constexpr auto HALF = torch::headeronly::ScalarType::Half;
constexpr auto BF16 = torch::headeronly::ScalarType::BFloat16;
constexpr auto FLOAT = torch::headeronly::ScalarType::Float;
constexpr auto INT = torch::headeronly::ScalarType::Int;
constexpr auto UINT8 = torch::headeronly::ScalarType::Byte;
// Common validation for fp4 experts quantization entry points.
static void validate_fp4_experts_quant_inputs(
torch::Tensor const& output, torch::Tensor const& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
torch::stable::Tensor const& output,
torch::stable::Tensor const& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
int64_t k) {
CHECK_INPUT(output, "output");
CHECK_INPUT(output_scale, "output_scale");
@@ -354,41 +356,42 @@ static void validate_fp4_experts_quant_inputs(
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2);
TORCH_CHECK(input.dim() == 2);
TORCH_CHECK(input_global_scale.dim() == 1);
TORCH_CHECK(input_offset_by_experts.dim() == 1);
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
STD_TORCH_CHECK(output.dim() == 2);
STD_TORCH_CHECK(output_scale.dim() == 2);
STD_TORCH_CHECK(input.dim() == 2);
STD_TORCH_CHECK(input_global_scale.dim() == 1);
STD_TORCH_CHECK(input_offset_by_experts.dim() == 1);
STD_TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
STD_TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
STD_TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
STD_TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
STD_TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK(output.scalar_type() == UINT8);
TORCH_CHECK(output_scale.scalar_type() == INT);
STD_TORCH_CHECK(output.scalar_type() == UINT8);
STD_TORCH_CHECK(output_scale.scalar_type() == INT);
const int BLOCK_SIZE = 16;
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
STD_TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
auto n_experts = input_global_scale.size(0);
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output.size(0) == m_topk);
TORCH_CHECK(output.size(1) == k / 2);
STD_TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
STD_TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
STD_TORCH_CHECK(output.size(0) == m_topk);
STD_TORCH_CHECK(output.size(1) == k / 2);
int scales_k = k / BLOCK_SIZE;
// 4 means the swizzle requirement by nvidia nvfp4.
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
STD_TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
}
void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0);
auto k = input.size(1);
@@ -397,11 +400,11 @@ void scaled_fp4_experts_quant_sm1xxa(
output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
VLLM_DISPATCH_HALF_TYPES(
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
@@ -413,14 +416,15 @@ void scaled_fp4_experts_quant_sm1xxa(
}
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0);
// Input has gate || up layout, so k = input.size(1) / 2
auto k_times_2 = input.size(1);
TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
STD_TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
auto k = k_times_2 / 2;
validate_fp4_experts_quant_inputs(output, output_scale, input,
@@ -428,11 +432,11 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
VLLM_DISPATCH_HALF_TYPES(
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(

View File

@@ -0,0 +1,175 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
torch::stable::Tensor const& input,
torch::stable::Tensor const& output_sf,
torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_nvfp4_quant_sm1xxa(torch::stable::Tensor& output,
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input,
torch::stable::Tensor& input_sf);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
static bool nvfp4_quant_sm_supported() {
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) return true;
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) return true;
#endif
return false;
}
#endif
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout,
torch::stable::Tensor& output,
torch::stable::Tensor& output_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
is_sf_swizzled_layout);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 quantization kernel");
}
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
torch::stable::Tensor const& input, torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int64_t n = input.size(-1);
int64_t m = input.numel() / n;
auto device = input.device();
// Two fp4 values packed into a uint8
auto output = torch::stable::empty(
{m, n / 2}, torch::headeronly::ScalarType::Byte, std::nullopt, device);
torch::stable::Tensor output_sf;
if (is_sf_swizzled_layout) {
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
output_sf = torch::stable::empty(
{sf_m, sf_n}, torch::headeronly::ScalarType::Int, std::nullopt, device);
} else {
output_sf = torch::stable::empty({m, n / CVT_FP4_SF_VEC_SIZE},
torch::headeronly::ScalarType::Byte,
std::nullopt, device);
}
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
output_sf);
return {output, output_sf};
}
void scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 experts quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled nvfp4 experts quantization kernel");
}
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& output,
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input,
torch::stable::Tensor& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 quantization kernel");
}
void silu_and_mul_scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 experts quantization kernel "
"for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
}

View File

@@ -14,16 +14,16 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
@@ -173,18 +173,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
} // namespace vllm
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input,
torch::Tensor const& output_sf,
torch::Tensor const& input_sf,
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
torch::stable::Tensor const& input,
torch::stable::Tensor const& output_sf,
torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int32_t m = input.size(0);
int32_t n = input.size(1);
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
STD_TORCH_CHECK(
input.scalar_type() == torch::headeronly::ScalarType::Half ||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
@@ -192,8 +193,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
auto stream = get_current_cuda_stream(input.get_device_index());
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
@@ -213,15 +215,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
m, n, num_padded_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
m, n, num_padded_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
} else {
int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD;
int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
@@ -229,15 +231,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, num_packed_cols,
input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
<<<grid, block, 0, stream>>>(
m, n, sf_n_unpadded, num_packed_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
}
}

View File

@@ -0,0 +1,87 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
#endif
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
#endif
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
const torch::stable::Tensor& A,
const torch::stable::Tensor& B,
const torch::stable::Tensor& A_sf,
const torch::stable::Tensor& B_sf,
const torch::stable::Tensor& alpha) {
// Make sure we're on A's device.
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) {
cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) {
cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled nvfp4 mm kernel for SM ", sm,
". Recompile with CUDA >= 12.8 and CC >= 100.");
}
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
int runtimeVersion;
cudaRuntimeGetVersion(&runtimeVersion);
if (runtimeVersion < 12080) return false;
// Only report support when the SM-specific kernel was actually compiled in,
// so the Python-side backend selector does not choose CUTLASS and then hit
// TORCH_CHECK_NOT_IMPLEMENTED (or worse, fall through to Marlin).
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (cuda_device_capability >= 100 && cuda_device_capability < 120)
return true;
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (cuda_device_capability >= 120 && cuda_device_capability < 130)
return true;
#endif
return false;
}

View File

@@ -14,10 +14,9 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
@@ -127,8 +126,9 @@ struct Fp4GemmSm100 {
template <typename Config>
typename Config::Gemm::Arguments args_from_options(
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
torch::stable::Tensor& D, torch::stable::Tensor const& A,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf, torch::stable::Tensor const& alpha,
int64_t M, int64_t N, int64_t K) {
using ElementA = typename Config::Gemm::ElementA;
using ElementB = typename Config::Gemm::ElementB;
@@ -174,19 +174,20 @@ typename Config::Gemm::Arguments args_from_options(
}
template <typename Config>
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
at::Tensor const& A_sf, at::Tensor const& B_sf,
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
cudaStream_t stream) {
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int64_t m, int64_t n,
int64_t k, cudaStream_t stream) {
typename Config::Gemm gemm;
auto arguments =
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, A.device());
CUTLASS_CHECK(gemm.can_implement(arguments));
@@ -197,12 +198,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
// Dispatch function to select appropriate config based on M
template <typename OutType>
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int64_t m, int64_t n,
int64_t k, cudaStream_t stream) {
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int64_t m,
int64_t n, int64_t k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 16) {
@@ -222,61 +224,65 @@ void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
#else
template <typename OutType>
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int64_t m, int64_t n,
int64_t k, cudaStream_t stream) {
TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int64_t m,
int64_t n, int64_t k, cudaStream_t stream) {
STD_TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha) {
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha) {
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
STD_TORCH_CHECK(A.size(1) == B.size(1),
"a and b shapes cannot be multiplied (", A.size(0), "x",
A.size(1), " and ", B.size(0), "x", B.size(1), ")");
auto const m = A.sizes()[0];
auto const n = B.sizes()[0];
auto const k = A.sizes()[1] * 2;
auto const m = A.size(0);
auto const n = B.size(0);
auto const k = A.size(1) * 2;
constexpr int alignment = 32;
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
"), k: ", k, ".");
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
"), k: ", k, ".");
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128);
@@ -285,33 +291,34 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
// integer.
int rounded_k = round_up(k / 16, 4);
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
"x", B_sf.sizes()[1], ")");
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
A_sf.sizes()[1], ")");
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
B_sf.sizes()[1], ")");
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
B_sf.size(1), ")");
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
A_sf.size(1), ")");
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
B_sf.size(1), ")");
auto out_dtype = D.dtype();
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
auto out_dtype = D.scalar_type();
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
if (out_dtype == at::ScalarType::Half) {
if (out_dtype == torch::headeronly::ScalarType::Half) {
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
k, stream);
} else if (out_dtype == at::ScalarType::BFloat16) {
} else if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
m, n, k, stream);
} else {
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype,
")");
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (",
out_dtype, ")");
}
}

View File

@@ -14,10 +14,9 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
@@ -34,19 +33,20 @@
using namespace cute;
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
struct sm120_fp4_config_M256 {
using ClusterShape = Shape<_1, _1, _1>;
@@ -109,12 +109,13 @@ struct Fp4GemmSm120 {
};
template <typename Gemm>
typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
torch::Tensor const& alpha, int M,
int N, int K) {
typename Gemm::Arguments args_from_options(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha,
int M, int N, int K) {
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementD = typename Gemm::ElementD;
@@ -158,18 +159,19 @@ typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
}
template <typename Gemm>
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
at::Tensor const& A_sf, at::Tensor const& B_sf,
torch::Tensor const& alpha, int M, int N, int K,
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int M, int N, int K,
cudaStream_t stream) {
Gemm gemm;
auto arguments = args_from_options<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K);
size_t workspace_size = Gemm::get_workspace_size(arguments);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, A.device());
CUTLASS_CHECK(gemm.can_implement(arguments));
@@ -178,12 +180,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
}
void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int m, int n,
int k, cudaStream_t stream) {
void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(
@@ -194,12 +197,13 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
}
}
void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int m, int n,
int k, cudaStream_t stream) {
void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(
@@ -210,11 +214,12 @@ void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
}
}
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha) {
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha) {
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
@@ -222,24 +227,25 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
STD_TORCH_CHECK(A.size(1) == B.size(1),
"a and b shapes cannot be multiplied (", A.size(0), "x",
A.size(1), " and ", B.size(0), "x", B.size(1), ")");
auto const m = A.sizes()[0];
auto const n = B.sizes()[0];
auto const k = A.sizes()[1] * 2;
auto const m = A.size(0);
auto const n = B.size(0);
auto const k = A.size(1) * 2;
constexpr int alignment = 32;
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
"), k: ", k, ".");
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
"), k: ", k, ".");
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128);
@@ -248,38 +254,39 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
// integer.
int rounded_k = round_up(k / 16, 4);
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
"x", B_sf.sizes()[1], ")");
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
A_sf.sizes()[1], ")");
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
B_sf.sizes()[1], ")");
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
B_sf.size(1), ")");
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
A_sf.size(1), ")");
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
B_sf.size(1), ")");
auto out_dtype = D.dtype();
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
auto out_dtype = D.scalar_type();
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
if (out_dtype == at::ScalarType::BFloat16) {
if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream);
} else if (out_dtype == at::ScalarType::Half) {
} else if (out_dtype == torch::headeronly::ScalarType::Half) {
return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream);
} else {
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
out_dtype, ")");
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
out_dtype, ")");
}
#else
TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
STD_TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
}
}

View File

@@ -20,7 +20,7 @@
#include <cuda_fp8.h>
#include <utility>
#include "../../cuda_vec_utils.cuh"
#include "cuda_vec_utils.cuh"
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
CUDA_VERSION >= 12090

View File

@@ -2,9 +2,10 @@
// clang-format will break include orders
// clang-format off
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <ATen/cuda/CUDAContext.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass/cutlass.h"
@@ -25,14 +26,14 @@
namespace vllm::c3x {
static inline cute::Shape<int, int, int, int> get_problem_shape(
torch::Tensor const& a, torch::Tensor const& b) {
torch::stable::Tensor const& a, torch::stable::Tensor const& b) {
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
return {m, n, k, 1};
}
template <typename GemmKernel>
void cutlass_gemm_caller(
torch::Device device, cute::Shape<int, int, int, int> prob_shape,
torch::stable::Device device, cute::Shape<int, int, int, int> prob_shape,
typename GemmKernel::MainloopArguments mainloop_args,
typename GemmKernel::EpilogueArguments epilogue_args,
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
@@ -50,19 +51,20 @@ void cutlass_gemm_caller(
CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(device);
auto workspace = torch::empty(workspace_size, workspace_options);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, device);
auto stream = at::cuda::getCurrentCUDAStream(device.index());
auto stream = get_current_cuda_stream(device.index());
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
}
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
void cutlass_gemm_caller(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementC = typename Gemm::ElementC;

View File

@@ -4,13 +4,12 @@
namespace vllm {
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias) {
void cutlass_scaled_mm_azp_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
if (azp) {
return cutlass_scaled_mm_sm90_int8_epilogue<
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,

View File

@@ -0,0 +1,22 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
@@ -130,10 +132,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
static constexpr bool swap_ab = Gemm::swap_ab;
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
@@ -200,11 +202,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
}
template <typename OutType>
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());

View File

@@ -0,0 +1,22 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
@@ -138,10 +140,10 @@ struct sm120_blockwise_fp8_config_M64 {
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
@@ -196,11 +198,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
}
template <typename OutType>
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
int M = a.size(0);
if (M <= 256) {
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;

View File

@@ -0,0 +1,23 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
@@ -101,10 +103,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
@@ -120,7 +122,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
STD_TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
StrideA a_stride;
StrideB b_stride;
@@ -161,11 +163,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
}
template <typename OutType>
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
// TODO: better heuristics
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, 128, 128, Shape<_128, _128, _128>,

View File

@@ -1,52 +1,57 @@
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cuda_utils.h"
#include "cutlass_extensions/common.hpp"
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias,
void dispatch_scaled_mm(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias,
Fp8Func fp8_func, Int8Func int8_func,
BlockwiseFunc blockwise_func) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
int M = a.size(0), N = b.size(1), K = a.size(1);
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn) {
fp8_func(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(a.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
int8_func(c, a, b, a_scales, b_scales, bias);
} else {
int32_t version_num = get_sm_version_num();
TORCH_CHECK(
STD_TORCH_CHECK(
false, "Int8 not supported on SM", version_num,
". Use FP8 quantization instead, or run on older arch (SM < 100).");
}
}
} else {
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
STD_TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
STD_TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
int32_t version_num = get_sm_version_num();
if (version_num >= 90) {
TORCH_CHECK(
STD_TORCH_CHECK(
a.size(0) == a_scales.size(0) &&
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
"a_scale_group_shape must be [1, 128].");
TORCH_CHECK(
STD_TORCH_CHECK(
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
"b_scale_group_shape must be [128, 128].");
}
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
STD_TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
blockwise_func(c, a, b, a_scales, b_scales);
}
}

View File

@@ -0,0 +1,52 @@
#pragma once
#include <torch/csrc/stable/tensor.h>
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
void cutlass_scaled_mm_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
void cutlass_scaled_mm_blockwise_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
} // namespace vllm

View File

@@ -0,0 +1,24 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm100_fp8_dispatch.cuh"
namespace vllm {
void cutlass_scaled_mm_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
b_scales, *bias);
} else {
return cutlass_scaled_mm_sm100_fp8_epilogue<false>(out, a, b, a_scales,
b_scales);
}
}
} // namespace vllm

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
@@ -192,8 +194,9 @@ struct sm100_fp8_config_M16_swap_ab {
};
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
void cutlass_gemm_caller_sm100_fp8(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
static constexpr bool swap_ab = Gemm::swap_ab;
using ElementAB = typename Gemm::ElementAB;
@@ -237,15 +240,15 @@ void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
template <typename InType, typename OutType, bool EnableBias,
typename... EpilogueArgs>
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
EpilogueArgs&&... args) {
inline void cutlass_gemm_sm100_fp8_dispatch(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using Cutlass3xGemmDefault =
typename sm100_fp8_config_default<InType, OutType,
@@ -292,22 +295,24 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
}
template <bool EnableBias, typename... EpilogueArgs>
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, EnableBias>(
out, a, b, a_scales, b_scales,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, EnableBias>(
out, a, b, a_scales, b_scales,

View File

@@ -0,0 +1,25 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
@@ -138,13 +140,15 @@ struct sm120_fp8_config_M16 {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
inline void cutlass_gemm_sm120_fp8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
int M = a.size(0);
@@ -177,19 +181,21 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);

View File

@@ -0,0 +1,24 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_fp8_dispatch.cuh"
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales,
b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_fp8_epilogue<false>(out, a, b, a_scales,
b_scales);
}
}
} // namespace vllm

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
@@ -235,8 +237,9 @@ struct sm90_fp8_config_M16_N8192 {
};
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
void cutlass_gemm_caller_sm90_fp8(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
static constexpr bool swap_ab = Gemm::swap_ab;
using ElementAB = typename Gemm::ElementAB;
@@ -280,15 +283,15 @@ void cutlass_gemm_caller_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
template <typename InType, typename OutType, bool EnableBias,
typename... EpilogueArgs>
inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
EpilogueArgs&&... args) {
inline void cutlass_gemm_sm90_fp8_dispatch(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using Cutlass3xGemmDefault =
typename sm90_fp8_config_default<InType, OutType,
@@ -347,22 +350,24 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
}
template <bool EnableBias, typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, EnableBias>(
out, a, b, a_scales, b_scales,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, EnableBias>(
out, a, b, a_scales, b_scales,

View File

@@ -0,0 +1,25 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
@@ -87,13 +89,13 @@ struct sm90_int8_config_M32_NSmall {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
inline void cutlass_gemm_sm90_int8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
using Cutlass3xGemmDefault =
typename sm90_int8_config_default<InType, OutType,
@@ -142,19 +144,19 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_int8_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
void cutlass_scaled_mm_sm90_int8_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.dtype() == torch::kBFloat16) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}

View File

@@ -1,10 +1,10 @@
#pragma once
#include <cuda.h>
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "libtorch_stable/torch_utils.h"
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
@@ -31,7 +31,7 @@ __global__ void get_group_gemm_starts(
}
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
<<<1, num_experts, 0, stream>>>( \
static_cast<int64_t*>(expert_offsets.data_ptr()), \
@@ -51,32 +51,39 @@ __global__ void get_group_gemm_starts(
namespace {
void run_get_group_gemm_starts(
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
torch::Tensor& out_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
STD_TORCH_CHECK(a_tensors.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b_tensors.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
// expect int64_t to avoid overflow during offset calculations
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
STD_TORCH_CHECK(expert_offsets.scalar_type() ==
torch::headeronly::ScalarType::Long);
int num_experts = static_cast<int>(expert_offsets.size(0));
bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
if (false) {
}
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
} // namespace
} // namespace

View File

@@ -6,6 +6,7 @@
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include <torch/csrc/stable/ops.h>
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "cutlass_extensions/common.hpp"
#include "get_group_starts.cuh"
@@ -84,13 +85,17 @@ struct cutlass_3x_group_gemm {
};
template <typename Gemm>
void cutlass_group_gemm_caller(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
void cutlass_group_gemm_caller(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
static constexpr bool swap_ab = Gemm::swap_ab;
using ElementAB = typename Gemm::ElementAB;
@@ -98,16 +103,20 @@ void cutlass_group_gemm_caller(
int num_experts = static_cast<int>(expert_offsets.size(0));
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
auto device = a_tensors.device();
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::stable::Tensor a_ptrs = torch::stable::empty(
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_ptrs = torch::stable::empty(
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor out_ptrs = torch::stable::empty(
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors,
@@ -156,7 +165,7 @@ void cutlass_group_gemm_caller(
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides.data_ptr())};
int device_id = a_tensors.device().index();
int device_id = a_tensors.get_device_index();
static const cutlass::KernelHardwareInfo hw_info{
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
device_id)};
@@ -170,9 +179,9 @@ void cutlass_group_gemm_caller(
CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, device);
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);

View File

@@ -1,7 +1,8 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "libtorch_stable/torch_utils.h"
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cutlass/cutlass.h"
#include "grouped_mm_c3x.cuh"
@@ -62,21 +63,27 @@ struct sm100_fp8_config_N8192 {
};
template <typename InType, typename OutType>
void run_cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
void run_cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
"A tensors must be of type float8_e4m3fn.");
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");
STD_TORCH_CHECK(
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
"A tensors must be of type float8_e4m3fn.");
STD_TORCH_CHECK(
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");
using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
@@ -107,14 +114,18 @@ void run_cutlass_moe_mm_sm100(
}
} // namespace
void dispatch_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) {
void dispatch_moe_mm_sm100(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
@@ -127,13 +138,17 @@ void dispatch_moe_mm_sm100(
}
}
void cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);

View File

@@ -1,7 +1,8 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "libtorch_stable/torch_utils.h"
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cutlass/cutlass.h"
#include "grouped_mm_c3x.cuh"
@@ -103,21 +104,27 @@ struct sm90_fp8_config_N8192 {
};
template <typename InType, typename OutType>
void run_cutlass_moe_mm_sm90(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
void run_cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
"A tensors must be of type float8_e4m3fn.");
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");
STD_TORCH_CHECK(
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
"A tensors must be of type float8_e4m3fn.");
STD_TORCH_CHECK(
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
@@ -163,14 +170,18 @@ void run_cutlass_moe_mm_sm90(
}
}
void dispatch_moe_mm_sm90(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) {
void dispatch_moe_mm_sm90(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
@@ -185,13 +196,17 @@ void dispatch_moe_mm_sm90(
} // namespace
void cutlass_moe_mm_sm90(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);

View File

@@ -1,9 +1,11 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "libtorch_stable/torch_utils.h"
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/ScalarType.h>
#include "dispatch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include <iostream>
@@ -110,19 +112,22 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
}
namespace {
inline void launch_compute_problem_sizes(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, torch::Tensor& atomic_buffer,
int64_t num_experts, int64_t n, int64_t k, cudaStream_t stream,
const bool swap_ab, const bool is_gated) {
inline void launch_compute_problem_sizes(const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& atomic_buffer,
int64_t num_experts, int64_t n,
int64_t k, cudaStream_t stream,
const bool swap_ab,
const bool is_gated) {
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
auto* atomic_ptr = atomic_buffer.data_ptr<int32_t>();
auto const* topk_ptr = topk_ids.const_data_ptr<int32_t>();
auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
auto* atomic_ptr = atomic_buffer.mutable_data_ptr<int32_t>();
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
@@ -171,46 +176,53 @@ __global__ void compute_problem_sizes_from_expert_offsets(
}
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
const torch::Tensor& expert_first_token_offset,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
const int64_t n, const int64_t k, const bool swap_ab) {
TORCH_CHECK(expert_first_token_offset.is_cuda(),
"expert_first_token_offset must be a CUDA tensor");
TORCH_CHECK(expert_first_token_offset.dtype() == torch::kInt64,
"expert_first_token_offset must be int64");
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab) {
STD_TORCH_CHECK(expert_first_token_offset.is_cuda(),
"expert_first_token_offset must be a CUDA tensor");
STD_TORCH_CHECK(expert_first_token_offset.scalar_type() ==
torch::headeronly::ScalarType::Long,
"expert_first_token_offset must be int64");
TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
"problem_sizes must be CUDA tensors");
TORCH_CHECK(problem_sizes1.dtype() == torch::kInt32 &&
problem_sizes2.dtype() == torch::kInt32,
"problem_sizes must be int32");
TORCH_CHECK(problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
"problem_sizes must be contiguous");
TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
"problem_sizes must be 2D tensors");
TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
"problem_sizes second dim must be 3");
TORCH_CHECK(problem_sizes1.sizes() == problem_sizes2.sizes(),
"problem_sizes1 and problem_sizes2 must have same shape");
STD_TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
"problem_sizes must be CUDA tensors");
STD_TORCH_CHECK(
problem_sizes1.scalar_type() == torch::headeronly::ScalarType::Int &&
problem_sizes2.scalar_type() == torch::headeronly::ScalarType::Int,
"problem_sizes must be int32");
STD_TORCH_CHECK(
problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
"problem_sizes must be contiguous");
STD_TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
"problem_sizes must be 2D tensors");
STD_TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
"problem_sizes second dim must be 3");
STD_TORCH_CHECK(problem_sizes1.size(0) == problem_sizes2.size(0) &&
problem_sizes1.size(1) == problem_sizes2.size(1),
"problem_sizes1 and problem_sizes2 must have same shape");
int64_t const num_experts64 = problem_sizes1.size(0);
TORCH_CHECK(expert_first_token_offset.numel() == num_experts64 + 1,
"expert_first_token_offset must have num_experts + 1 elements");
TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX, "n and k must fit in int32");
STD_TORCH_CHECK(
expert_first_token_offset.numel() == num_experts64 + 1,
"expert_first_token_offset must have num_experts + 1 elements");
STD_TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
STD_TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX,
"n and k must fit in int32");
int const num_experts = static_cast<int>(num_experts64);
auto stream = at::cuda::getCurrentCUDAStream(
expert_first_token_offset.device().index());
auto stream =
get_current_cuda_stream(expert_first_token_offset.get_device_index());
int const threads = (num_experts < 256) ? num_experts : 256;
int const blocks = (num_experts + threads - 1) / threads;
auto const* offsets_ptr = expert_first_token_offset.data_ptr<int64_t>();
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
auto const* offsets_ptr = expert_first_token_offset.const_data_ptr<int64_t>();
auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
compute_problem_sizes_from_expert_offsets<SwapAB>
<<<blocks, threads, 0, stream>>>(offsets_ptr, ps1_ptr, ps2_ptr,
num_experts, static_cast<int>(n),
@@ -219,16 +231,19 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
}
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets,
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
auto device = topk_ids.device();
auto stream = get_current_cuda_stream(device.index());
torch::stable::Tensor atomic_buffer = torch::stable::new_zeros(
topk_ids, {num_experts}, torch::headeronly::ScalarType::Int);
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
@@ -290,11 +305,13 @@ __global__ void compute_batched_moe_data(
}
void get_cutlass_batched_moe_mm_data_caller(
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
auto stream = get_current_cuda_stream(expert_offsets.get_device_index());
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
compute_batched_moe_data<false><<<1, num_local_experts, 0, stream>>>(
@@ -311,4 +328,4 @@ void get_cutlass_batched_moe_mm_data_caller(
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k);
}
}
}

View File

@@ -0,0 +1,220 @@
#include <stddef.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cutlass/cutlass.h"
#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm75_dispatch.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
#include "libtorch_stable/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
using namespace vllm;
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
*/
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm75(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm75(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (azp) {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm80(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm80(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (azp) {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
if (a.scalar_type() == torch::headeronly::ScalarType::Char) {
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
assert(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}
void cutlass_scaled_mm_sm89(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm89(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (azp) {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}

View File

@@ -1,8 +1,9 @@
#pragma once
#include <stddef.h>
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <ATen/cuda/CUDAContext.h>
#include "libtorch_stable/torch_utils.h"
// clang-format will break include orders
// clang-format off
@@ -95,8 +96,9 @@ struct cutlass_2x_gemm {
};
template <typename Gemm, typename... EpilogueArgs>
inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
inline void cutlass_gemm_caller(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
@@ -149,11 +151,12 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
// Launch the CUTLASS GEMM kernel.
typename Gemm::Op gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto device = a.device();
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, device);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto stream = get_current_cuda_stream(device.index());
CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
@@ -161,9 +164,9 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
}
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
inline void fallback_cutlass_gemm_caller(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... args) {
// In some cases, the GPU isn't able to accommodate the
// shared memory requirements of the Gemm. In such cases, use
@@ -180,8 +183,8 @@ inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
return cutlass_gemm_caller<Gemm>(out, a, b,
std::forward<EpilogueArgs>(args)...);
} else {
TORCH_CHECK(fallback_gemm_shared_mem_size <=
max_shared_mem_per_block_opt_in);
STD_TORCH_CHECK(fallback_gemm_shared_mem_size <=
max_shared_mem_per_block_opt_in);
return cutlass_gemm_caller<FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm_c2x.cuh"
/**
@@ -70,13 +72,13 @@ struct sm75_config_M32 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
inline void cutlass_gemm_sm75_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
using Cutlass2xGemmDefault =
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm_c2x.cuh"
/**
@@ -72,13 +74,13 @@ struct sm80_config_M16 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm80_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
inline void cutlass_gemm_sm80_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
using Cutlass2xGemmDefault =
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm_c2x.cuh"
#include "cutlass/float8.h"
@@ -34,10 +36,12 @@ struct sm89_fp8_config_default {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
@@ -84,10 +88,12 @@ struct sm89_fp8_config_M256 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
@@ -125,10 +131,12 @@ struct sm89_fp8_config_M128 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
@@ -173,10 +181,12 @@ struct sm89_fp8_config_M64 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
@@ -227,10 +237,12 @@ struct sm89_fp8_config_M32 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
@@ -280,10 +292,12 @@ struct sm89_fp8_config_M16 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
@@ -326,13 +340,15 @@ struct sm89_fp8_config_M16 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
inline void cutlass_gemm_sm89_fp8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
uint32_t const m = a.size(0);
uint32_t const mp2 =

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm_c2x.cuh"
/**
@@ -32,10 +34,11 @@ struct sm89_int8_config_default {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
@@ -88,10 +91,11 @@ struct sm89_int8_config_M256 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
@@ -143,10 +147,11 @@ struct sm89_int8_config_M128 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
@@ -193,10 +198,11 @@ struct sm89_int8_config_M64 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
@@ -234,10 +240,11 @@ struct sm89_int8_config_M32 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
@@ -276,10 +283,11 @@ struct sm89_int8_config_M16 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
@@ -311,13 +319,13 @@ struct sm89_int8_config_M16 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
inline void cutlass_gemm_sm89_int8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
uint32_t const m = a.size(0);
uint32_t const mp2 =

View File

@@ -8,11 +8,12 @@
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm100_fp8,
nullptr, // int8 not supported on SM100

View File

@@ -8,11 +8,12 @@
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm120_fp8,
nullptr, // int8 not supported on SM120

View File

@@ -0,0 +1,38 @@
#include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper).
*/
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm90_fp8,
vllm::cutlass_scaled_mm_sm90_int8,
vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
}
void cutlass_scaled_mm_azp_sm90(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
azp, bias);
}
#endif

View File

@@ -0,0 +1,451 @@
#include <cudaTypedefs.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
void cutlass_scaled_mm_sm75(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm80(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm89(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#endif
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
(defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
void get_cutlass_moe_mm_data_caller(
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab);
void get_cutlass_batched_moe_mm_data_caller(
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k);
#endif
void cutlass_scaled_mm_azp_sm75(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm80(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm89(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_azp_sm90(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
#endif
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
// CUDA 12.4 on SM89 systems (Lovelace)
#if defined CUDA_VERSION
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 89) {
return CUDA_VERSION >= 12040;
}
#endif
return false;
}
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
// and at least SM90 (Hopper)
#if defined CUDA_VERSION
if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
} else if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
}
#endif
return false;
}
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
// or CUDA 12.8 and SM100 (Blackwell)
#if defined CUDA_VERSION
if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
}
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12030;
}
#endif
return false;
}
void cutlass_scaled_mm(torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
// Checks for conformality
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
// Check for strides and alignment
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
if (bias) {
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
bias->dim() == 1);
}
const torch::stable::accelerator::DeviceGuard device_guard(
a.get_device_index());
int32_t version_num = get_sm_version_num();
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
if (version_num >= 120) {
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
if (version_num >= 100 && version_num < 120) {
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90 && version_num < 100) {
// Hopper
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
return;
}
if (version_num >= 75) {
// Turing
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
}
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides, bool per_act_token,
bool per_out_ch) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
if (version_num >= 100 && version_num < 110) {
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
if (version_num >= 90 && version_num < 100) {
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
". Required capability: 90 or 100");
}
void get_cutlass_moe_mm_data(
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k,
blockscale_offsets, is_gated);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
"CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab) {
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
"no cutlass_scaled_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_batched_moe_mm_data(
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_batched_moe_mm_data: no "
"cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}
void cutlass_scaled_mm_azp(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
// Checks for conformality
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
STD_TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
STD_TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
// Check for strides and alignment
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
// bias, azp, azp_adj are all 1d
// bias and azp_adj have n elements, azp has m elements
if (bias) {
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
}
if (azp) {
STD_TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
}
STD_TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
// azp & bias types
STD_TORCH_CHECK(azp_adj.scalar_type() == torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(!azp ||
azp->scalar_type() == torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(!bias || bias->scalar_type() == c.scalar_type(),
"currently bias dtype must match output dtype ",
c.scalar_type());
const torch::stable::accelerator::DeviceGuard device_guard(
a.get_device_index());
int32_t version_num = get_sm_version_num();
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90) {
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
// Turing
STD_TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: ",
version_num);
}

View File

@@ -31,6 +31,174 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()");
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
// Check if cutlass scaled_mm is supported for CUDA devices of the given
// capability
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
// Check if cutlass grouped gemm is supported for CUDA devices of the given
// capability
ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool");
// CUTLASS w8a8 grouped GEMM
ops.def(
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
" bool per_out_ch) -> ()");
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM. It takes topk_ids as an input, and computes expert_offsets
// (token start indices of each expert). In addition to this, it computes
// problem sizes for each expert's multiplication used by the two mms called
// from fused MoE operation, and arrays with permutations required to shuffle
// and de-shuffle the input/output of the fused operation.
ops.def(
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k, Tensor? blockscale_offsets, "
" bool is_gated) -> ()");
// compute per-expert problem sizes from expert_first_token_offset
// produced by vLLM's moe_permute kernel
ops.def(
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
" Tensor expert_first_token_offset, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" int n, int k, bool swap_ab) -> ()");
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM in batched expert format. It takes expert_num_tokens
// as an input, and computes expert_offsets (token start indices of each
// expert). In addition to this, it computes problem sizes for each expert's
// multiplication used by the two mms called from fused MoE operation.
ops.def(
"get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" Tensor expert_num_tokens, "
" int num_local_experts, int padded_m, "
" int n, int k) -> ()");
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
ops.def(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
"bool");
// CUTLASS nvfp4 block scaled GEMM
ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
// cutlass nvfp4 block scaled group GEMM
ops.def(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
// Out variant
// TODO: Add out_variant tag once PyTorch supports it (added in 2.11)
// This registration is now migrated to stable ABI
// at::Tag::out_variant is not available in the stable ABI (enum_tag.h is not
// yet in torch/headeronly), the tag should be applied from Python
// via torch.library.Library.define(..., tags=(torch.Tag.out_variant,))
// with the .impl remaining in C++.
// See pytorch/pytorch#176117.
ops.def(
"scaled_fp4_quant.out(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
"-> ()");
// Compute NVFP4 experts quantization.
ops.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Fused SiLU+Mul+NVFP4 experts quantization.
ops.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Fused SiLU+Mul+NVFP4 quantization.
ops.def(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
"Tensor input, Tensor input_global_scale) -> ()");
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
// CUTLASS w4a8 GEMM
ops.def(
"cutlass_w4a8_mm("
" Tensor A,"
" Tensor B,"
" Tensor group_scales,"
" int group_size,"
" Tensor channel_scales,"
" Tensor token_scales,"
" ScalarType? out_type,"
" str? maybe_schedule"
") -> Tensor");
// pack scales
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
// encode and reorder weight matrix
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
// CUTLASS w4a8 grouped GEMM
ops.def(
"cutlass_w4a8_moe_mm("
" Tensor! out_tensors,"
" Tensor a_tensors,"
" Tensor b_tensors,"
" Tensor a_scales,"
" Tensor b_scales,"
" Tensor b_group_scales,"
" int b_group_size,"
" Tensor expert_offsets,"
" Tensor problem_sizes,"
" Tensor a_strides,"
" Tensor b_strides,"
" Tensor c_strides,"
" Tensor group_scale_strides,"
" str? maybe_schedule"
") -> ()");
ops.def(
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
"Tensor)");
#endif
}
@@ -46,6 +214,45 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
TORCH_BOX(&per_token_group_quant_8bit_packed));
ops.impl("per_token_group_quant_int8",
TORCH_BOX(&per_token_group_quant_int8));
// CUTLASS scaled_mm ops
ops.impl("cutlass_scaled_mm", TORCH_BOX(&cutlass_scaled_mm));
ops.impl("cutlass_scaled_mm_azp", TORCH_BOX(&cutlass_scaled_mm_azp));
ops.impl("cutlass_moe_mm", TORCH_BOX(&cutlass_moe_mm));
ops.impl("get_cutlass_moe_mm_data", TORCH_BOX(&get_cutlass_moe_mm_data));
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets",
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
ops.impl("get_cutlass_batched_moe_mm_data",
TORCH_BOX(&get_cutlass_batched_moe_mm_data));
// FP4/NVFP4 ops
ops.impl("cutlass_scaled_fp4_mm", TORCH_BOX(&cutlass_scaled_fp4_mm));
ops.impl("scaled_fp4_quant", TORCH_BOX(&scaled_fp4_quant_func));
ops.impl("scaled_fp4_quant.out", TORCH_BOX(&scaled_fp4_quant_out));
ops.impl("scaled_fp4_experts_quant", TORCH_BOX(&scaled_fp4_experts_quant));
ops.impl("silu_and_mul_scaled_fp4_experts_quant",
TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant));
ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant));
// W4A8 ops: impl registrations are in the source files
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
#endif
}
// These capability-check functions take only primitive args (no tensors), so
// there is no device to dispatch on. CompositeExplicitAutograd makes them
// available for all backends. This is the stable ABI equivalent of calling
// ops.impl("op_name", &func) without a dispatch key in the non-stable API.
STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
#ifndef USE_ROCM
ops.impl("cutlass_scaled_mm_supports_fp8",
TORCH_BOX(&cutlass_scaled_mm_supports_fp8));
ops.impl("cutlass_group_gemm_supported",
TORCH_BOX(&cutlass_group_gemm_supported));
ops.impl("cutlass_scaled_mm_supports_block_fp8",
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
ops.impl("cutlass_scaled_mm_supports_fp4",
TORCH_BOX(&cutlass_scaled_mm_supports_fp4));
#endif
}

View File

@@ -1,10 +1,17 @@
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/shim_utils.h>
#include <cuda_runtime.h>
// Stable ABI equivalent of TORCH_CHECK_NOT_IMPLEMENTED.
#define STD_TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
STD_TORCH_CHECK(cond, "NotImplementedError: ", __VA_ARGS__)
// Utility to get the current CUDA stream for a given device using stable APIs.
// Returns a cudaStream_t for use in kernel launches.
inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) {

View File

@@ -21,7 +21,7 @@ struct SSMParamsBase {
int dim_ngroups_ratio;
bool is_variable_B;
bool is_variable_C;
int64_t pad_slot_id;
int64_t null_block_id;
bool delta_softplus;
bool cache_enabled;

View File

@@ -118,9 +118,17 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
int cache_index;
if (cache_indices == nullptr) {
cache_index = batch_id;
} else if (params.cache_enabled) {
const int* initial_state_idx = reinterpret_cast<const int*>(params.initial_state_idx_ptr);
cache_index = cache_indices[batch_id * params.cache_indices_stride + initial_state_idx[batch_id]];
} else {
cache_index = cache_indices[batch_id];
}
// Skip batch entries whose cache index maps to the null block (padding).
if (cache_indices != nullptr && cache_index == params.null_block_id){
return;
}
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
@@ -527,7 +535,7 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const std::optional<at::Tensor>& cache_indices,
const std::optional<at::Tensor>& has_initial_state,
bool varlen,
int64_t pad_slot_id,
int64_t null_block_id,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
@@ -544,7 +552,7 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.dstate = dstate;
params.n_groups = n_groups;
params.dim_ngroups_ratio = dim / n_groups;
params.pad_slot_id = pad_slot_id;
params.null_block_id = null_block_id;
params.delta_softplus = delta_softplus;
@@ -658,7 +666,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const torch::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id,
int64_t null_block_id,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
@@ -805,7 +813,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
cache_indices,
has_initial_state,
varlen,
pad_slot_id,
null_block_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,

View File

@@ -1,144 +0,0 @@
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include "gpt_oss_router_gemm.cuh"
void launch_gpt_oss_router_gemm(__nv_bfloat16* gA, __nv_bfloat16* gB,
__nv_bfloat16* gC, __nv_bfloat16* bias,
int batch_size, int output_features,
int input_features, cudaStream_t stream) {
static int const WARP_TILE_M = 16;
static int const TILE_M = WARP_TILE_M;
static int const TILE_N = 8;
static int const TILE_K = 64;
static int const STAGES = 16;
static int const STAGE_UNROLL = 4;
static bool const PROFILE = false;
CUtensorMap weight_map{};
CUtensorMap activation_map{};
constexpr uint32_t rank = 2;
uint64_t size[rank] = {(uint64_t)input_features, (uint64_t)output_features};
uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)};
uint32_t box_size[rank] = {TILE_K, TILE_M};
uint32_t elem_stride[rank] = {1, 1};
CUresult res = cuTensorMapEncodeTiled(
&weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank,
gB, size, stride, box_size, elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
TORCH_CHECK(res == CUDA_SUCCESS,
"cuTensorMapEncodeTiled failed for weight_map, error code=",
static_cast<int>(res));
size[1] = batch_size;
box_size[1] = TILE_N;
res = cuTensorMapEncodeTiled(
&activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
rank, gA, size, stride, box_size, elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
TORCH_CHECK(res == CUDA_SUCCESS,
"cuTensorMapEncodeTiled failed for activation_map, error code=",
static_cast<int>(res));
int smem_size = STAGES * STAGE_UNROLL *
(TILE_M * TILE_K * sizeof(__nv_bfloat16) +
TILE_N * TILE_K * sizeof(__nv_bfloat16));
gpuErrChk(cudaFuncSetAttribute(
gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
STAGE_UNROLL, PROFILE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
int tiles_m = (output_features + TILE_M - 1) / TILE_M;
int tiles_n = (batch_size + TILE_N - 1) / TILE_N;
dim3 grid(tiles_m, tiles_n);
dim3 block(384);
cudaLaunchConfig_t config;
cudaLaunchAttribute attrs[1];
config.gridDim = grid;
config.blockDim = block;
config.dynamicSmemBytes = smem_size;
config.stream = stream;
config.attrs = attrs;
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
config.numAttrs = 1;
cudaLaunchKernelEx(
&config,
&gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
STAGE_UNROLL, PROFILE>,
gC, gA, gB, bias, output_features, batch_size, input_features, weight_map,
activation_map, nullptr);
}
void gpt_oss_router_gemm_cuda_forward(torch::Tensor& output,
torch::Tensor input, torch::Tensor weight,
torch::Tensor bias) {
auto const batch_size = input.size(0);
auto const input_dim = input.size(1);
auto const output_dim = weight.size(0);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.scalar_type() == at::ScalarType::BFloat16) {
launch_gpt_oss_router_gemm((__nv_bfloat16*)input.data_ptr(),
(__nv_bfloat16*)weight.data_ptr(),
(__nv_bfloat16*)output.mutable_data_ptr(),
(__nv_bfloat16*)bias.data_ptr(), batch_size,
output_dim, input_dim, stream);
} else {
throw std::invalid_argument("Unsupported dtype, only supports bfloat16");
}
}
void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input,
torch::Tensor weight, torch::Tensor bias) {
TORCH_CHECK(input.dim() == 2, "input must be 2D");
TORCH_CHECK(weight.dim() == 2, "weight must be 2D");
TORCH_CHECK(bias.dim() == 1, "bias must be 1D");
TORCH_CHECK(input.sizes()[1] == weight.sizes()[1],
"input.size(1) must match weight.size(1)");
TORCH_CHECK(weight.sizes()[0] == bias.sizes()[0],
"weight.size(0) must match bias.size(0)");
TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16,
"input tensor must be bfloat16");
TORCH_CHECK(weight.scalar_type() == at::ScalarType::BFloat16,
"weight tensor must be bfloat16");
TORCH_CHECK(bias.scalar_type() == at::ScalarType::BFloat16,
"bias tensor must be bfloat16");
gpt_oss_router_gemm_cuda_forward(output, input, weight, bias);
}

Some files were not shown because too many files have changed in this diff Show More